M.Hiroi's Home Page

Scala Programming

お気楽 Scala プログラミング入門

[ PrevPage | Scala | NextPage ]

継続渡しスタイル

今回は「継続渡しスタイル (Continuation Passing Style : CPS)」という手法について説明します。Scheme には「継続」という他の言語 [*1] にはない強力な機能がありますが、使いこなすのはちょっと難しいといわれています。継続渡しスタイルはクロージャを使った汎用的な方法で、クロージャがあるプログラミング言語であれば、継続渡しスタイルでプログラムを作成することができます。

-- note --------
[*1] 実は Ruby にも「継続」があります。また、標準的な機能ではありませんが、SML/NJ や OCaml でも拡張機能を使って「継続」を取り扱うことができます。

●継続とは?

最初に継続について簡単に説明します。継続は「次に行われる計算」のことです。たとえば、次のプログラムを例に考えてみましょう。

scala> def foo(): Unit = println("foo")
def foo(): Unit

scala> def bar(): Unit = println("bar")
def bar(): Unit

scala> def baz(): Unit = println("baz")
def baz(): Unit

scala> def test(): Unit = {foo(); bar(); baz()}
def test(): Unit

scala> test()
foo
bar
baz

関数 test は関数 foo, bar, baz を順番に呼び出します。foo の次に実行される処理は bar, baz の関数呼び出しです。この処理が foo を呼び出したあとの「継続」になります。同様に、bar のあとに実行されるのは baz の呼び出しで、この処理がこの時点での「継続」になります。また、baz を呼び出したあと、test の中では次に実行する処理はありませんが、test は関数呼び出しされているので、関数呼び出しから元に戻る処理が baz を呼び出したあとの「継続」になります。

このように、あるプログラムを実行しているとき、そのプログラムを終了するまでには「次に実行する処理 (計算)」が必ず存在します。一般に、この処理 (計算) のことを「継続」といいます。

Scheme の場合、次の計算を続行するための情報を取り出して、それを保存することができます。Scheme では、この保存した情報を「継続」といって、通常のデータ型と同様に取り扱うことができます。つまり、継続を変数に代入したり関数の引数に渡すことができるのです。継続を使うとプログラムの実行を途中で中断し、あとからそこに戻ってプログラムの実行を再開することができます。

●継続渡しスタイルとは?

一般のプログラミング言語では、Scheme のように継続を取り出して保存することはできません。そこで、継続 (次に行う処理) を関数 (クロージャ) で表して、それを引数に渡して実行することにします。これを「継続渡しスタイル (CPS)」といいます。たとえば、次の例を見てください。

scala> def testCps(cont: () => Unit): Unit = { foo(); bar(); cont() }
def testCps(cont: () => Unit): Unit

scala> testCps(baz)
foo
bar
baz

scala> testCps(bar)
foo
bar
bar

関数 testCps は foo, bar を呼び出したあと、引数 cont に渡された処理 (継続) を実行します。関数 baz を渡せば foo, bar, baz と表示されますし、他の処理を渡せばそれを実行することができます。

もう一つ簡単な例を示しましょう。継続に値を渡して処理を行うこともできます。

scala> def addCps(a: Int, b: Int, cont: Int => Int): Int = cont(a + b)
def addCps(a: Int, b: Int, cont: Int => Int): Int

scala> addCps(1, 2, x => x)
val res3: Int = 3

scala> addCps(1, 2, x => {println(x); x})
3
val res4: Int = 3

関数 addCps は引数 a と b を加算して、その結果を継続 cont に渡します。cont に x => x を渡せば、計算結果を返すことができます。また、cont で println(x) を呼び出せば、計算結果を表示することができます。

ところで、addCps は次のように多相関数として定義するともっと便利になります。

scala> def addCps1[A](a: Int, b: Int, cont: Int => A): A = cont(a + b)
def addCps1[A](a: Int, b: Int, cont: Int => A): A

scala> addCps1(10, 20, x => x)
val res5: Int = 30

scala> addCps1(10, 20, println)
30

継続 cont の返り値を型パラメータ A にします。そうすると、addCps1 には println を直接渡すことができます。

●再帰呼び出しと継続渡しスタイル

CPS を使うと再帰呼び出しを末尾再帰に変換することができます。たとえば、階乗の計算を CPS でプログラムすると次のようになります。

リスト : 階乗の計算 (CPS)

  def factCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0) cont(1)
    else factCps(n - 1, x => cont(n * x))

引数 cont が継続を表します。n == 0 のときは、cont に階乗の値 1 を渡します。それ以外の場合は、階乗の計算を継続の処理にまかせて factCps を再帰呼び出します。ここで、factCps の呼び出しは末尾再帰になることに注意してください。

継続の処理 x => cont(n * x) では、継続の引数 x と factCps の引数 n を掛け算して、その結果を cont に渡します。たとえば、factCps(4, x => x) の呼び出しを図に示すと、次のようになります。

   fact(4, x => x)
=>      4 (x1 => (x => x) (4 * x1))
=>      3 (x2 => (x1 => (x => x) (4 * x1)) (3 * x2))
=>      2 (x3 => (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) (2 * x3))
=>      1 (x4 => (x3 => (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) (2 * x3)) (1 * x4))
=>      0 (x4 => (x3 => (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) (2 * x3)) (1 * x4)) 1

継続の評価

   (x4 => (x3 => (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) (2 * x3)) (1 * x4)) 1
=> (x3 => (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) (2 * x3)) 1
=> (x2 => (x1 => (x => x) (4 * x1)) (3 * x2)) 2
=> (x1 => (x => x) (4 * x1)) 6
=> (x => x) 24
=> 24


                    図 1 : fact_cps の実行

このように、継続の中で階乗の式が組み立てられていきます。そして、n == 0 のとき継続 cont に引数 1 を渡して評価すると、今までに組み立てられた式が評価されて階乗の値を求めることができます。つまり、n の階乗を求めるとき、継続 x => cont(n * x) の引数 x には n - 1 の階乗の値が渡されていくわけです。そして、最後に継続 x => x に n の階乗の値が渡されるので、階乗の値を返すことができます。

それでは実際に実行してみましょう。

scala> for (i <- 1 to 20) factCps(i, println)
1
2
6
24
120
720
5040
40320
362880
3628800
39916800
479001600
6227020800
87178291200
1307674368000
20922789888000
355687428096000
6402373705728000
121645100408832000
2432902008176640000

●二重再帰と継続渡しスタイル

次はフィボナッチ数列を求める関数を CPS で作りましょう。次のリストを見てください。

リスト : フィボナッチ関数

  // 二重再帰
  def fibo(n: Int): BigInt =
    if (n == 0 || n == 1) n
    else fibo(n - 1) + fibo(n - 2)

  // CPS
  def fiboCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0 || n == 1) cont(n)
    else fiboCps(n - 1, x => fiboCps(n - 2, y => cont(x + y)))

関数 fiboCps は、引数 n が 0 または 1 のとき cont 1 を評価します。それ以外の場合は fiboCps を再帰呼び出しします。fiboCps(n - 1) が求まると、その値は継続の引数 x に渡されます。継続の中で、今度は fiboCps(n - 2) の値を求めます。すると、その値は fiboCps (n - 2) の継続の引数 y に渡されます。したがって、fiboCps(n) の値は x + y で求めることができます。この値を fiboCps(n) の継続 cont に渡せばいいわけです。

fiboCps の実行を図に示すと、次のようになります。

cont は継続を表します。fiboCps は末尾再帰になっているので、n - 1 の値を求めるために左から右へ処理が進みます。このとき、n - 2 の値を求める継続 cont が生成されていくことに注意してください。そして、f(1) の実行が終了すると継続が評価され、n - 2 の値が求められます。すると、2 番目の継続が評価されて n - 1 の値 x と n - 2 の値 y を加算して、その値を継続 cont に渡します。こうして、次々と継続が評価されてフィボナッチ関数の値を求めることができます。

それでは実際に実行してみましょう。

scala> for (i <- 0 to 15) fiboCps(i, println)
0
1
1
2
3
5
8
13
21
34
55
89
144
233
377
610

正常に動作していますね。ところが、fiboCps(16, println) を実行するとエラーになってしまいました。Scala には末尾再帰最適化が行われているかチェックするアノテーション @tailrec が用意されています。@tailrec を使う場合、scala.annotation.tailrec をインポートしてください。

リスト : 末尾再帰最適化のチェック

  @tailrec
  def fiboCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0 || n == 1) cont(n)
    else fiboCps(n - 1, x => fiboCps(n - 2, y => cont(x + y)))

末尾再帰最適化が行われない場合、コンパイルエラーになります。

scala> import scala.annotation.tailrec
import scala.annotation.tailrec

scala> :paste
// Entering paste mode (ctrl-D to finish)

  @tailrec
  def fiboCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0 || n == 1) cont(n)
    else fiboCps(n - 1, x => fiboCps(n - 2, y => cont(x + y)))

// Exiting paste mode, now interpreting.


    else fiboCps(n - 1, x => fiboCps(n - 2, y => cont(x + y)))
                                    ^
<pastie>:4: error: could not optimize @tailrec annotated method fiboCps: 
it contains a recursive call not in tail position

他の関数型言語 (Scheme, SML/NJ, OCaml など) では末尾再帰最適化されるのですが、Scala では最適化されないようです。Scala と CPS の相性はあまりよくないのかもしれません。

なお、fiboCps は末尾再帰最適化されたとしても、関数の呼び出し回数は二重再帰の場合と同じです。したがって、実行速度は二重再帰の場合とほとんどかわりません。また、二重再帰の場合は関数呼び出しによりスタックが消費されますが、CPS の場合はクロージャが生成されるのでメモリ (ヒープ領域) が消費されます。このように、再帰呼び出しを CPS に変換したからといって、効率の良いプログラムになるとは限りません。ご注意くださいませ。

●CPS の便利な使い方

階乗やフィボナッチ関数の場合、CPS に変換するメリットはほとんどありませんが、場合によっては CPS に変換した方が簡単にプログラムできることもあります。たとえば、リストを平坦化する関数 flatten で、リストの要素に空リストが含まれていたら空リストを返すようにプログラムを修正することを考えてみましょう。次のリストを見てください。

リスト : リストの平坦化 (間違い)

  def flatten[A](xs: List[List[A]]): List[A] =
    xs match {
      case Nil => Nil
      case x::_ if (x == Nil) => Nil
      case x::xs => x ::: flatten(xs)
    }

関数 flatten は空リストを見つけたら空リストを返していますが、これでは正常に動作しません。実際に試してみると次のようになります。

scala> flatten(List(List(1,2), List(3,4), List(5,6)))
val res0: List[Int] = List(1, 2, 3, 4, 5, 6)

scala> flatten(List(List(1,2), List(3,4), Nil, List(5,6)))
val res1: List[Int] = List(1, 2, 3, 4)

2 番目の例が空リストを含む場合です。この場合、空リストを返したいのですが、その前の要素を連結したリストを返しています。空リストを見つける前にリストの連結処理を行っているので、空リストを見つけたらその処理を廃棄しないといけないのです。

このような場合、CPS を使うと簡単です。次のリストを見てください。

リスト : リストの平坦化 (CPS)

  def flattenCps[A](xs: List[List[A]], cont: List[A] => List[A]): List[A] =
    xs match {
      case Nil => cont(Nil)
      case x::_ if (x == Nil) => Nil
      case x::xs1 => flattenCps(xs1, (y: List[A]) => cont(x ::: y))
    }

flatten を CPS に変換するのは簡単です。リストの先頭の要素 x と平坦化したリストの連結を継続で行うだけです。平坦化したリストは継続の引数 y に渡されるので、x @ y でリストを連結して、それを継続 cont に渡せばいいわけです。

引数のリストが空リストになったら継続 cont に空リストを渡して評価します。これで、リストの連結処理が行われます。もしも、途中で空リストを見つけた場合は、空リストをそのまま返します。この場合、継続 cont は評価されないので、リストの連結処理は行われず、空リストをそのまま返すことができます。

それでは実行してみましょう。

scala> flattenCps(List(List(1,2), List(3,4), List(5,6)), (x: List[Int]) => x)
val res2: List[Int] = List(1, 2, 3, 4, 5, 6)

scala> flattenCps(List(List(1,2), List(3,4), Nil, List(5,6)), (x: List[Int]) => x)
val res3: List[Int] = List()

正常に動作していますね。

●二分木の巡回を CPS で実装

次は二分木を巡回するプログラムを CPS で作ってみましょう。二分木の詳しい説明は拙作のページ 多相クラス (3)「不変 (immutable) で多相的な二分木」をお読みください。

リスト : 二分木の巡回

  // 抽象クラス
  abstract class Tree[+A] {
    def left: Tree[A]
    def right: Tree[A]
    def item: A
    def isEmpty: Boolean

    // 巡回
    def foreach(f: A => Unit): Unit = {
      if (!isEmpty) {
        left.foreach(f)
        f(item)
        right.foreach(f)
      }
    }
  }

  // 節
  case class Node[A](item: A, left: Tree[A], right: Tree[A]) extends Tree[A] {
    def isEmpty: Boolean = false
  }

  // 終端 (空の木)
  case object Nils extends Tree[Nothing] {
    def item  = throw new Exception("Nils: item is not member")
    def left  = throw new Exception("Nils: left is not member")
    def right = throw new Exception("Nils: right is not member")
    def isEmpty: Boolean = true
  }

メソッド foreach は二重再帰になっています。そこで、f(item) の評価と右部分木の巡回は継続で行うことにします。プログラムは次のようになります。

リスト : 二分木の巡回 (CPS)

    def foreachCps(f: A => Unit)(cont: () => Unit) {
      if (isEmpty) cont()
      else left.foreachCps(f)(() => {f(item); right.foreachCps(f)(() => cont())})
    }

foreachCpsは副作用が目的なので、継続に値を渡す必要はありません。そこで、cont には Unit を渡すことにします。左部分木をたどったら継続 cont を呼び出します。その中で f(item) を評価し、そのあと右部分木をたどります。このときの継続は cont() を評価するだけです。これで生成された継続を呼び出して、木を巡回することができます。

それでは実際に試してみましょう。

scala> val a: Tree[Int] = Node(4, Node(2, Node(1, Nils, Nils), Node(3, Nils, Nils)),
     | Node(6, Node(5, Nils, Nils), Node(7, Nils, Nils)))
val a: sample23.Tree[Int] = Node(4,Node(2,Node(1,Nils,Nils),Node(3,Nils,Nils)),
Node(6,Node(5,Nils,Nils),Node(7,Nils,Nils)))

scala> a.foreachCps(println)(() => ())
1
2
3
4
5
6
7

このように、foreachCps で二分木を通りがけ順で巡回することができます。

●二分木と遅延ストリーム

二分木の巡回を CPS に変換すると、遅延ストリームに対応するのも簡単です。次のリストを見てください。

リスト : 二分木の巡回 (遅延ストリーム版)

    def streamOfTree[B >: A](cont: () => LazyList[B]): LazyList[B] = 
      if (isEmpty) cont()
      else left.streamOfTree(() => item #:: right.streamOfTree(() => cont()))

streamOfTree は二分木を巡回してその要素を順番に出力する遅延ストリームを生成します。foreachCps は継続の中で関数 f を呼び出しましたが、streamOfTree は継続の中で遅延ストリーム LazyList を返します。そして、演算子 #:: の右辺で右部分木をたどり、その継続の中で cont() を呼び出します。

ここで継続 cont の型は () => LazyList[B] になることに注意してください。streamOfTree を呼び出すときに渡す継続が一番最後に呼び出されるので、遅延ストリームの終端 Strem.empty を返すように定義してください。

簡単な実行例を示しましょう。

scala> a
val res2: sample23.Tree[Int] = Node(4,Node(2,Node(1,Nils,Nils),Node(3,Nils,Nils)),
Node(6,Node(5,Nils,Nils),Node(7,Nils,Nils)))

scala> val s = a.streamOfTree(() => LazyList.empty)
val s: scala.collection.immutable.LazyList[Int] = LazyList()

scala> s.head
val res3: Int = 1

scala> s.tail.head
val res4: Int = 2

scala> s.tail.tail.head
val res5: Int = 3

scala> s.take(7).toList
val res6: List[Int] = List(1, 2, 3, 4, 5, 6, 7)

streamOfTree を使うと、2 つの二分木が等しいか判定する述語 isEqual を簡単に作ることができます。二分木の要素がすべて等しい場合、isEqual は true を返し、そうでなければ false を返すことにします。つまり、二分木を集合として扱うわけです。プログラムは次のようになります。

リスト : 同値の判定

    def isEqual[B >: A](xs: Tree[B]): Boolean = {
      def iter(s1: LazyList[B], s2: LazyList[B]): Boolean =
        if (s1.isEmpty && s2.isEmpty) true
        else if(s1.isEmpty || s2.isEmpty || s1.head != s2.head) false
        else iter(s1.tail, s2.tail)
      //
      iter(this.streamOfTree(() => LazyList.empty),
           xs.streamOfTree(() => LazyList.empty))
    }

実際の処理は局所関数 iter で行います。iter には二分木の遅延ストリームを渡します。あとは、遅延ストリームから要素を一つずつ取り出して、それが等しいかチェックするだけです。

それでは実行例を示します。

scala> val a: Tree[Int] = Node(4, Node(2, Node(1, Nils, Nils), Node(3, Nils, Nils)),
     | Node(6, Node(5, Nils, Nils), Node(7,Nils,Nils)))
val a: sample23.Tree[Int] = Node(4,Node(2,Node(1,Nils,Nils),Node(3,Nils,Nils)),
Node(6,Node(5,Nils,Nils),Node(7,Nils,Nils)))

scala> val b = Node(1,Nils,Node(2,Nils, Node(3,Nils, Node(4, Nils, Node(5,Nils,
     | Node(6, Nils, Node(7, Nils, Nils)))))))
val b: sample23.Node[Int] = Node(1,Nils,Node(2,Nils,Node(3,Nils,Node(4,Nils,
Node(5,Nils,Node(6,Nils,Node(7,Nils,Nils)))))))

scala> a.isEqual(b)
val res7: Boolean = true

scala> val c = Node(1,Nils,Node(2,Nils, Node(3,Nils, Node(4, Nils, Node(5,Nils,
     | Node(6, Nils, Node(8, Nils, Nils)))))))
val c: sample23.Node[Int] = Node(1,Nils,Node(2,Nils,Node(3,Nils,Node(4,Nils,
Node(5,Nils,Node(6,Nils,Node(8,Nils,Nils)))))))

scala> a.isEqual(c)
val res8: Boolean = false

変数 a, b に二分木をセットします。a と b では二分木の形状は異なりますが要素はすべて同じです。したがって、isEqual(a, b) は true を返します。変数 c にセットされた二分木は要素が一つだけ異なっているので、isEqual(a, c) は false を返します。

部分集合を判定する関数 isSubSet も簡単です。次のリストを見てください。

リスト : 部分集合の判定

    def isSubSet[B >: A <% Ordered[B]](xs: Tree[B]): Boolean = {
      def iter(s1: LazyList[B], s2: LazyList[B]): Boolean =
        if (s1.isEmpty) true
        else if (s2.isEmpty) false
        else if (s1.head == s2.head) iter(s1.tail, s2.tail)
        else if (s1.head > s2.head) iter(s1, s2.tail)
        else false
      //
      iter(this.streamOfTree(() => LazyList.empty),
           xs.streamOfTree(() => LazyList.empty))
    }

実際の処理は局所関数 iter で行います。遅延ストリーム s1 が s2 の途中で終了した場合、s1 の要素はすべて s2 にあるので s1 は s2 の部分集合です。isSubSet は true を返します。s2 が途中で終了した場合、s1 は s2 に含まれていない要素があるので、部分集合ではありません。false を返します。

そうでなければ、遅延ストリームから要素を一つずつ取り出します。s1.head == s2.head ならば次の要素を調べます。 s1.head > s2.head の場合、x と等しい要素が s2 に存在するかもしれないので、x と s2 の次の要素を比較します。それ以外の場合、s1.head と等しい要素は s2 に存在しないことがわかるので false を返します。

それでは実行例を示します。

scala> a
val res9: sample23.Tree[Int] = Node(4,Node(2,Node(1,Nils,Nils),Node(3,Nils,Nils)),
Node(6,Node(5,Nils,Nils),Node(7,Nils,Nils)))

scala> val b: Tree[Int] = Node(3,Node(2,Node(1, Nils, Nils), Nils), Nils)
val b: sample23.Tree[Int] = Node(3,Node(2,Node(1,Nils,Nils),Nils),Nils)

scala> b.isSubSet(a)
val res10: Boolean = true

scala> val c: Tree[Int] = Node(3,Node(2,Node(0, Nils, Nils), Nils), Nils)
val c: sample23.Tree[Int] = Node(3,Node(2,Node(0,Nils,Nils),Nils),Nils)

scala> c.isSubSet(a)
val res11: Boolean = false

正常に動作していますね。

●TailCalls

Scala は fiboCps だけではなく、相互再帰の場合も末尾再帰は最適化されません。このような場合、オブジェクト scala.util.control.TailCalls に用意されている関数を使うと、スタックオーバーフローせずにプログラムを実行することができます。TailCalls に用意されている関数を示します。

def done[A](result: A): TailRec[A]
def tailcall[A](rest: => TailRec[A]): TailRec[A]

done は値を返すときに呼び出します。値は TailRec のフィールド変数 result に格納されて返されます。再帰呼び出しするときは tailcall を使います。TailCalls は遅延評価を使って末尾再帰のプログラムをスタックを消費せずに実行します。本当に最適化される (繰り返しに変換される) わけではないので、実行速度は遅くなるかもしれません。ご注意くださいませ。

簡単な例を示しましょう。

リスト : 相互再帰

  def isEven(n: Int): TailRec[Boolean] =
    if (n == 0) done(true)
    else tailcall(isOdd(n - 1))

  def isOdd(n: Int): TailRec[Boolean] = 
    if (n == 0) done(false)
    else tailcall(isEven(n - 1))

isEven と isOdd は相互再帰で偶数と奇数を判定します。Scala の場合、TailCalls を使わないとスタックオーバーフローします。isEven と isOdd の返り値の型は TailRec[Boolean] になります。n == 0 のとき、真偽値を返しますが done で TailRec に格納して返します。そして、再帰呼び出しするときは tailcall を経由して行います。

それでは実行してみましょう。

scala> isEven(1000).result
val res12: Boolean = true

scala> isOdd(10000).result
val res13: Boolean = false

scala> isEven(100000).result
val res14: Boolean = true

scala> isOdd(999999999).result
val res15: Boolean = true

このように、大きな値でもスタックオーバーフローせずに実行することができます。

次は階乗 (CPS 版) を TailCalls で書き直してみましょう。

リスト : 階乗

  def factCps1(n: Int, cont: TailRec[BigInt] => TailRec[BigInt]): TailRec[BigInt] =
    if (n == 0) cont(done(1))
    else tailcall(factCps1(n - 1, x => cont(done(x.result * n))))

継続 cont の型は TailRec[BigInt] => TailRec[BigInt] になります。そして、cont を呼び出すとき、その引数を done で TailRec に変換して渡します。cont の引数 x は TailRec なので、計算するときは x.result で値を取り出してください。

それでは実行してみましょう。

scala> factCps1(100, x => x).result
val res16: BigInt = 933262154439441526816992388562667004907159682643816214685929
63895217599993229915608941463976156518286253697920827223758251185210916864000000
000000000000000000

同様に、フィボナッチ関数 fiboCps を書き直すこと次のようになります。

リスト : フィボナッチ関数

  def fiboCps1(n: Int, cont: TailRec[BigInt] => TailRec[BigInt]): TailRec[BigInt] =
    if (n == 0 || n == 1) cont(done(n))
    else tailcall(fiboCps1(n - 1, x => tailcall(fiboCps1(n - 2, y => cont(done(x.result + y.result))))))

匿名関数の引数 x, y は TailRec なので、x.result と y.result で値を取り出して計算するだけです。

実行結果を示します。

scala> fiboCps1(16, x => x).result
val res0: BigInt = 987

scala> fiboCps1(20, x => x).result
val res1: BigInt = 6765

scala> fiboCps1(30, x => x).result
val res2: BigInt = 832040

スタックオーバーフローせずに値を求めることができました。


●プログラムリスト

//
// sample23.html : 継続渡しスタイル (CPS) サンプルプログラム
//
//                 Copyright (C) 2014-2021 Makoto Hiroi
//
import scala.collection.immutable.LazyList
import scala.util.control.TailCalls._

object sample23 {

  // 階乗
  def factCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0) cont(1)
    else factCps(n - 1, x => cont(n * x))

  // フィボナッチ関数
  def fibo(n: Int): BigInt =
    if (n == 0 || n == 1) n
    else fibo(n - 1) + fibo(n - 2)

  def fiboCps[A](n: Int, cont: BigInt => A): A =
    if (n == 0 || n == 1) cont(n)
    else fiboCps(n - 1, x => fiboCps(n - 2, y => cont(x + y)))

  // リストの平坦化
  def flatten[A](xs: List[List[A]]): List[A] =
    xs match {
      case Nil => Nil
      case x::_ if (x == Nil) => Nil
      case x::xs1 => x ::: flatten(xs1)
    }

  def flattenCps[A](xs: List[List[A]], cont: List[A] => List[A]): List[A] =
    xs match {
      case Nil => cont(Nil)
      case x::_ if (x == Nil) => Nil
      case x::xs1 => flattenCps(xs1, (y: List[A]) => cont(x ::: y))
    }

  // 二分木
  abstract class Tree[+A] {
    def left: Tree[A]
    def right: Tree[A]
    def item: A
    def isEmpty: Boolean

    // 巡回
    def foreach(f: A => Unit): Unit = {
      if (!isEmpty) {
        left.foreach(f)
        f(item)
        right.foreach(f)
      }
    }

    // CPS 版
    def foreachCps(f: A => Unit)(cont: () => Unit): Unit = {
      if (isEmpty) cont()
      else left.foreachCps(f)(() => {f(item); right.foreachCps(f)(() => cont())})
    }

    // 遅延ストリーム版
    def streamOfTree[B >: A](cont: () => LazyList[B]): LazyList[B] =
      if (isEmpty) cont()
      else left.streamOfTree(() => item #:: right.streamOfTree(() => cont()))

    // 同値の判定
    def isEqual[B >: A](xs: Tree[B]): Boolean = {
      def iter(s1: LazyList[B], s2: LazyList[B]): Boolean =
        if (s1.isEmpty && s2.isEmpty) true
        else if(s1.isEmpty || s2.isEmpty || s1.head != s2.head) false
        else iter(s1.tail, s2.tail)
      //
      iter(this.streamOfTree(() => LazyList.empty),
           xs.streamOfTree(() => LazyList.empty))
    }

    // 部分集合の判定
    def isSubSet[B >: A](xs: Tree[B])(implicit f: B => Ordered[B]): Boolean = {
      def iter(s1: LazyList[B], s2: LazyList[B]): Boolean =
        if (s1.isEmpty) true
        else if (s2.isEmpty) false
        else if (s1.head == s2.head) iter(s1.tail, s2.tail)
        else if (s1.head > s2.head) iter(s1, s2.tail)
        else false
      //
      iter(this.streamOfTree(() => LazyList.empty),
           xs.streamOfTree(() => LazyList.empty))
    }

  }

  case class Node[A](item: A, left: Tree[A], right: Tree[A]) extends Tree[A] {
    def isEmpty: Boolean = false
  }

  case object Nils extends Tree[Nothing] {
    def item  = throw new Exception("Nils: item is not member")
    def left  = throw new Exception("Nils: left is not member")
    def right = throw new Exception("Nils: right is not member")
    def isEmpty: Boolean = true
  }

  // TailCalls の使用例
  def isEven(n: Int): TailRec[Boolean] =
    if (n == 0) done(true)
    else tailcall(isOdd(n - 1))

  def isOdd(n: Int): TailRec[Boolean] =
    if (n == 0) done(false)
    else tailcall(isEven(n - 1))

  def factCps1(n: Int, cont: TailRec[BigInt] => TailRec[BigInt]): TailRec[BigInt] =
    if (n == 0) cont(done(1))
    else tailcall(factCps1(n - 1, x => cont(done(x.result * n))))

  def fiboCps1(n: Int, cont: TailRec[BigInt] => TailRec[BigInt]): TailRec[BigInt] =
    if (n == 0 || n == 1) cont(done(n))
    else tailcall(fiboCps1(n - 1, x => tailcall(fiboCps1(n - 2, y => cont(done(x.result + y.result))))))
}

Appendix: 末尾再帰と繰り返し

ここで「末尾再帰」についてもう少し深く考えてみましょう。末尾再帰の「末尾」とは、関数の最後で行われる処理のことです。とくに末尾で関数を呼び出すことを「末尾呼び出し (tail call)」といいます。関数を呼び出す場合、返ってきたあとに行う処理のため、必要な情報を保存しておかなければいけません。ところが、末尾呼び出しはそのあとに実行する処理がありません。呼び出したあと元に戻ってくる必要さえないのです。

このため、末尾呼び出しはわざわざ関数を呼び出す必要はなく、アセンブリ言語のような低水準のレベルではジャンプ命令に変換することができます。これを「末尾呼び出し最適化 (tail call optimization)」とか「末尾最適化」といいます。とくに末尾再帰は末尾で自分自身を呼び出しているので、関数の中で繰り返しに変換することができます。

また、相互再帰やもっと複雑な再帰呼び出しの場合でも、末尾最適化を適用することで、繰り返しに変換できる場合もあります。このように、再帰プログラムを繰り返しに変換してから実行することを「末尾再帰最適化 (tail recursion optimization)」といいます。厳密にいうと末尾最適化なのですが、一般的には末尾再帰最適化と呼ばれることが多いようです。

末尾再帰最適化を行うプログラミング言語、たとえば Scheme の場合、次に示すような関数呼び出しは、スタックを消費せずに実行することができます。処理系は Gauche を使いました。

gosh> (define (foo) (foo))
foo
gosh> (foo)
=> 無限ループになる

これは Scala でも同様に実行することができますが、M.Hiroi の環境では CTRL-C でブレークすることができなかったので注意してください。

もうひとつ簡単な例を示しましょう。C言語で階乗を計算する関数 fact を作ります。

リスト : 末尾再帰を繰り返しに変換する (C言語)

/* 末尾再帰 */
int fact(int n, int a)
{
  if(n == 0){
    return a;
  } else {
    return fact(n - 1, a * n);
  }
}

/* 繰り返し */
int facti(int n, int a)
{
loop:
  if(n == 0) return a;
  a *= n;
  n--;
  goto loop;
}

fact は末尾再帰になっています。これを繰り返しに変換すると facti のようになります。引数 n と a の値を保存する必要がないので、n と a の値を書き換えてから goto 文で先頭の処理へジャンプするだけです。最近はC言語でも末尾再帰最適化を行う処理系 (GCC など) があるようです。

●末尾再帰をスタックオーバーフローせずに実行する

Scala の場合、末尾再帰最適化ができなければ繰り返しでプログラムすることになりますが、CPS は末尾再帰でプログラムしたほうが簡単です。この場合、TailCalls を使うことになりますが、その基本的な考え方は簡単なので、少々不恰好でもよければ私たちでもプログラムすることができます。

末尾再帰の場合、再帰呼び出しのあとに行う処理は存在せず、関数の返り値をそのまま返すだけです。この返り値のかわりに、関数呼び出しの部分を遅延評価して、それをオブジェクトに格納して返すこともできます。ここで実行中の処理を中断することができます。そして、オブジェクトに格納された処理を評価すると、中断された処理を再開することができます。ようするに、遅延ストリームと同じような処理になります。

プログラムは次のようになります。

リスト : 末尾再帰の計算

// 抽象クラス
abstract class TailCall[A] {
  def result: A
  def call: TailCall[A]
  def isDone: Boolean
}

// 計算終了
class Done[A](val result: A) extends TailCall[A] {
  def call = throw new Exception("Done: call not member")
  def isDone: Boolean = true
}

// 計算中
class Call[A](f: => TailCall[A]) extends TailCall[A] {
  lazy val func = f
  def result = throw new Exception("Call: result not member")
  def call: TailCall[A] = func
  def isDone: Boolean = false
}

クラス TailCall は末尾再帰の計算を表します。TailCalls の TailRec に相当するクラスです。クラス Done は計算終了を表すクラスで、計算結果をフィールド変数 result に格納します。Call は計算中の処理を格納するクラスです。遅延評価で実行する処理を受け取り、フィールド変数 func にセットします。メソッド call は遅延評価した処理 func を実行します。

次に、TailCall を使って末尾再帰を実行する関数を作ります。

リスト : 末尾再帰を実行する

  // 値を返す
  def done[A](x: A): TailCall[A] = new Done(x)

  // Call オブジェクトを返す
  def tailcall[A](f: => TailCall[A]): TailCall[A] = new Call(f)

  // 実行
  def exec[A](f: TailCall[A]): A = {
    var fn = f
    while (!fn.isDone) fn = fn.call
    fn.result
  }

関数 done は値を返すために使います。関数 tailcall は末尾再帰するときに使います。done は Done のオブジェクトを、tailcall は Call のオブジェクトを生成して返すだけです。関数 exec は末尾再帰で書かれたプログラムを実行します。引数 f を mutable な変数 fn にセットします。あとは、計算が終わるまで fn.call を呼び出して fn の値を更新するだけです。fn.call の返り値は Done か Call のオブジェクトです。Done が帰ってくれば計算終了、Call が返ってくれば tailcall を使った再帰呼び出しであることがわかります。その場合は、計算を続行すればいいわけです。

簡単な実行例を示します。

リスト : TailCall の簡単な例題

  // 相互再帰
  def isEven(n: Int): TailCall[Boolean] =
    if (n == 0) done(true)
    else tailcall(isOdd(n - 1))

  def isOdd(n: Int): TailCall[Boolean] = 
    if (n == 0) done(false)
    else tailcall(isEven(n - 1))

  // 階乗
  def factCps1(n: Int, cont: TailCall[BigInt] => TailCall[BigInt]): TailCall[BigInt] =
    if (n == 0) cont(done(1))
    else tailcall(factCps1(n - 1, x => cont(done(x.result * n))))

  // フィボナッチ関数
  def fiboCps1(n: Int, cont: TailCall[BigInt] => TailCall[BigInt]): TailCall[BigInt] =
    if (n == 0 || n == 1) cont(done(n))
    else tailcall(fiboCps1(n - 1, x => tailcall(fiboCps1(n - 2, y => cont(done(x.result + y.result))))))
scala> exec(isEven(1000))
val res0: Boolean = true

scala> exec(isOdd(10000))
val res1: Boolean = false

scala> exec(factCps1(100, x => x))
val res2: BigInt = 93326215443944152681699238856266700490715968264381621468592963895
217599993229915608941463976156518286253697920827223758251185210916864000000000000000
000000000

scala> exec(fiboCps1(16, x => x))
val res3: BigInt = 987

scala> exec(fiboCps1(20, x => x))
val res4: BigInt = 6765

scala> exec(fiboCps1(30, x => x))
val res5: BigInt = 832040

tailcall と done の使い方は TailCalls と同じです。関数を実行するときは exec を使うので、TailCalls よりも使い勝手は少々悪いと思いますが、スタックオーバーフローせずにプログラムを実行することができます。


●プログラムリスト2

//
// sample2301.scala : 末尾再帰を繰り返しのように実行する
//
//                    Copyright (C) 2014-2021 Makoto Hiroi
//

// 抽象クラス
abstract class TailCall[A] {
  def result: A
  def call: TailCall[A]
  def isDone: Boolean
}

// 計算終了
class Done[A](val result: A) extends TailCall[A] {
  def call = throw new Exception("Done: call not member")
  def isDone: Boolean = true
}

// 計算中
class Call[A](f: => TailCall[A]) extends TailCall[A] {
  lazy val func = f
  def result = throw new Exception("Call: result not member")
  def call: TailCall[A] = func
  def isDone: Boolean = false
}

object sample2301 {
  // 値を返す
  def done[A](x: A): TailCall[A] = new Done(x)

  // Call オブジェクトを返す
  def tailcall[A](f: => TailCall[A]): TailCall[A] = new Call(f)

  // 実行
  def exec[A](f: TailCall[A]): A = {
    var fn = f
    while (!fn.isDone) fn = fn.call
    fn.result
  }

  def isEven(n: Int): TailCall[Boolean] =
    if (n == 0) done(true)
    else tailcall(isOdd(n - 1))

  def isOdd(n: Int): TailCall[Boolean] = 
    if (n == 0) done(false)
    else tailcall(isEven(n - 1))

  def factCps1(n: Int, cont: TailCall[BigInt] => TailCall[BigInt]): TailCall[BigInt] =
    if (n == 0) cont(done(1))
    else tailcall(factCps1(n - 1, x => cont(done(x.result * n))))

  def fiboCps1(n: Int, cont: TailCall[BigInt] => TailCall[BigInt]): TailCall[BigInt] =
    if (n == 0 || n == 1) cont(done(n))
    else tailcall(fiboCps1(n - 1, x => tailcall(fiboCps1(n - 2, y => cont(done(x.result + y.result))))))

}

初版 2014 年 10 月 25 日
改訂 2021 年 4 月 4 日

Copyright (C) 2014-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | Scala | NextPage ]