M.Hiroi's Home Page

お気楽 OCaml プログラミング入門

メモ化と遅延評価


Copyright (C) 2008-2020 Makoto Hiroi
All rights reserved.

はじめに

今回は「たらいまわし関数」を例題にして、「メモ化」と「遅延評価」について説明します。

●たらいまわし関数

最初に「たらいまわし関数」について説明します。次のリストを見てください。

リスト 1 : たらいまわし関数

let rec tarai x y z =
  if x <= y then y
  else tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y)

let rec tak x y z = 
  if x <= y then z
  else tak (tak (x - 1) y z) (tak (y - 1) z x) (tak (z - 1) x y)

関数 tarai や tak は「たらいまわし関数」といって、再帰的に定義されています。これらの関数は、引数の与え方によっては実行に時間がかかるため、Lisp などのベンチマークに利用されることがあります。

関数 tarai は通称「竹内関数」と呼ばれていて、日本の代表的な Lisper である竹内郁雄先生によって考案されたそうです。そして、関数 tak は関数 tarai のバリエーションで、John Macarthy によって作成されたそうです。たらいまわし関数が Lisp のベンチマークで使われていたことは知っていましたが、このような由緒ある関数だとは思ってもいませんでした。

さっそく実行してみましょう。実行環境は Ubunts 18.04 (Windows Subsystem for Linux), Intel Core i5-6200U 2.30GHz, ocamlc (version 4.05.0) です。

tarai 13 6 0 : 1.55 [s]
tak 19 9 0   : 0.59 [s]

このように、たらいまわし関数は引数の値が小さくても実行に時間がかかります。

●メモ化による高速化

たらいまわし関数が遅いのは、同じ値を何度も計算しているためです。この場合、表 (table) を使って処理を高速化することができます。同じ値を何度も計算することがないように、計算した値は表に格納しておいて、2 回目以降は表から計算結果を求めるようにします。このような手法を「表計算法」とか「メモ化 (memoization または memoisation)」といいます。

OCaml の場合、メモ化はハッシュ表 (Hashtbl) を使うと簡単です。次のリストを見てください。

リスト 2 : たらいまわし関数のメモ化 (1)

(* メモ用のハッシュ表 *)
let table = Hashtbl.create 2048

let rec tarai x y z =
  let key = (x, y, z) in
  if Hashtbl.mem table key then Hashtbl.find table key
  else
    let value = if x <= y then y
    else
      tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y)
    in
    Hashtbl.add table key value;
    value

関数 tarai の値を格納するハッシュ表を大域変数 table に用意します。関数 tarai では、引数 x, y, z を要素とする組を作り、それをキーとしてハッシュ表 table を検索します。table に key があれば、その値を返します。そうでなければ、値を計算して table にセットして、その値を返します。

ところで、ハッシュ表は局所変数に格納することもできます。次のリストを見てください。

リスト 3 : たらいまわし関数のメモ化 (2)

(* 探索 *)
let lookup table func args =
  if Hashtbl.mem table args then
    Hashtbl.find table args
  else
    let value = func args in
    Hashtbl.add table args value;
    value

(* たらいまわし関数 *)
let rec tak (x, y, z) =
  if x <= y then z
  else memo_tak (memo_tak (x - 1, y, z),
                 memo_tak (y - 1, z, x),
                 memo_tak (z - 1, x, y))
and memo_tak =
  let table = Hashtbl.create 2048 in
  fun x -> lookup table tak x

let rec tarai (x, y, z) =
  if x <= y then y
  else memo_tarai (memo_tarai (x - 1, y, z),
                   memo_tarai (y - 1, z ,x),
                   memo_tarai (z - 1, x, y))
and memo_tarai =
  let table = Hashtbl.create 2048 in
  fun x -> lookup table tarai x

関数 lookup はハッシュ表 table から関数 func の引数 args に対応するデータを探します。ここでは関数の引数を組にまとめて args に渡すものとします。ハッシュ表にデータがある場合はその値を返します。そうでなければ、func args を評価して値 value を求め、それをハッシュ表に登録します。

関数 tak と tarai は自分自身を再帰呼び出しするのではなく、関数 memo_tak と memo_tarai を呼び出します。memo_tak と memo_tarai は、ハッシュ表を局所変数 table にセットしてから、匿名関数を使って関数本体を定義します。ハッシュ表が生成されるのは、memo_tak, memo_tarai に関数をセットするときの一回だけです。これで、その関数専用のハッシュ表を局所変数に用意することができます。memo_tak と memo_tarai の本体は、lookup を呼び出してハッシュ表から値を探索するだけです。

関数の型は次のようになります。

val memoize : ('a, 'b) Hashtbl.t -> ('a -> 'b) -> 'a -> 'b = <fun>
val tak : int * int * int -> int = <fun>
val memo_tak : int * int * int -> int = <fun>
val tarai : int * int * int -> int = <fun>
val memo_tarai : int * int * int -> int = <fun>

それでは実際に実行してみましょう。

tarai (192, 96, 0) : 0.10 [s]
tak (192, 96, 0)   : 0.34 [s]

このように、引数の値を増やしても高速に実行することができます。メモ化の効果は十分に出ていると思います。また、同じ計算を再度実行すると、メモ化の働きにより値をすぐに求めることができます。

●メモ化関数

このように関数をメモ化することは簡単にできますが、メモ化を行うたびに関数を修正するのは面倒です。このような場合、関数をメモ化する「メモ化関数」があると便利です。メモ化関数については『計算機プログラムの構造と解釈 第二版 (和田英一 訳)』の「3.3.3 表の表現」に詳しい説明があります。

ただし、変数の値を書き換えることができない関数型言語の場合、汎用的なメモ化関数を作成することは難しく、OCaml でも簡単ではありません。そこで、今回は Lisp を使ってメモ化関数を作成してみましょう。Common Lisp と Scheme のプログラムは次のようになります。

リスト 4 : メモ化関数 (Common Lisp)

(defun memoize (func)
  (let ((table (make-hash-table :test #'equal)))
    #'(lambda (&rest args)
        (let ((value (gethash args table nil)))
          (unless value
            (setf value (apply func args))
            (setf (gethash args table) value))
          value))))

; たらいまわし関数
(defun tak (x y z)
  (if (<= x y)
      z
      (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y))))

(defun tarai (x y z)
  (if (<= x y)
      y
      (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y))))

; 関数を書き換える
(setf (symbol-function 'tak) (memoize #'tak))
(setf (symbol-function 'tarai) (memoize #'tarai))
リスト 5 : メモ化関数 (Scheme : Gauche)

; 汎用のメモ化関数
(define (memoize func)
  (let ((table (make-hash-table 'equal?)))
    (lambda args
      (if (hash-table-exists? table args)
          (hash-table-get table args)
          (let ((value (apply func args)))
            (hash-table-put! table args value)
            value)))))

; たらいまわし関数
(define (tak x y z)
  (if (<= x y)
      z
      (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y))))

(define (tarai x y z)
  (if (<= x y)
      y
      (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y))))

; 値を書き換える
(set! tak (memoize tak))
(set! tarai (memoize tarai))

関数 memoize は関数 func を引数に受け取り、それをメモ化した関数を返します。memoize が返す関数はクロージャなので、memoize の引数 func や局所変数 table にアクセスすることができます。また、無名関数 lambda の引数 args は可変個の引数を受け取るように定義します。これで、複数の引数を持つ関数にも対応することができます。

args の値は引数を格納したリストになるので、これをキーとして扱います。ハッシュ表 table に値がなければ、関数 func を呼び出して値を計算し、それを table にセットします。そしで、最後に値を返します。なお、変数 tak と tarai の値 (Common Lisp の場合は関数) を書き換えないと、関数 tak, tarai の中で再帰呼び出しするとき、メモ化した関数を呼び出すことはできません。ご注意ください。

●遅延評価による高速化

関数 tarai は「遅延評価 (delayed evaluation または lazy evaluation)」を行う処理系、たとえば関数型言語の Haskell では高速に実行することができます。また、Scheme でも delay と force を使って遅延評価を行うことができます。tarai のプログラムを見てください。x <= y のときに y を返しますが、このとき引数 z の値は必要ありませんね。引数 z の値は x > y のときに計算するようにすれば、無駄な計算を省略することができます。

なお、関数 tak は x <= y のときに z を返しているため、遅延評価で高速化することはできません。ご注意ください。

OCaml には遅延評価を行うための構文 lazy とモジュール Lazy が用意されています。また、完全ではありませんが、クロージャを使って遅延評価を行うこともできます。今回は Shiro さんの WiLiKi にある「Scheme:たらいまわしべんち」を参考に、プログラムを作ってみましょう。次のリストを見てください。

リスト 6 : クロージャによる遅延評価

let rec tarai x y z =
  if x <= y then y
  else
    let zz = z () in
    tarai (tarai (x - 1) y (fun () -> zz))
          (tarai (y - 1) zz (fun () -> x))
          (fun () -> tarai (zz - 1) x (fun () -> y))

遅延評価したい処理をクロージャに包んで引数 z に渡します。そして、x > y のときに引数 z の関数を呼び出します。すると、クロージャ内の処理が評価されて z の値を求めることができます。たとえば、fun () -> 0 を z に渡す場合、z () とすると返り値は 0 になります。fun () -> x を渡せば、x に格納されている値が返されます。fun -> tarai ... を渡せば、関数 tarai が実行されてその値が返されるわけです。

関数 tarai の型は次のようになります。

val tarai : int -> int -> (unit -> int) -> int = <fun>

また、lazy 文を使うと、tarai は次のようになります。

リスト 7 : lazy による遅延評価

let rec tarai x y z =
  if x <= y then y
  else
    let zz = Lazy.force z in
    tarai (tarai (x - 1) y (lazy zz))
          (tarai (y - 1) zz (lazy x))
          (lazy (tarai (zz - 1) x (lazy y)))

lazy expr は、式 expr を評価せずに lazy_t というデータ (遅延オブジェクト) を返します。簡単な使用例を示しましょう。

# let a = lazy (10 + 20);;
val a : int lazy_t = <lazy>
# Lazy.force a;;
- : int = 30

lazy (10 + 20) の返り値を変数 a にセットします。このとき、式 10 + 20 は評価されていません。遅延オブジェクトの値を実際に求める関数が Lazy.force です。Lazy.force a を実行すると、式 10 + 20 を評価して値 30 を返します。

また、遅延オブジェクトは式の評価結果をキャッシュします。したがって、Lazy.force a を再度実行すると、同じ式を再評価することなく値を求めることができます。次の例を見てください。

# let a = lazy (print_string "eval"; 10 + 20);;
val a : int lazy_t = <lazy>
# Lazy.force a;;
eval- : int = 30
# Lazy.force a;;
- : int = 30

最初に Lazy.force a を実行すると、式 (print_string "eval"; 10 + 20) が評価されるので、画面に eval が表示されます。次に、Lazy.force a を実行すると、式を評価せずにキャッシュした値を返すので eval は表示されません。

lazy を使った場合、関数 tarai の型は次のようになります。

val tarai : int -> int -> int Lazy.t -> int = <fun>

それでは、実際に実行してみましょう。

tarai 192 96 0
closure : 0.00078 [s]
lazy    : 0.00047 [s]

実行時間が速いので、今回は tarai 192 96 0 を 100 回実行した時間から 1 回の実行時間を求めました。tarai の場合、遅延評価の効果はとても大きいですね。

ところで、クロージャや lazy を使わなくても、関数 tarai を高速化する方法があります。C++:language&libraries (リンク切れ) で Akira Higuchi さん (リンク切れ) が書かれたC言語の tarai 関数はとても高速です。OCaml でプログラムすると次のようになります。

リスト 8 : tarai の遅延評価

let rec tarai x y z =
  if x <= y then y
  else tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y
and tarai_lazy x y xx yy zz =
  if x <= y then y
  else
    let z = tarai xx yy zz in
    tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y
val tarai : int -> int -> int -> int = <fun>
val tarai_lazy : int -> int -> int -> int -> int -> int = <fun>

関数 tarai_lazy の引数 xx, yy, zz で z の値を表すところがポイントです。つまり、z の計算に必要な値を引数に保持し、z の値が必要になったときに tarai(xx, yy, zz) で計算するわけです。実際に実行してみると tarai 192 96 0 は 0.0011 [s] になりました。Akira Higuchi さんに感謝いたします。

●問題

次に示す関数を表計算法でプログラムしてください。

  1. 階乗を求める関数 fact n, (0 <= n <= 20)
  2. フィボナッチ数を求める関数 fibo n, (0 <= n <= 90)
  3. 組み合わせの数を求める関数 comb n r, (0 <= n <= 65, r <= n)












●解答

リスト : 表計算法の簡単な例題

exception Domain_error

(* 階乗 *)
let fact =
  let n = 20 in
  let table = Array.make (n + 1) 1 in
  for i = 1 to n do
    table.(i) <- table.(i - 1) * i
  done;
  fun x -> if x < 0 || x > n then raise Domain_error
           else table.(x)

(* フィボナッチ数 *)
let fibo =
  let n = 90 in
  let table = Array.make (n + 1) 0 in
  table.(1) <- 1;
  for i = 2 to n do
    table.(i) <- table.(i - 2) + table.(i - 1)
  done;
  fun x -> if x < 0 || x > n then raise Domain_error
           else table.(x)

(* 組み合わせの数 *)
let comb =
  let x = 65 in
  let table = Array.make (x + 1) [|1|] in
  table.(1) <- [|1; 1|];
  for i = 2 to x do
    table.(i) <- Array.make (i + 1) 1;
    for j = 1 to i - 1 do
      table.(i).(j) <- table.(i - 1).(j - 1) + table.(i - 1).(j)
    done
  done;
  fun n r -> if n > x || r > n then raise Domain_error
             else table.(n).(r)
exception Domain_error
val fact : int -> int = <fun>
val fibo : int -> int = <fun>
val comb : int -> int -> int = <fun>

簡単な実行例を示します。

# for i = 10 to 20 do print_int (fact i); print_newline () done;;
3628800
39916800
479001600
6227020800
87178291200
1307674368000
20922789888000
355687428096000
6402373705728000
121645100408832000
2432902008176640000
- : unit = ()
# fact (-1);;
Exception: Domain_error.
# fact 21;;
Exception: Domain_error.

# for i = 80 to 90 do print_int (fibo i); print_newline () done;;
23416728348467685
37889062373143906
61305790721611591
99194853094755497
160500643816367088
259695496911122585
420196140727489673
679891637638612258
1100087778366101931
1779979416004714189
2880067194370816120
- : unit = ()
# fibo 91;;
Exception: Domain_error.

# for i = 56 to 65 do print_int (comb i (i / 2)); print_newline () done;;
7648690600760440
15033633249770520
30067266499541040
59132290782430712
118264581564861424
232714176627630544
465428353255261088
916312070471295267
1832624140942590534
3609714217008132870
- : unit = ()
# comb 66 33;;
Exception: Domain_error.

# for i = 0 to 16 do for j = 0 to i do Printf.printf "%d " (comb i j) done; print_newline () done;;
1
1 1
1 2 1
1 3 3 1
1 4 6 4 1
1 5 10 10 5 1
1 6 15 20 15 6 1
1 7 21 35 35 21 7 1
1 8 28 56 70 56 28 8 1
1 9 36 84 126 126 84 36 9 1
1 10 45 120 210 252 210 120 45 10 1
1 11 55 165 330 462 462 330 165 55 11 1
1 12 66 220 495 792 924 792 495 220 66 12 1
1 13 78 286 715 1287 1716 1716 1287 715 286 78 13 1
1 14 91 364 1001 2002 3003 3432 3003 2002 1001 364 91 14 1
1 15 105 455 1365 3003 5005 6435 6435 5005 3003 1365 455 105 15 1
1 16 120 560 1820 4368 8008 11440 12870 11440 8008 4368 1820 560 120 16 1
- : unit = ()

初版 2008 年 9 月 14 日
改訂 2020 年 7 月 26 日