Rustのクロージャで再帰してみた (ダメだった)

Qiita

クロージャを再帰呼び出しする方法を考えました。

競技プログラミングではローカル変数を書き換えながら再帰する処理がよく出てきます。しかし Rust でそれを書こうとするとやや冗長になりがちです。

本稿では小さなヘルパーを用意して記述を簡略化することを試みました。

  • 環境: Rust 1.15.1 (AtCoder での現在のバージョン)
  • 筆者: AtCoder もうすぐ青といい続けて1年

要約

  • 競プロではよく再帰する。
  • 小さなアダプタを書くと再帰呼び出しできる。
  • イミュータブルなクロージャはローカル変数を書き換えられない。
    • RefCell で対処する。
  • 成果:

用例1: 階乗

単純な例として、階乗の計算を再帰で書けるようにしましょう。内部で自身を参照するために、クロージャは引数に fact (階乗関数) を受け取るようにする方針でいきます。

    let fact_5 = recurse(5, &|n, fact| {
        if n <= 1 {
            1_i64
        } else {
            n * fact(n - 1)
        }
    });
    assert_eq!(1 * 2 * 3 * 4 * 5, fact_5);

ここで recurse(x, f)f(x, f) の意味になるように後で定義するヘルパーです。

「なぜ作った関数を即座に起動するのか」という疑問があると思いますが、それは実際にそういう用途が多いからです。再帰関数がほしいときは |x| recurse(x, &|x, f| ..) のようにクロージャ化する運用でも大丈夫でしょう。

実装1: イミュータブル版

recurse の実装は簡単で、fn で定義した関数が再帰可能であることを利用します。

    fn recurse<X, Y>(x: X, f: &Fn(X, &Fn(X) -> Y) -> Y) -> Y {
        f(x, &|x: X| recurse(x, &f))
    }

注意点は、クロージャの引数の型がまたそのクロージャの型で……という無限の循環を避けるため、関数を トレイトオブジェクト への参照という形で扱っていることです。

Fn(X) -> Y というのは「型 X の値を受け取って型 Y の値を返す関数」の型を表すトレイトで、ある種のクロージャは自動的に Fn を実装した型になります。参照: std::ops::Fn - Rust

&Fn(X, &Fn(X) -> Y) -> Y
        ^^^^^^^^^^           再帰関数の型 (クロージャの引数)
 ^^^^^^^^^^^^^^^^^^^^^^^     定義したクロージャのトレイトオブジェクトの型

用例2: DFSで連結成分分解

次に現実的な例として、グラフの連結成分分解を深さ優先探索で書いてみます。

    //
    //   0 -- 1
    //   | \
    //   |  \
    //   2 -- 3    4--5
    //
    let graph =
        vec![
            vec![1, 2, 3],
            vec![0],
            vec![0, 3],
            vec![0, 2],
            vec![5],
            vec![4],
        ];
    let n = graph.len();

    let roots = RefCell::new(vec![n; n]);
    for u in 0..n {
        recurse(u, &|v, go| {
            if roots.borrow()[v] < n {
                return;
            }

            roots.borrow_mut()[v] = u;

            for &w in graph[v].iter() {
                go(w);
            }
        })
    }

    assert_eq!(&*roots.borrow(), &[0, 0, 0, 0, 4, 4]);

頂点 v が属す連結成分の代表を roots[v] に入れていきます。

このとき、再帰の途中で配列を更新する必要があります。しかし roots を let mut でミュータブル配列として宣言すると、先ほどの recurse は使えません。というもの、外部のミュータブルな変数を借用するクロージャは Fn トレイトを実装しないからです。

ここでは RefCell を使ってこの問題を回避しています。クロージャに渡すのが RefCell へのイミュータブルな参照でも、内部の値をミュータブルとして扱えます。参照: 保証を選ぶ

なんにせよ、それなりに簡潔に再帰処理ができました!

実装2. ミュータブル版

追記: ミュータブルなローカル変数を書き換えながらクロージャを再帰呼び出しする方法について記述していましたが、 安全でないコードが書けてしまう ので取り下げました。

クロージャの型が自動で実装するトレイトは Fn のほかに FnMut もあります。FnMut は、簡単にいうと「ミュータブルな状態を持つ関数」の型が実装すべきトレイトです。参照: std::ops::FnMut - Rust

外部のミュータブルな状態 (例えば let mut roots = ...) を触りながら再帰できるように、クロージャが FnMut でもいいようにしてみます。すると、借用検査が 通りません

通せるようにしたのが以下です:

fn recurse<X, Y>(x: X, f: &mut FnMut(X, &mut FnMut(X) -> Y) -> Y) -> Y {
    let fp = f as *mut FnMut(X, &mut FnMut(X) -> Y) -> Y;
    let f1 = unsafe { &mut *fp };
    let f2 = unsafe { &mut *fp };
    f1(x, &mut |x: X| recurse(x, f2))
}

これをみると分かるように、 recurse は受け取ったクロージャへの参照を2つに複製します: 即座に呼び出すための参照と、再帰用に呼び出すための参照です。ミュータブルな参照は複製できないので、unsafe を使って強制的に複製しています。

「unsafe だから危険じゃないのか」という疑問がありますが、実行中のクロージャが自分への参照を self, f で2重に受け取っているだけなので、たぶん大丈夫です。

これで深さ優先探索の例を書き直すと、RefCell が消失してすっきり。

    let mut roots = vec![n; n];
    for u in 0..n {
        recurse(u, &mut |v, go| {
            if roots[v] < n {
                return;
            }

            roots[v] = u;

            for &w in graph[v].iter() {
                go(w);
            }
        })
    }

Rust Playground で試す

参考

  • Stebalien commented on 28 Jan 2016

    Zコンビネータを使ってクロージャを再帰可能にするコードの例。引数として受け取る再帰関数の型は推論されないっぽい。

  • 無名再帰 - Google 検索

    クロージャのような匿名の関数を再帰呼び出しすることを無名再帰というらしい。

#![allow(dead_code)]
#![allow(unused_macros)]
#![allow(unused_imports)]

fn recurse<X, Y>(x: X, f: &Fn(X, &Fn(X) -> Y) -> Y) -> Y {
    f(x, &|x: X| recurse(x, &f))
}

macro_rules! memo {
    (| $f:ident, $($p:ident $(: $t:ty)*),* | $body:expr) => {{
        use std;
        let memo = std::cell::RefCell::new(std::collections::HashMap::new());

        move |$($p $(: $t)*),*| {
            recurse(
                #[allow(unused_parens)]
                { ($($p),*) },
                &|$($p $(: $t)*),*, $f| {
                    let args = ($($p),*).clone();
                    if let Some(&y) = memo.borrow().get(&args) {
                        return y;
                    }
                    let y = $body;
                    memo.borrow_mut().insert(args, y.clone());
                    y
                }
            )
        }
    }};
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::cell::RefCell;

    fn graph() -> Vec<Vec<usize>> {
        //
        //   0 -- 1
        //   | \
        //   |  \
        //   2 -- 3    4--5
        //
        vec![
            vec![1, 2, 3],
            vec![0],
            vec![0, 3],
            vec![0, 2],
            vec![5],
            vec![4],
        ]
    }

    #[test]
    fn test_fact() {
        let fact = |n| recurse(n, &|n, fact| if n <= 1 { 1_i64 } else { n * fact(n - 1) });
        assert_eq!(fact(1), 1);
        assert_eq!(fact(5), 120);
    }

    #[test]
    fn test_dfs() {
        let graph = graph();
        let n = graph.len();

        let roots = RefCell::new(vec![n; n]);
        for u in 0..n {
            recurse(u, &|v, go| {
                if roots.borrow()[v] < n {
                    return;
                }

                roots.borrow_mut()[v] = u;

                for &w in graph[v].iter() {
                    go(w);
                }
            })
        }

        assert_eq!(&*roots.borrow(), &[0, 0, 0, 0, 4, 4]);
    }

    #[test]
    fn test_memoized_fib() {
        let fib = memo!(|fib, n: i32| if n <= 1 {
            1_i64
        } else {
            fib(n - 1) + fib(n - 2)
        });
        assert_eq!(fib(5), 8);
        assert_eq!(fib(20), 10946);
    }
}
#![allow(dead_code)]
#![allow(unused_imports)]

fn recurse<X, Y>(x: X, f: &mut FnMut(X, &mut FnMut(X) -> Y) -> Y) -> Y {
    let fp = f as *mut FnMut(X, &mut FnMut(X) -> Y) -> Y;
    let f1 = unsafe { &mut *fp };
    let f2 = unsafe { &mut *fp };
    f1(x, &mut |x: X| recurse(x, f2))
}

#[cfg(test)]
mod tests {
    use super::*;

    fn graph() -> Vec<Vec<usize>> {
        //
        //   0 -- 1
        //   | \
        //   |  \
        //   2 -- 3    4--5
        //
        vec![
            vec![1, 2, 3],
            vec![0],
            vec![0, 3],
            vec![0, 2],
            vec![5],
            vec![4],
        ]
    }

    #[test]
    fn test_fact() {
        let fact = |n| recurse(n, &mut |n, fact| if n <= 1 { 1 } else { n * fact(n - 1) });
        assert_eq!(fact(1), 1);
        assert_eq!(fact(5), 120);
    }

    #[test]
    fn test_dfs() {
        let graph = graph();
        let n = graph.len();

        let mut roots = vec![n; n];
        for u in 0..n {
            recurse(u, &mut |v, go| {
                if roots[v] < n {
                    return;
                }

                roots[v] = u;

                for &w in graph[v].iter() {
                    go(w);
                }
            })
        }

        assert_eq!(roots, vec![0, 0, 0, 0, 4, 4]);
    }

    #[test]
    fn test_closure_is_dropped() {
        let n = 4;
        let mut k = 0;
        struct D<'a>(pub &'a mut i32);
        impl<'a> Drop for D<'a> {
            fn drop(&mut self) {
                *self.0 += 1;
            }
        }

        {
            recurse(0, &mut |i, go| {
                let d = D(&mut k);

                if i >= n {
                    assert_eq!(*d.0, 0);
                    return;
                }

                go(i + 1);
            });
        }

        assert_eq!(k, n + 1);
    }
}

関連記事