今回は「たらいまわし関数」を例題にして、「メモ化」と「遅延評価」について説明します。
最初に「たらいまわし関数」について説明します。次のリストを見てください。
リスト 1 : たらいまわし関数 let rec tarai x y z = if x <= y then y else tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y) let rec tak x y z = if x <= y then z else tak (tak (x - 1) y z) (tak (y - 1) z x) (tak (z - 1) x y)
関数 tarai や tak は「たらいまわし関数」といって、再帰的に定義されています。これらの関数は、引数の与え方によっては実行に時間がかかるため、Lisp などのベンチマークに利用されることがあります。
関数 tarai は通称「竹内関数」と呼ばれていて、日本の代表的な Lisper である竹内郁雄先生によって考案されたそうです。そして、関数 tak は関数 tarai のバリエーションで、John Macarthy によって作成されたそうです。たらいまわし関数が Lisp のベンチマークで使われていたことは知っていましたが、このような由緒ある関数だとは思ってもいませんでした。
さっそく実行してみましょう。実行環境は Ubunts 18.04 (Windows Subsystem for Linux), Intel Core i5-6200U 2.30GHz, ocamlc (version 4.05.0) です。
tarai 13 6 0 : 1.55 [s] tak 19 9 0 : 0.59 [s]
このように、たらいまわし関数は引数の値が小さくても実行に時間がかかります。
たらいまわし関数が遅いのは、同じ値を何度も計算しているためです。この場合、表 (table) を使って処理を高速化することができます。同じ値を何度も計算することがないように、計算した値は表に格納しておいて、2 回目以降は表から計算結果を求めるようにします。このような手法を「表計算法」とか「メモ化 (memoization または memoisation)」といいます。
OCaml の場合、メモ化はハッシュ表 (Hashtbl) を使うと簡単です。次のリストを見てください。
リスト 2 : たらいまわし関数のメモ化 (1) (* メモ用のハッシュ表 *) let table = Hashtbl.create 2048 let rec tarai x y z = let key = (x, y, z) in if Hashtbl.mem table key then Hashtbl.find table key else let value = if x <= y then y else tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y) in Hashtbl.add table key value; value
関数 tarai の値を格納するハッシュ表を大域変数 table に用意します。関数 tarai では、引数 x, y, z を要素とする組を作り、それをキーとしてハッシュ表 table を検索します。table に key があれば、その値を返します。そうでなければ、値を計算して table にセットして、その値を返します。
ところで、ハッシュ表は局所変数に格納することもできます。次のリストを見てください。
リスト 3 : たらいまわし関数のメモ化 (2) (* 探索 *) let lookup table func args = if Hashtbl.mem table args then Hashtbl.find table args else let value = func args in Hashtbl.add table args value; value (* たらいまわし関数 *) let rec tak (x, y, z) = if x <= y then z else memo_tak (memo_tak (x - 1, y, z), memo_tak (y - 1, z, x), memo_tak (z - 1, x, y)) and memo_tak = let table = Hashtbl.create 2048 in fun x -> lookup table tak x let rec tarai (x, y, z) = if x <= y then y else memo_tarai (memo_tarai (x - 1, y, z), memo_tarai (y - 1, z ,x), memo_tarai (z - 1, x, y)) and memo_tarai = let table = Hashtbl.create 2048 in fun x -> lookup table tarai x
関数 lookup はハッシュ表 table から関数 func の引数 args に対応するデータを探します。ここでは関数の引数を組にまとめて args に渡すものとします。ハッシュ表にデータがある場合はその値を返します。そうでなければ、func args を評価して値 value を求め、それをハッシュ表に登録します。
関数 tak と tarai は自分自身を再帰呼び出しするのではなく、関数 memo_tak と memo_tarai を呼び出します。memo_tak と memo_tarai は、ハッシュ表を局所変数 table にセットしてから、匿名関数を使って関数本体を定義します。ハッシュ表が生成されるのは、memo_tak, memo_tarai に関数をセットするときの一回だけです。これで、その関数専用のハッシュ表を局所変数に用意することができます。memo_tak と memo_tarai の本体は、lookup を呼び出してハッシュ表から値を探索するだけです。
関数の型は次のようになります。
val memoize : ('a, 'b) Hashtbl.t -> ('a -> 'b) -> 'a -> 'b = <fun> val tak : int * int * int -> int = <fun> val memo_tak : int * int * int -> int = <fun> val tarai : int * int * int -> int = <fun> val memo_tarai : int * int * int -> int = <fun>
それでは実際に実行してみましょう。
tarai (192, 96, 0) : 0.10 [s] tak (192, 96, 0) : 0.34 [s]
このように、引数の値を増やしても高速に実行することができます。メモ化の効果は十分に出ていると思います。また、同じ計算を再度実行すると、メモ化の働きにより値をすぐに求めることができます。
このように関数をメモ化することは簡単にできますが、メモ化を行うたびに関数を修正するのは面倒です。このような場合、関数をメモ化する「メモ化関数」があると便利です。メモ化関数については『計算機プログラムの構造と解釈 第二版 (和田英一 訳)』の「3.3.3 表の表現」に詳しい説明があります。
ただし、変数の値を書き換えることができない関数型言語の場合、汎用的なメモ化関数を作成することは難しく、OCaml でも簡単ではありません。そこで、今回は Lisp を使ってメモ化関数を作成してみましょう。Common Lisp と Scheme のプログラムは次のようになります。
リスト 4 : メモ化関数 (Common Lisp) (defun memoize (func) (let ((table (make-hash-table :test #'equal))) #'(lambda (&rest args) (let ((value (gethash args table nil))) (unless value (setf value (apply func args)) (setf (gethash args table) value)) value)))) ; たらいまわし関数 (defun tak (x y z) (if (<= x y) z (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y)))) (defun tarai (x y z) (if (<= x y) y (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y)))) ; 関数を書き換える (setf (symbol-function 'tak) (memoize #'tak)) (setf (symbol-function 'tarai) (memoize #'tarai))
リスト 5 : メモ化関数 (Scheme : Gauche) ; 汎用のメモ化関数 (define (memoize func) (let ((table (make-hash-table 'equal?))) (lambda args (if (hash-table-exists? table args) (hash-table-get table args) (let ((value (apply func args))) (hash-table-put! table args value) value))))) ; たらいまわし関数 (define (tak x y z) (if (<= x y) z (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y)))) (define (tarai x y z) (if (<= x y) y (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y)))) ; 値を書き換える (set! tak (memoize tak)) (set! tarai (memoize tarai))
関数 memoize は関数 func を引数に受け取り、それをメモ化した関数を返します。memoize が返す関数はクロージャなので、memoize の引数 func や局所変数 table にアクセスすることができます。また、無名関数 lambda の引数 args は可変個の引数を受け取るように定義します。これで、複数の引数を持つ関数にも対応することができます。
args の値は引数を格納したリストになるので、これをキーとして扱います。ハッシュ表 table に値がなければ、関数 func を呼び出して値を計算し、それを table にセットします。そしで、最後に値を返します。なお、変数 tak と tarai の値 (Common Lisp の場合は関数) を書き換えないと、関数 tak, tarai の中で再帰呼び出しするとき、メモ化した関数を呼び出すことはできません。ご注意ください。
関数 tarai は「遅延評価 (delayed evaluation または lazy evaluation)」を行う処理系、たとえば関数型言語の Haskell では高速に実行することができます。また、Scheme でも delay と force を使って遅延評価を行うことができます。tarai のプログラムを見てください。x <= y のときに y を返しますが、このとき引数 z の値は必要ありませんね。引数 z の値は x > y のときに計算するようにすれば、無駄な計算を省略することができます。
なお、関数 tak は x <= y のときに z を返しているため、遅延評価で高速化することはできません。ご注意ください。
OCaml には遅延評価を行うための構文 lazy とモジュール Lazy が用意されています。また、完全ではありませんが、クロージャを使って遅延評価を行うこともできます。今回は Shiro さんの WiLiKi にある「Scheme:たらいまわしべんち」を参考に、プログラムを作ってみましょう。次のリストを見てください。
リスト 6 : クロージャによる遅延評価 let rec tarai x y z = if x <= y then y else let zz = z () in tarai (tarai (x - 1) y (fun () -> zz)) (tarai (y - 1) zz (fun () -> x)) (fun () -> tarai (zz - 1) x (fun () -> y))
遅延評価したい処理をクロージャに包んで引数 z に渡します。そして、x > y のときに引数 z の関数を呼び出します。すると、クロージャ内の処理が評価されて z の値を求めることができます。たとえば、fun () -> 0 を z に渡す場合、z () とすると返り値は 0 になります。fun () -> x を渡せば、x に格納されている値が返されます。fun -> tarai ... を渡せば、関数 tarai が実行されてその値が返されるわけです。
関数 tarai の型は次のようになります。
val tarai : int -> int -> (unit -> int) -> int = <fun>
また、lazy 文を使うと、tarai は次のようになります。
リスト 7 : lazy による遅延評価 let rec tarai x y z = if x <= y then y else let zz = Lazy.force z in tarai (tarai (x - 1) y (lazy zz)) (tarai (y - 1) zz (lazy x)) (lazy (tarai (zz - 1) x (lazy y)))
lazy expr は、式 expr を評価せずに lazy_t というデータ (遅延オブジェクト) を返します。簡単な使用例を示しましょう。
# let a = lazy (10 + 20);; val a : int lazy_t = <lazy> # Lazy.force a;; - : int = 30
lazy (10 + 20) の返り値を変数 a にセットします。このとき、式 10 + 20 は評価されていません。遅延オブジェクトの値を実際に求める関数が Lazy.force です。Lazy.force a を実行すると、式 10 + 20 を評価して値 30 を返します。
また、遅延オブジェクトは式の評価結果をキャッシュします。したがって、Lazy.force a を再度実行すると、同じ式を再評価することなく値を求めることができます。次の例を見てください。
# let a = lazy (print_string "eval"; 10 + 20);; val a : int lazy_t = <lazy> # Lazy.force a;; eval- : int = 30 # Lazy.force a;; - : int = 30
最初に Lazy.force a を実行すると、式 (print_string "eval"; 10 + 20) が評価されるので、画面に eval が表示されます。次に、Lazy.force a を実行すると、式を評価せずにキャッシュした値を返すので eval は表示されません。
lazy を使った場合、関数 tarai の型は次のようになります。
val tarai : int -> int -> int Lazy.t -> int = <fun>
それでは、実際に実行してみましょう。
tarai 192 96 0 closure : 0.00078 [s] lazy : 0.00047 [s]
実行時間が速いので、今回は tarai 192 96 0 を 100 回実行した時間から 1 回の実行時間を求めました。tarai の場合、遅延評価の効果はとても大きいですね。
ところで、クロージャや lazy を使わなくても、関数 tarai を高速化する方法があります。C++:language&libraries (リンク切れ) で Akira Higuchi さん (リンク切れ) が書かれたC言語の tarai 関数はとても高速です。OCaml でプログラムすると次のようになります。
リスト 8 : tarai の遅延評価 let rec tarai x y z = if x <= y then y else tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y and tarai_lazy x y xx yy zz = if x <= y then y else let z = tarai xx yy zz in tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y
val tarai : int -> int -> int -> int = <fun> val tarai_lazy : int -> int -> int -> int -> int -> int = <fun>
関数 tarai_lazy の引数 xx, yy, zz で z の値を表すところがポイントです。つまり、z の計算に必要な値を引数に保持し、z の値が必要になったときに tarai(xx, yy, zz) で計算するわけです。実際に実行してみると tarai 192 96 0 は 0.0011 [s] になりました。Akira Higuchi さんに感謝いたします。
次に示す関数を表計算法でプログラムしてください。
リスト : 表計算法の簡単な例題 exception Domain_error (* 階乗 *) let fact = let n = 20 in let table = Array.make (n + 1) 1 in for i = 1 to n do table.(i) <- table.(i - 1) * i done; fun x -> if x < 0 || x > n then raise Domain_error else table.(x) (* フィボナッチ数 *) let fibo = let n = 90 in let table = Array.make (n + 1) 0 in table.(1) <- 1; for i = 2 to n do table.(i) <- table.(i - 2) + table.(i - 1) done; fun x -> if x < 0 || x > n then raise Domain_error else table.(x) (* 組み合わせの数 *) let comb = let x = 65 in let table = Array.make (x + 1) [|1|] in table.(1) <- [|1; 1|]; for i = 2 to x do table.(i) <- Array.make (i + 1) 1; for j = 1 to i - 1 do table.(i).(j) <- table.(i - 1).(j - 1) + table.(i - 1).(j) done done; fun n r -> if n > x || r > n then raise Domain_error else table.(n).(r)
exception Domain_error val fact : int -> int = <fun> val fibo : int -> int = <fun> val comb : int -> int -> int = <fun>
簡単な実行例を示します。
# for i = 10 to 20 do print_int (fact i); print_newline () done;; 3628800 39916800 479001600 6227020800 87178291200 1307674368000 20922789888000 355687428096000 6402373705728000 121645100408832000 2432902008176640000 - : unit = () # fact (-1);; Exception: Domain_error. # fact 21;; Exception: Domain_error. # for i = 80 to 90 do print_int (fibo i); print_newline () done;; 23416728348467685 37889062373143906 61305790721611591 99194853094755497 160500643816367088 259695496911122585 420196140727489673 679891637638612258 1100087778366101931 1779979416004714189 2880067194370816120 - : unit = () # fibo 91;; Exception: Domain_error. # for i = 56 to 65 do print_int (comb i (i / 2)); print_newline () done;; 7648690600760440 15033633249770520 30067266499541040 59132290782430712 118264581564861424 232714176627630544 465428353255261088 916312070471295267 1832624140942590534 3609714217008132870 - : unit = () # comb 66 33;; Exception: Domain_error. # for i = 0 to 16 do for j = 0 to i do Printf.printf "%d " (comb i j) done; print_newline () done;; 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 1 5 10 10 5 1 1 6 15 20 15 6 1 1 7 21 35 35 21 7 1 1 8 28 56 70 56 28 8 1 1 9 36 84 126 126 84 36 9 1 1 10 45 120 210 252 210 120 45 10 1 1 11 55 165 330 462 462 330 165 55 11 1 1 12 66 220 495 792 924 792 495 220 66 12 1 1 13 78 286 715 1287 1716 1716 1287 715 286 78 13 1 1 14 91 364 1001 2002 3003 3432 3003 2002 1001 364 91 14 1 1 15 105 455 1365 3003 5005 6435 6435 5005 3003 1365 455 105 15 1 1 16 120 560 1820 4368 8008 11440 12870 11440 8008 4368 1820 560 120 16 1 - : unit = ()