プログラミングの経験がある人ならば、「再帰」という言葉はご存じだと思います。関数定義の中で、その関数自身を呼び出すことを「再帰呼び出し (recursive call)」とか「再帰定義 (recursive definition)」といいます。
拙作のページ「新・お気楽 Python プログラミング入門第 3 回 再帰定義と高階関数」では、階乗の計算、ユークリッドの互除法、順列の生成などを例題にして、再帰定義の基本を説明しています。よろしければ参考にしてください。本ページでは新・お気楽 Python プログラミング入門で取り上げることのできなかった再帰プログラミングについて紹介します。
最初は簡単な数値計算の例を示しましょう。フィボナッチ関数は階乗と同様に再帰的に定義される関数です。
フィボナッチ関数も再帰呼び出しを使えば簡単にプログラムできます。
リスト : フィボナッチ関数 def fibo(n): if n < 2: return n return fibo(n - 1) + fibo(n - 2)
>>> for x in range(20): ... print(fibo(x)) ... 0 1 1 2 3 5 8 13 21 34 55 89 144 233 377 610 987 1597 2584 4181
関数 fibo() は階乗の計算を行う関数 fact() とは違い、自分自身を 2 回呼び出しています。これを「二重再帰」といいます。fibo() の呼び出しをトレースすると下図のようになります。
fibo(5) ┬ fibo(4) ┬ fibo(3) ┬ fibo(2) ┬ fibo(1) │ │ │ │ │ │ │ └ fibo(0) │ │ └ fibo(1) │ └ fibo(2) ┬ fibo(1) │ │ │ └ fibo(0) │ └ fibo(3) ┬ fibo(2) ┬ fibo(1) │ │ │ └ fibo(0) └ fibo(1) 図 2 : 関数 fibo のトレース
同じ値を何回も求めているため、fibo() の効率はとても悪いのです。この場合、二重再帰を「末尾再帰」に変換すると高速化することができます。そこで累算変数を使って、二重再帰を末尾再帰へ変換してみましょう。プログラムは次のようになります。
リスト : フィボナッチ関数 (末尾再帰) def fibo(n, a1 = 0, a2 = 1): if n == 0: return a1 return fibo(n - 1, a2, a1 + a2)
累算変数 a1 と a2 の使い方がポイントです。現在のフィボナッチ数を変数 a1 に、ひとつ先の値を変数 a2 に格納しておきます。あとは a1 と a2 を足し算して、新しいフィボナッチ数を計算すればいいわけです。fibo の呼び出しを下図に示します。
fibo(5, 0, 1) fibo(4, 1, 1) fibo(3, 1, 2) fibo(2, 2, 3) fibo(1, 3, 5) fibo(0, 5, 8) => a1 の値 5 を返す => 5 => 5 => 5 => 5 => 5 図 3 : fibo() の呼び出し
二重再帰では、同じ値を何回も求めていたため効率がとても悪かったのですが、このプログラムでは無駄な計算を行っていないので、値を高速に求めることができます。もちろん、末尾再帰になっているので、末尾再帰最適化を行う処理系では、プログラムをより高速に実行することができます。
Python の場合、末尾再帰最適化はサポートされていませんが、末尾再帰を繰り返しに変換することは簡単です。fibo() を繰り返しに変換すると次のようになります。
リスト : フィボナッチ関数 (繰り返し) def fibo(n): a1, a2 = 0, 1 while n > 0: a1, a2 = a2, a1 + a2 n -= 1 return a1
このように、末尾再帰は簡単に繰り返しに変換することができます。
もう一つ、簡単な数値計算の例を紹介しましょう。n を 0 以上の整数とすると、累乗 xn は x を n 回掛ける計算になります。Python には累乗を計算する演算子 ** がありますが、繰り返しでも再帰定義でも簡単にプログラムすることができます。単純な繰り返しで実装すると、次のようになります。
リスト : 累乗の計算 (1) def pow0(x, n): value = 1 while n > 0: value *= x n -= 1 return value
この場合、n 回の乗算が必要になります。ところが、式を変形するともっと少ない回数で求めることができます。
x ** 4 = (x ** 2) ** 2 -> 2 回 x ** 8 = (x ** 4) ** 2 -> 3 回 x ** 16 = (x ** 8) ** 2 -> 4 回 一般化すると x ** n = (x ** (n / 2)) ** 2; (n は偶数) x ** n = ((x ** (n / 2)) ** 2) * x; (n は奇数)
階乗計算では n を n - 1 の計算に置き換えていきますが、累乗の場合は n を n / 2 に置き換えていくことができます。n が半分になっていくので、ひとつずつ n を減らすよりも減少の度合いは大きくなります。その分だけ計算回数が少なくなるわけです。
それでは、この考え方をプログラムしてみましょう。x ** (n / 2) を計算する部分は、再帰を使えば簡単です。
リスト : 累乗の計算 (2) def pow1(x, n): if n == 0: return 1 value = pow1(x, n // 2) value *= value if n % 2 == 1: value *= x return value
関数 pow1() は一般化した累乗の定義をそのままプログラムしただけです。最初の if 文が再帰呼び出しの停止条件です。次に、n を 2 で割った値で pow1() を再帰呼び出しします。この返り値を 2 乗して、n が奇数ならば x をさらに掛け算します。最後に計算結果 value を return で返します。
ところで、pow1() も繰り返しに変換することができます。整数 n は次のように 2 ** i の和の形で表せることを利用します。
11 = 1 + 2 + 8 (2**0 + 2**1 + 2**3) 15 = 1 + 2 + 4 + 8 (2**0 + 2**1 + 2**2 + 2**3)
これは整数 n の下位ビットから順番にビットをチェックしていけば簡単に求めることができます。これを利用すると x ** n は次のように求めることができます。
x ** 11 = x * x ** 2 * x ** 8 x ** 15 = x * x ** 2 * x ** 4 * x ** 8 x ** 2 = x * x x ** 4 = x ** 2 * x ** 2 x ** 8 = x ** 4 * x ** 4
x ** 2, x ** 4, x ** 8, ... は値を次々に 2 乗していけば簡単に求めることができます。これをプログラムすると次のようになります。
リスト : 累乗の計算 (3) def pow2(x, n): value = 1 while n > 0: if n & 1: value *= x n >>= 1 x *= x return value
最初に局所変数 value を 1 に初期化します。そして、n が 0 よりも大きいあいだ、つまり 1 のビットがあるあいだ処理を繰り返します。最下位ビットが 1 の場合は value に x を掛け算します。そして、n を 1 ビット右へシフトします。x の値は x *= x で自乗されていくので、x, x ** 2, x ** 4, x ** 8, x ** 16 と増えていきます。繰り返しを終了したら value を返します。
それでは、実際に実行速度を比較して見ましょう。実行環境は Ubuntu 20.04 LTS (WSL1), Intel Core i5-6200U 2.30GHz, Python 3.8.10 です。
pow0(2, 1000000) : 17.716 [s] pow1(2, 1000000) : 0.0039 [s] pow2(2, 1000000) : 0.0051 [s]
pow0() と比べて pow1() と pow2() はとても速いですね。そして、pow2() よりも pow1() の方が少しだけ速くなったのには驚きました。pow1() のような再帰定義でも十分実用的といえるでしょう。
再帰といえば忘れてはいけないのが「ハノイの塔」でしょう。ハノイの塔は、棒に刺さっている大きさが異なる複数の円盤を、次の規則に従ってほかの棒に移動させるパズルです。
ハノイの塔は、再帰を使えば簡単に解ける問題です。たとえば、3 枚の円盤が左の棒に刺さっているとします。この場合、いちばん大きな円盤を中央の棒に移すには、その上の 2 枚の円盤を右の棒に移しておけばいいですね。いちばん大きな円盤を中央に移したら、右の棒に移した 2 枚の円盤を中央の棒に移すことを考えればよいわけです。したがって、n 枚の円盤を左から中央の棒に移すプログラムは次のように定義できます。
これを素直にプログラムすると次のようになります。
リスト : ハノイの塔 def hanoi(n, from_, to, via): if n == 1: print('{} => {}'.format(from_, to)) else: hanoi(n - 1, from_, via, to) print('{} => {}'.format(from_, to)) hanoi(n - 1, via, to, from_)
n は動かす円盤の枚数、from_ は移動元の棒、to は移動先の棒、via は残りの棒を示します。棒は文字列で表します。円盤の枚数が 1 枚の場合は簡単ですね。from_ にある円盤を to へ移すだけです。これが再帰の停止条件になります。この動作を print() で表示します。
円盤が複数枚ある場合、from_ にある円盤の n - 1 枚を via に移します。この処理は hanoi() を再帰呼び出しすればいいですね。次に、残りの 1 枚を to に移します。これを print() で表示します。そして、via に移した n - 1 枚の円盤を to に移します。この処理も hanoi() を再帰呼び出しするだけです。これでプログラムは完成です。それでは実行してみましょう。
>>> hanoi(3, 'A', 'B', 'C') A => B A => C B => C A => B C => A C => B A => B
次は組み合わせの数 \({}_n \mathrm{C}_r\) を求めるプログラムを作ってみましょう。組み合わせの数を求めるには、次の公式を使えば簡単です。
皆さんお馴染みの公式ですね。ところが、整数値の範囲が限られているプログラミング言語では、この公式を使うと乗算で「桁あふれ」を起こす恐れがあります。Python は多倍長演算をサポートしているので、桁あふれを心配する必要はありません。
この公式をそのままプログラムすることもできますが、次の式を使うともっと簡単にプログラムできます。
この式は \({}_n \mathrm{C}_r\) と \({}_n \mathrm{C}_{r-1}\) の関係を表しています。あとは階乗と同じように、再帰定義を使って簡単にプログラムできます。次のリストを見てください。
リスト : 組み合わせの数を求める def comb(n, r): if n == r or r == 0: return 1 return comb(n, r - 1) * (n - r + 1) // r
とても簡単ですね。ところで、整数値の範囲が限られているプログラミング言語では、この方法でも桁あふれする場合があるので注意してください。それでは、comb() を使って「パスカルの三角形」を作ってみましょう。次の図を見てください。
1 0C0 / \ / \ 1 1 1C0 1C1 / \ / \ / \ / \ 1 2 1 2C0 2C1 2C2 / \ / \ / \ / \ / \ / \ 1 3 3 1 3C0 3C1 3C2 3C3 / \ / \ / \ / \ / \ / \ / \ / \ 1 4 6 4 1 4C0 4C1 4C2 4C3 4C4 図 4 : パスカルの三角形
パスカルの三角形は、左側の図のように両側がすべて 1 で、内側の数はその左上と右上の和になっています。これは \((a + b)^n\) を展開したときの各項の係数を表しています。そして、その値は右側の図のように組み合わせの数 \({}_n \mathrm{C}_r\) に対応しています。
きれいな三角形にはなりませんが、簡単なプログラムを示します。
リスト : パスカルの三角形 def pascal(x): for n in range(0, x + 1): for r in range(0, n + 1): print(comb(n, r), end=' ') print()
実行結果は次のようになります。
>>> pascal(10) 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 1 5 10 10 5 1 1 6 15 20 15 6 1 1 7 21 35 35 21 7 1 1 8 28 56 70 56 28 8 1 1 9 36 84 126 126 84 36 9 1 1 10 45 120 210 252 210 120 45 10 1
図のように、きれいな三角形を出力するプログラムは、皆さんにお任せいたします。また、comb() を使わないでパスカルの三角形を出力するプログラムを作ってみるのも面白いでしょう。
今度は \({}_n \mathrm{C}_r\) 個の組み合わせを全て生成するプログラムを作ってみましょう。たとえば、1 から 5 までの数字の中から 3 個を選ぶ組み合わせは次のようになります。
[1, 2, 3], [1, 2, 4], [1, 2, 5], [1, 3, 4], [1, 3, 5], [1, 4, 5], [2, 3, 4], [2, 3, 5], [2, 4, 5], [3, 4, 5],
最初に 1 を選択した場合、次は [2, 3, 4, 5] の中から 2 個を選べばいいですね。2 番目に 2 を選択したら、次は [3, 4, 5] の中から 1 個を選べばいいわけです。これで、[1, 2, 3], [1, 2, 4], [1, 2, 5] が生成されます。
[2, 3, 4, 5] の中から 2 個選ぶとき、2 を選ばない場合があります。この場合は [3, 4, 5] の中から 2 個を選べばいいわけです。ここで 3 を選ぶと [1, 3, 4], [1, 3, 5] が生成できます。同様に、3 を除いた [4, 5] の中から 2 個を選ぶと [1, 4, 5] を生成することができます。
これで 1 を含む組み合わせを生成したので、次は 1 を含まない組み合わせ、つまり [2, 3, 4, 5] から 3 個を選ぶ組み合わせを生成すればいいわけです。けっきょく、この処理の考え方は次に示す組み合わせの公式と同じです。
Python でプログラムを作ると次のようになります。
リスト : 組み合わせの生成 def comb1(n, m, a = []): if m == 0: print(a) elif n == m: print(list(range(1, m + 1)) + a) else: comb1(n - 1, m, a) comb1(n - 1, m - 1, [n] + a)
関数 comb1() は、1 から n までの数字の中から m 個を選ぶ組み合わせを表示します。選んだ要素は変数 a のリストに格納します。m が 0 になったら組み合わせを一つ生成できたので、print() で a を出力します。n が m と等しくなったならば、残りの数字 (1 から m まで) を全て選択します。range で 1 から m までの数字をリストに格納し、a と連結してから print() で出力します。
この 2 つの条件が再帰呼び出しの停止条件になります。あとは comb1() を再帰呼び出しするだけです。最初の呼び出しは数字 n を選ばない場合です。残りの数字の中から m 個の数字を選びます。最後の呼び出しが数字 n を選択する場合です。数字 n を a に追加して、残りの数字の中から m - 1 個を選びます。
簡単な実行例を示します。
>>> comb1(5, 3) [1, 2, 3] [1, 2, 4] [1, 3, 4] [2, 3, 4] [1, 2, 5] [1, 3, 5] [2, 3, 5] [1, 4, 5] [2, 4, 5] [3, 4, 5]
正常に動作していますね。Python らしくジェネレータを使うと、次のようなプログラムになります。
リスト : 組み合わせの生成 def comb2(n, m): if m == 0: yield [] elif n == m: yield list(range(1, m + 1)) else: for a in comb2(n - 1, m): yield a for a in comb2(n - 1, m - 1): yield a + [n]
実行例は次のようになります。
>>> for x in comb2(5, 3): print(x) ... [1, 2, 3] [1, 2, 4] [1, 3, 4] [2, 3, 4] [1, 2, 5] [1, 3, 5] [2, 3, 5] [1, 4, 5] [2, 4, 5] [3, 4, 5]
ところで、n 個の中から m 個を選ぶ組み合わせは、ビットのオンオフで表すことができます。たとえば、5 個の数字 (0 - 4) から 3 個を選ぶ場合、数字を 0 bit から 4 bit に対応させます。すると、1, 3, 4 という組み合わせは 11010 と表すことができます。
これを Python でプログラムすると次のようになります。
リスト : 組み合わせの生成 def comb3(n, m, a = 0): if m == 0: print('{:2x} ({:05b})'.format(a, a)) elif m == n: b = a | ((1 << m) - 1) print('{:2x} ({:05b})'.format(b, b)) else: comb3(n - 1, m, a) comb3(n - 1, m - 1, a | (1 << (n - 1)))
関数 comb3() は n 個の中から m 個を選ぶ組み合わせを生成して出力します。組み合わせは引数 a にセットします。m が 0 になったら、組み合わせがひとつできたので a を出力します。n が m と等しくなったならば、残り m 個を全て選びます。(1 << m) - 1 で m 個のビットをオンにして出力します。
あとは comb3() を再帰呼び出しします。最初の呼び出しは n 番目の数字を選ばない場合です。n - 1 個の中から m 個を選びます。次の呼び出しが n 番目の数字を選ぶ場合で、a の n - 1 ビットをオンにします。そして、n - 1 個の中から m - 1 個を選びます。
それでは 5 個の中から 3 個を選ぶ comb3(5, 3) の実行例を示します。
>>> comb3(5, 3) 7 (00111) b (01011) d (01101) e (01110) 13 (10011) 15 (10101) 16 (10110) 19 (11001) 1a (11010) 1c (11100)
この場合、最小値は 0x07 (00111) で最大値は 0x1c (11100) になります。このように、comb3() は組み合わせを表す数を昇順で出力します。ところで、参考文献『C言語による最新アルゴリズム事典』の「組み合わせの生成」には、再帰呼び出しを使わずに同じ結果を得る方法が解説されてます。とても巧妙な方法なので、興味のある方は読んでみてください。
次は、N 通りある組み合わせに 0 から N - 1 までの番号を付ける方法を紹介しましょう。たとえば、6 個の中から 3 個を選ぶ組み合わせは 20 通りありますが、この組み合わせに 0 から 19 までの番号を付けることができます。1 1 1 0 0 0 を例題に考えてみましょう。次の図を見てください。
5 4 3 2 1 0 ───────── 0 0 0 1 1 1 ↑ 0 0 1 0 1 1 │ 0 0 1 1 0 1 │ 0 0 1 1 1 0 │ 0 1 0 0 1 1 │ 0 1 0 1 0 1 5C3 = 10 通り 0 1 0 1 1 0 │ 0 1 1 0 0 1 │ 0 1 1 0 1 0 │ 0 1 1 1 0 0 ↓ ───────── 1 0 0 0 1 1 ↑ 1 0 0 1 0 1 │ 1 0 0 1 1 0 │ 1 0 1 0 0 1 4C2 = 6 通り 1 0 1 0 1 0 │ 1 0 1 1 0 0 ↓ ──────── 1 1 0 0 0 1 ↑ 1 1 0 0 1 0 3C1 = 3 通り 1 1 0 1 0 0 ↓ ─────── 1 1 1 0 0 0 19 番目 ───────── 図 5 : 6C3 の組み合わせ
最初に 5 をチェックします。5 を選ばない場合は \({}_5 \mathrm{C}_3 = 10\) 通りありますね。この組み合わせに 0 から 9 までの番号を割り当てることにすると、5 を選ぶ組み合わせの番号は 10 から 19 までとなります。
次に、4 をチェックします。4 を選ばない場合は、\({}_4 \mathrm{C}_2 = 6\) 通りあります。したがって、5 を選んで 4 を選ばない組み合わせに 10 から 15 までの番号を割り当てることにすると、5 と 4 を選ぶ組み合わせには 16 から 19 までの番号となります。
最後に、3 をチェックします。同様に 3 を選ばない場合は 3 通りあるので、これに 16 から 18 までの番号を割り当て、5, 4, 3 を選ぶ組み合わせには 19 を割り当てます。これで組み合わせ 1 1 1 0 0 0 の番号を求めることができました。
では、0 0 0 1 1 1 はどうなるのでしょうか。左から順番にチェックしていくと、最初の 1 が見つかった時点で、その数字を選ばない組み合わせは存在しません。つまり、残りの数字をすべて選ぶしかないわけです。したがって、これが 0 番目となります。
このように、数字を選ぶときに、数字を選ばない場合の組み合わせの数を足し算していけば、その組み合わせの番号を求めることができるのです。プログラムは次のようになります。
リスト : 組み合わせを番号に変換 def comb_to_num(c, n, r, value = 0): if r == 0 or n == r: return value if c & (1 << (n - 1)): return comb_to_num(c, n - 1, r - 1, value + comb(n - 1, r)) else: return comb_to_num(c, n - 1, r, value)
引数 c はビットのオンオフで表した組み合わせ、引数 n と r は \({}_n \mathrm{C}_r\) の n と r を表しています。引数 value は求める番号を表します。n と r の値が同じになるか、もしくは r が 0 になれば、組み合わせの番号を計算できたので value を返します。
そうでない場合、c の n - 1 ビットの値を調べます。ビットがオンであれば、value に comb(n - 1, r) の値を足し算し、r を -1 して comb_to_num() を再帰呼び出しします。そうでなければ、value と r の値はそのままで comb_to_num() を再帰呼び出しします。
逆に、番号から組み合わせを求めるプログラムも簡単に作ることができます。次のリストを見てください。
リスト : 番号を組み合わせに変換 def num_to_comb(value, n, r, c = 0): if r == 0: return c elif n == r: return c | ((1 << n) - 1) else: k = comb(n - 1, r) if value >= k: return num_to_comb(value - k, n - 1, r - 1, c | 1 << (n - 1)) else: return num_to_comb(value, n - 1, r, c)
引数 value が番号で、引数 n と r は \({}_n \mathrm{C}_r\) の n と r を表しています。引数 c が求める組み合わせです。たとえば、n = 5, r = 3 の場合、ビットが 1 になるのは \({}_4 \mathrm{C}_2 = 6\) 通りあり、0 になるのは \({}_4 \mathrm{C}_3 = 4\) 通りあります。したがって、数値が 0 - 3 の場合はビットを 0 にし、4 - 9 の場合はビットを 1 にすればいいわけです。
ビットを 0 にした場合、残りは \({}_4 \mathrm{C}_3 = 4\) 通りになるので、同様に次のビットを決定します。ビット 1 にした場合、残りは \({}_4 \mathrm{C}_2 = 6\) 通りになるので、value から 4 を引いて num_to_comb() を再帰呼び出しして次のビットを決定します。
r が 0 の場合は、組み合わせが完成したので c を返します。n と r が等しい場合は、残りのビットをすべて 1 にセットしてから c を返します。それ以外の場合は、\({}_{n-1} \mathrm{C}_r\) の値を comb(n - 1, r) で求めて変数 k にセットします。value が k 以上であれば変数 c のビットを 1 にセットし、value から k を引き算して comb_to_num() を再帰呼び出しします。そうでなければ、num_to_comb() を再帰呼び出しするだけです。
それでは、n = 5, r = 3 の場合の実行例を示します。
>>> for x in range(10): ... y = num_to_comb(x, 5, 3) ... z = comb_to_num(y, 5, 3) ... print('{:d} => {:x} => {:d}'.format(x, y, z)) ... 0 => 7 => 0 1 => b => 1 2 => d => 2 3 => e => 3 4 => 13 => 4 5 => 15 => 5 6 => 16 => 6 7 => 19 => 7 8 => 1a => 8 9 => 1c => 9
正常に動作していますね。この方法を使うと、n 個ある組み合わせの中の i 番目 (\(0 \leq i \lt n\)) の組み合わせを簡単に求めることができます。
最後に、再帰を使った面白い関数を紹介しましょう。次のリストを見てください。
リスト : たらいまわし関数 def tarai(x, y, z): if x <= y: return y return tarai(tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y)) def tak(x, y, z): if x <= y: return z return tak(tak(x - 1, y, z), tak(y - 1, z, x), tak(z - 1, x, y))
tarai() や tak() は「たらいまわし関数」といい、再帰的に定義されています。これらの関数は、引数の与え方によっては実行に時間がかかるため、Lisp などのベンチマークに利用されることがあります。tarai() は通称「竹内関数」と呼ばれていて、日本の代表的な Lisper である竹内郁雄氏によって考案されたそうです。そして、tak() は tarai() のバリエーションで、John Macarthy 氏によって作成されたそうです。たらいまわし関数が Lisp のベンチマークで使われていたことは知っていましたが、このような由緒ある関数だとは思ってもいませんでした。
それでは、さっそく実行してみましょう。実行環境は Ubuntu 20.04 LTS (WSL1), Intel Core i5-6200U 2.30GHz, Python 3.8.10 です。
tarai(12, 6, 0) : 1.21 [s] tak(18, 9, 0) : 1.49 [s]
このように、たらいまわし関数は引数の値が小さくても実行に時間がかかります。
たらいまわし関数が遅いのは、同じ値を何度も計算しているためです。この場合、表 (table) を使って処理を高速化することができます。同じ値を何度も計算することがないように、計算した値は表に格納しておいて、2 回目以降は表から計算結果を求めるようにします。このような手法を「表計算法」とか「メモ化」といいます。
Python の場合、辞書 (ハッシュ表) を使うと簡単です。次のリストを見てください。
リスト : たらいまわし関数のメモ化 table = {} # メモ用の辞書 def tarai1(x, y, z): global table key = (x, y, z) if not key in table: if x <= y: table[key] = y else: table[key] = tarai1(tarai1(x - 1, y, z), tarai1(y - 1, z, x), tarai1(z - 1, x, y)) return table[key]
関数 tarai1() の値を格納する辞書をグローバル変数 table に用意します。関数 tarai では、引数 x, y, z を要素とするタプルを作り、それをキーとして辞書 table を検索します。table に key があれば、その値を返します。そうでなければ、値を計算して table にセットして、その値を返します。
このようにメモ化のプログラムは簡単にできますが、メモ化を行うたびに関数を修正するのは面倒ですね。そこで、『計算機プログラムの構造と解釈 第二版 (SICP)』 3.3.3 表の表現 を参考に、関数をメモ化する「メモ化関数」を作成してみましょう。次のリストを見てください。
リスト : メモ化関数 (1) def memoize(f): table = {} def func(*args): if not args in table: table[args] = f(*args) return table[args] return func
関数 memoize() は関数 f() を引数に受け取り、それをメモ化した関数を返します。memoize() が返す関数はクロージャなので、memoize() の引数 f() や局所変数 table にアクセスすることができます。詳しい説明は拙作のページ「新・お気楽 Python プログラミング入門 第 3 回」をお読みください。また、局所関数 func() の引数を *args で定義することで、複数の引数を持つ関数にも対応していることに注意してください。
args の値は引数を格納したタプルになるので、これをキーとして扱います。table に キー args がなければ、f() を呼び出して値を計算し、それを table にセットします。最後に table[args] の値を返すだけです。
関数をメモ化する場合、Python では次のようにデコレータ表記を使うと簡単です。
リスト : たらいまわし関数のメモ化 (1) @memoize def tarai(x, y, z): if x <= y: return y return tarai(tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y))
これは次に示すように tarai の値をメモ化した関数に書き換える動作になります。
リスト : たらいまわし関数のメモ化 (2) def tarai(x, y, z): if x <= y: return y return tarai(tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y)) tarai = memoize(tarai)
tarai の値を書き換えないと、tarai() の中で再帰呼び出しするとき、メモ化した関数を呼び出すことはできません。ご注意ください。
また、メモ化関数は Python らしくクラスを使っても実装することができます。次のリストを見てください。
リスト : メモ化関数 (2) class Memoize: def __init__(self, func): self.table = {} self.f = func def __call__(self, *args): if not args in self.table: self.table[args] = self.f(*args) return self.table[args]
クラス Memoize のインスタンス変数 table にメモ用の辞書を、インスタンス変数 f に関数 func() をセットします。あとは、メソッド __call__() で memoize() と同様の処理を行うだけです。関数のメモ化は memoize() と同様にデコレータ表記を使うと簡単です。
リスト : たらいまわし関数のメモ化 (3) @Memoize def tarai(x, y, z): if x <= y: return y return tarai(tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y))
これで関数をメモ化することができます。この場合、変数 tarai には Memoize のインスタンスが格納され、それを関数呼び出しすることでメソッド __call__() が呼び出され、関数のメモ化が機能します。
それでは実際に実行してみましょう。
@memoize tarai(120, 60, 0) : 0.023 [s] tak(180, 90, 0) : 0.16 [s] @Memoize tarai(120, 60, 0) : 0.027 [s] tak(180, 90, 0) : 0.18 [s]
引数の値を 10 倍にしましたが、どちらの方法でも高速に求めることができました。Python の場合、__call__() の呼び出しは通常の関数呼び出しよりも時間が少しかかるようで、memoize() によるメモ化の方が少しだけ速くなりました。興味のある方はいろいろ試してみてください。
tarai() は「遅延評価」を行う処理系、たとえば関数型言語の Haskell では高速に実行することができます。また、Scheme でも delay と force を使って遅延評価を行うことができます。
tarai() のプログラムを見てください。x <= y のときに y を返しますが、このとき引数 z の値は必要ありませんね。引数 z の値は x > y のときに計算するようにすれば、無駄な計算を省略することができます。なお、tak() は x <= y のときに z を返しているため、遅延評価で高速化することはできません。ご注意ください。
完全ではありませんが、Python でもクロージャを使って遅延評価を行うことができます。Shiro さんの WiLiKi にある「Scheme:たらいまわしべんち」を参考に、プログラムを作ってみましょう。次のリストを見てください。
リスト : クロージャによる遅延評価 def tarai(x, y, z): if x <= y: return y zz = z() return tarai(tarai(x - 1, y, lambda : zz), tarai(y - 1, zz, lambda : x), lambda : tarai(zz - 1, x, lambda : y))
遅延評価したい処理をクロージャに包んで引数 z に渡します。そして、x > y のときに引数 z を評価 (関数呼び出し) します。すると、クロージャ内の処理が実行されて z の値を求めることができます。
たとえば、lambda : 0 を z に渡す場合、z() とすると返り値は 0 になります。lambda : x を渡せば、x に格納されている値が返されます。lambda : tarai( ... ) を渡せば、関数 tarai が実行されてその値が返されるわけです。
それでは、実際に実行してみましょう。
tarai(192, 96, 0) メモ化 : 0.064 [s] 遅延評価 : 0.009 [s]
tarai() の場合、遅延評価の効果はとても大きいですね。ところで、クロージャを使わなくても、tarai() を高速化する方法があります。Akira Higuchi さんが書かれたC言語の tarai() はとても高速です。Python でプログラムすると次のようになります。
リスト : tarai() の遅延評価 (2) def tarai_lazy(x, y, xx, yy, zz): if x <= y: return y z = tarai(xx, yy, zz) return tarai_lazy(tarai(x - 1, y, z), tarai(y - 1, z, x), z - 1, x, y) def tarai(x, y, z): if x <= y: return y return tarai_lazy(tarai(x - 1, y, z), tarai(y - 1, z, x), z - 1, x, y)
関数 tarai_lazy() の引数 xx, yy, zz で z の値を表すところがポイントです。つまり、z の計算に必要な値を引数に保持し、z の値が必要になったときに tarai(xx, yy, zz) で計算するわけです。実際に実行してみると tarai(192, 96, 0) は 0.004 [s] になりました。このような簡単な方法で tarai 関数を高速化できるとは驚きました。Akira Higuchi さんに感謝いたします。