M.Hiroi's Home Page

Functional Programming

お気楽 Standard ML of New Jersey 入門

[ PrevPage | SML/NJ | NextPage ]

関数型電卓プログラムの作成 (6)

今回は電卓プログラムに「継続 (continuation)」の機能を追加してみましょう。電卓プログラムで継続を扱う場合、式を評価する関数 eval_expr を「継続渡しスタイル (CPS)」で書き直す必要があります。この修正が少々面倒な作業になりますが、継続そのものは簡単に実装することができます。CPS については、拙作のページ 継続渡しスタイル をお読みください。

●継続の使い方

今回作成する電卓プログラムでは、継続を Scheme と同じ方法で取り扱うことにします。SML/NJ の継続とは使い方が少し異なるので、簡単に説明しておきましょう。

継続を取り出すには callcc を使います。callcc には関数をひとつ渡します。callcc に渡される関数は引数がひとつで、その引数に callcc が取り出した継続が渡されます。callcc はその関数を評価し、その結果が callcc の返り値になります。

Scheme の仕様書 (R5RS) によると、継続は引数を一つ取る関数で表されます。引数を渡して継続を評価すると、今までの処理を破棄して、callcc で取り出された残りの計算 (継続) を実行します。このとき、継続に渡した引数が callcc の返り値になります。

簡単な例を示しましょう。

Calc> callcc(fn(k) k end);
<Function>
Calc> 1 + 2 * callcc(fn(k) 3 end);
7
Calc> 1 + 2 * callcc(fn(k) k(4), 3 end);
9

最初の例では、匿名関数の引数 k に継続が渡されます。匿名関数は k をそのまま返しているので、callcc の返り値は取り出された継続になります。関数型電卓プログラムの場合、継続は <Function> と表記され、継続が関数として実装されていることがわかります。

次の例を見てください。callcc によって取り出される継続は、callcc の返り値を 2 倍して、その結果に 1 を加えるという処理になります。callcc の返り値を X とすると、継続は 1 + 2 * X という式で表すことができます。匿名関数では継続を評価せずに 3 をそのまま返しているので、1 + 2 * 3 をそのまま計算して値は 7 になります。

最後の例では、匿名関数の中で k(4) を評価しています。継続を評価しているので、現在の処理を破棄して、取り出した継続 1 + 2 * X を評価します。したがって、匿名関数で k(4) の後ろにある 3 を返す処理は実行されません。X の値は継続の引数 4 になるので、1 + 2 * 4 を評価して値は 9 になります。

継続を変数に保存しておいて、あとから実行することもできます。次の例を見てくください。

Calc> 1 + 2 * callcc(fn(k) c = k, 3 end);
7
Calc> c(10);
21
Calc> c(100);
201

匿名関数の中で取り出した継続を大域変数 c に保存します。継続で行う処理は 1 + 2 * X なので、c(10) は 1 + 2 * 10 を評価して値は 21 になります。同様に、c(100) は 1 + 2 * 100 を評価して値は 201 になります。

●継続を表すデータ型の定義

それではプログラムを作りましょう。最初に、継続を表すデータ型を定義します。

リスト : 式の定義

datatype value = Nil                           (* 空を表す値 *)
               | Integer of IntInf.int         (* 整数 *)
               | Float of real                 (* 実数 *)
               | Func of func                  (* 関数 *)
               | Pair of value ref * value ref (* 連結リスト *)
               | Vec  of value array           (* ベクタ *)
and func = F1  of value -> value
         | F2  of (value * value) -> value
         | CLO of string list * expr * (string * value ref) list
         | CT  of value -> value             (* 継続 *)
and expr = Val of value                      (* 値 *)
         | Var of string                     (* 変数 *)
         | Op1 of operator * expr            (* 単項演算子 *)
         | Op2 of operator * expr * expr     (* 二項演算子 *)
         | Ops of operator * expr * expr     (* 短絡演算子 *)
         | Sel of expr * expr * expr         (* if expr then expr else expr end *)
         | Whl of expr * expr                (* while expr do expr end *)
         | Bgn of expr list                  (* begin expr, ... end *)
         | Clo of string list * expr         (* fn (仮引数) body end *)
         | Let of string list * expr list * expr
         | Rec of string list * expr list * expr
         | Lst of expr list                  (* リストの生成 *)
         | Crv of expr list                  (* ベクタの生成 *)
         | Ref of expr * expr list           (* ベクタのアクセス *)
         | App of expr * expr list           (* 関数の適用 *)
         | Cct of expr                       (* 継続 *)

今回は CPS で継続を取り扱うので、継続の実体は SML/NJ の関数になります。データ型は value -> value になるので、func に継続を表す CT of value -> value を追加します。expr には Cct of expr を追加します。これは callcc の処理に対応します。callcc の文法を示します。

callcc式 = "callcc", "(", 式, ")".

callcc の処理は関数 factor で行います。カッコ内の式を取り出して Cct of expr の第 1 要素に格納します。実行する場合は式 expr を評価して、その値が関数 (クロージャ) であれば、その第 1 引数に取り出した継続を渡して呼び出します。

それから、token に callcc を表す CALLCC を追加します。字句解析の修正は簡単なので説明は割愛します。詳細は プログラムリスト をお読みください。

●構文解析の修正

次は関数 factor に callcc の処理を追加します。プログラムは次のようになります。

リスト : callcc の処理

and factor s =

    ・・・ 省略 ・・・

    | LET => (get_token s; make_let s)
    | LIST => (get_token s; make_list s)
    | CALLCC => (get_token s; make_ct s)

    ・・・ 省略 ・・・

(* 継続の生成 *)
and make_ct s =
    case !tokenBuff of
         Lpar => (get_token s;
                  let
                    val v = expression s
                  in
                    case !tokenBuff of
                         Rpar => (get_token s; Cct(v))
                       | _ => raise Syntax_error("')' expected")
                  end)
       | _ => raise Syntax_error("'(' expected")

トークンが CALLCC の場合、関数 make_ct を呼び出して callcc の処理を行います。make_ct では、次のトークンが Lpar であることを確認してから、expression を呼び出して式 v を取り出します。次に、トークンが Rpar であることを確認して Cct(v) を返します。とても簡単ですね。

●eval_expr の修正

次は、関数 eval_expr を CPS で書き直します。まず、継続を表す引数 cont を追加します。cont のデータ型は value -> value になります。値 v を返す場合、必ず cont(v) を評価してください。それから、eval_expr を再帰呼び出しする場合、継続 cont が途切れないように注意してください。たとえば、二項演算子の処理を CPS で書き直すと次のようになります。

リスト : 二項演算子の処理

|   eval_expr(Op2(op2, expr1, expr2), env, cont) = 
    eval_expr(
      expr1,
      env,
      fn v => eval_expr(
                expr2,
                env,
                fn w => case op2 of
                             Add => cont(eval_op(op +, op +, v, w))
                           | Sub => cont(eval_op(op -, op -, v, w))
                           | Mul => cont(eval_op(op *, op *, v, w))
                           | Quo => cont(eval_op(op div, op /, v, w))
                           | Mod => cont(eval_op_int(op mod,  v, w))
                           | EQ => cont(eval_comp(op =, Real.==, v, w))
                           | NE => cont(eval_comp(op <>, Real.!=, v, w))
                           | LT => cont(eval_comp(op <, op <, v, w))
                           | GT => cont(eval_comp(op >, op >, v, w))
                           | LE => cont(eval_comp(op <=, op <=, v, w))
                           | GE => cont(eval_comp(op >=, op >=, v, w))
                           | _  => raise Calc_run_error("Illegal operator") ))

eval_expr の第 3 引数 cont が継続です。最初に、eval_expr で expr1 を評価します。その結果は継続を表す匿名関数の引数 v に渡されます。その中で、再度 eval_expr を呼び出して expr2 を評価します。その結果は継続を表す匿名関数の引数 w に渡されます。その中で op2 をチェックして適切な関数を呼び出し、その結果を cont で返します。このように eval_expr を呼び出すときは、継続を連鎖させていくことに注意してください。

次は、継続を呼び出す App の処理を修正します。

リスト : 継続の呼び出し

|   eval_expr(App(expr, args), env, cont) = 
    let
      fun iter([], _, k) = k []
      |   iter(x::xs, env, k) =
          eval_expr(x,
                    env,
                    fn v => iter(xs, env, fn w => k (v::w)))
    in
      iter(args,
           env,
           fn vs => eval_expr(
                      expr,
                      env,
                      fn v => case v of
                                   Func(F1 f1) => (check_args_num(vs, 1);
                                                   cont(f1(hd vs)))
                                 | Func(F2 f2) => (check_args_num(vs, 2);
                                                   cont(f2(hd vs, hd (tl vs))))
                                 | Func(CLO(parm, body, clo)) =>
                                   eval_expr(body, add_binding(parm, vs, clo), cont)
                                 | Func(CT k) => (check_args_num(vs, 1);
                                                  k(hd vs))
                                 | _ => raise Calc_run_error("Not function") ))
    end

局所関数 iter は引数 args の評価を CPS で書き直したものです。まず、要素 x を eval_expr で評価します。その値は匿名関数の引数 v に渡されます。この中で iter を再帰呼び出しして次の要素を評価します。その値は匿名関数の w に渡されるので、継続 k に (v::w) を渡して呼び出せば、引数の評価結果をリストに格納して返すことができます。

App の処理は、最初に iter を呼び出して引数 args を評価します。その結果は匿名関数の引数 vs に渡されます。次に、その中で eval_expr を呼び出して expr を評価します。その結果は匿名関数 v に渡されます。その値が関数 Func であれば、関数を呼び出してその値を cont で返します。

クロージャ CLO の場合、body を eval_expr で評価してその結果を返せばいいので、第 3 引数の継続には cont を渡すだけです。継続 CT の場合、CT に格納されている継続 k を取り出し、引数 (hd vs) を渡して呼び出すだけです。ここで、cont に格納されている継続 (次に行う処理) が破棄され、継続 k に格納されている処理が実行されます。ここで cont を評価して値を返すと、継続は正常に動作しません。ご注意ください。

次は、継続を取り出す Cct の処理を追加します。

リスト : 継続の取得

|   eval_expr(Cct(expr), env, cont) =
    eval_expr(expr,
              env,
              fn f => case f of
                           Func(CLO(parm, body, clo)) =>
                           eval_expr(body, add_binding(parm, [Func(CT cont)], clo), cont)
                         | _ => raise Calc_run_error("Not Closure") )

最初に eval_expr で式 expr を評価します。その結果は匿名関数の引数 f に渡されます。f がクロージャの場合、関数の本体 body を eval_expr で評価します。このとき、add_binding で変数束縛を行いますが、実引数として継続を表す値 Func(CT cont) を渡します。これで引数に継続を渡して関数を評価することができます。

あとは、プログラムを CPS に書き換えるだけですが、ベクタを生成する処理で注意点がひとつあります。次のリストを見てください。

リスト : ベクタの生成

|   eval_expr(Crv(args), env, cont) =
    let
      fun toVector(_, [], v) = v
      |   toVector(i, x::xs, v) =
          (Array.update(v, i - 1, x); toVector(i - 1, xs, v))
      fun iter(i, [], a) =
          cont(Vec(toVector(i, a, Array.array(i, Nil))))
      |   iter(i, x::xs, a) = 
          eval_expr(x, env, fn w => iter(i + 1, xs, w::a))
    in
      iter(0, args, [])
    end

Array.array でベクタを生成するのは、引数 args を評価した後で行います。プログラムでは、引数を評価した結果をリストに格納しておいて、最後にその値を局所関数 toVecotr でベクタにコピーしています。最初にベクタを生成すると、都合の悪いことがあるのです。簡単な例を示しましょう。

Calc> a = [1, 2, callcc(fn(k) c = k, 3 end)]
[1, 2, 3]
Calc> b = a;
[1, 2, 3]
Calc> a;
[1, 2, 3]
Calc> b;
[1, 2, 3]
Calc> c(10);
[1, 2, 10]
Calc> a;
[1, 2, 10]
Calc> b;
[1, 2, 3]

変数 a にベクタを生成してセットします。そして、a の値を b にセットします。このとき、取り出される継続 c の処理は a = [1, 2, X] となるので、c(10) とすると a の値は [1, 2, 10] になります。b の値は [1, 2, 3] のままですが、もしも最初にベクタを生成すると、b の値も [1, 2, 10] になってしまいます。

つまり、処理の順番が ベクタの生成 -> 引数の評価 の場合、引数を評価するところで継続を取り出すと、継続の処理にベクタを生成する処理が含まれていないため、最初に生成したベクタを使いまわすことになるのです。これでは継続を使ったプログラム、たとえば非決定性 amb を作るときに都合が悪いので、引数を評価した後でベクタを生成することにします。

あとはプログラムを CPS に変換するだけなので説明は割愛します。手作業での変換はちょっと面倒ですが、興味のある方は プログラムリスト を読んでみてください。

●末尾再帰のチェック

今回は eval_expr を CPS に変換したので、継続以外の機能が正しく動作するかテストする必要がありますが、ここでは「末尾再帰最適化」に注目してチェックしてみることにしましょう。次のリストを見てください。

リスト : 1 から x までの合計値を求める

def sum0(n)
  if n == 0 then 0 else n + sum0(n - 1) end
end

def sum1(n, a)
  if n == 0 then a else sum1(n - 1, a + n) end
end

def sum2(n)
  let
    a = 0
  in
    while n > 0 do
      a = a + n,
      n = n - 1
    end,
    a
  end
end

関数 sum0 は末尾再帰になっていないので、大きな値を計算するとメモリを大量に消費します。関数 sum1 は末尾再帰になっているので、大きな値でもメモリを消費せずに計算することができます。関数 sum2 は while ループを使っているので、メモリを消費せずに実行できるはずです。実行結果は次のようになります。

Calc> sum0(1000000);
500000500000
Calc> sum1(1000000, 0);
500000500000
Calc> sum2(1000000);
500000500000

Windows 10, Intel Core i5-6200U 2.30GHz, SML/NJ ver 110.98 で実行した場合、どの関数でも値を求めることができました。実行時間は sum0 が約 11 秒、sum1 が約 0.52 秒、sum2 が約 0.47 秒になりました。CPS に変換するとクロージャの処理が少し増えるので、実行時間は少し遅くなるようです。

次は相互再帰の場合をチェックしてみましょう。

リスト : 相互再帰

def foo(n)
  if n == 0 then
    1
  else
    bar(n - 1)
  end
end

def bar(n)
  if n == 0 then
    0
  else
    foo(n - 1)
  end
end

電卓プログラムの末尾最適化はこのような相互再帰でも機能します。bar(1000000) を実行したところ、実行時間は約 0.38 秒でした。末尾最適化はきちんと動作しているようです。

今回はここまでです。次回は実際に継続を使って簡単なサンプルプログラムを作ってみましょう。


●プログラムリスト

(*
 * calc.sml : 電卓プログラム
 *
 *            Copyright (C) 2012-2021 Makoto Hiroi
 *
 * (1) 四則演算の実装
 * (2) 変数と組み込み関数の追加
 * (3) ユーザー定義関数の追加
 * (4) 論理演算子, 比較演算子, if の追加
 * (5) begin, while の追加
 * (6) 関数を値とし、匿名関数 (クロージャ) と let を追加
 * (7) 空リスト Nil と型述語 (isNil, isInteger, isFloat, isFunction) の追加
 * (8) 連結リストの実装
 * (9) ベクタの実装
 * (10) CPS による継続の実装
 *
 *)

open TextIO

(* 例外 *)
exception Calc_exit
exception Syntax_error of string
exception Calc_run_error of string

(* 演算子の定義 *)
datatype operator = Add | Sub | Mul | Quo | Mod | Assign
                  | NOT | AND | OR 
                  | EQ  | NE  | LT  | GT  | LE  | GE

(* 式の定義 *)
datatype value = Nil                           (* 空を表す値 *)
               | Integer of IntInf.int         (* 整数 *)
               | Float of real                 (* 実数 *)
               | Func of func                  (* 関数 *)
               | Pair of value ref * value ref (* 連結リスト *)
               | Vec  of value array           (* ベクタ *)
and func = F1  of value -> value
         | F2  of (value * value) -> value
         | CLO of string list * expr * (string * value ref) list
         | CT  of value -> value
and expr = Val of value                      (* 値 *)
         | Var of string                     (* 変数 *)
         | Op1 of operator * expr            (* 単項演算子 *)
         | Op2 of operator * expr * expr     (* 二項演算子 *)
         | Ops of operator * expr * expr     (* 短絡演算子 *)
         | Sel of expr * expr * expr         (* if expr then expr else expr end *)
         | Whl of expr * expr                (* while expr do expr end *)
         | Bgn of expr list                  (* begin expr, ... end *)
         | Clo of string list * expr         (* fn (仮引数) body end *)
         | Let of string list * expr list * expr
         | Rec of string list * expr list * expr
         | Lst of expr list                  (* リストの生成 *)
         | Crv of expr list                  (* ベクタの生成 *)
         | Ref of expr * expr list           (* ベクタのアクセス *)
         | App of expr * expr list           (* 関数の適用 *)
         | Cct of expr                       (* 継続 *)

(* トークンの定義 *)
datatype token = Value of value         (* 値 *)
               | Ident of string        (* 識別子 *)
               | Oper of operator       (* 演算子 *)
               | Lpar | Rpar            (* (, ) *)
               | Lbra | Rbra            (* [, ] *)
               | Semic                  (* ; *)
               | Comma                  (* , *)
               | DEF                    (* def *)
               | END                    (* end *)
               | IF                     (* if *)
               | THEN                   (* then *)
               | ELSE                   (* else *)
               | WHL                    (* while *)
               | DO                     (* do *)
               | BGN                    (* begin *)
               | FN                     (* fn *)
               | LET                    (* let *)
               | IN                     (* in *)
               | REC                    (* rec *)
               | LIST                   (* list *)
               | CALLCC                 (* callcc *)
               | Quit                   (* 終了 *)
               | Others                 (* その他 *)


(* value を real に変換 *)
fun toReal(Float(v)) = v
|   toReal(Integer(v)) = Real.fromLargeInt(v)
|   toReal(_) = raise Calc_run_error("Not Number")

(* 関数を呼び出す *)
fun call_real_func1 f v = Float(f(toReal v))
fun call_real_func2 f (v, w) = Float(f(toReal v, toReal w))

(* 値の表示 *)
fun print_value x =
    case x of
         Nil => (print "()"; Nil)
       | Integer(n) => (print(IntInf.toString(n)); Nil)
       | Float(n) => (print(Real.toString(n)); Nil)
       | Func(_) => (print "<Function>"; Nil)
       | Pair(_) => (print "("; print_pair(x); print ")"; Nil)
       | Vec(_) => (print "["; print_vector(x); print "]"; Nil)
and print_pair(Pair(ref x, ref y)) = (
      print_value x;
      case y of
           Nil => Nil
         | Pair(_, _) => (print " "; print_pair y)
         | _ => (print " . "; print_value y)
    )
|   print_pair x = print_value x
and print_vector(Vec(v)) =
    let
      val i = ref 0
      val k = Array.length(v)
    in
      while !i < k - 1 do (
        print_value(Array.sub(v, !i));
        print ", ";
        i := !i + 1
      );
      print_value(Array.sub(v, !i));
      Nil
    end
|   print_vector x = print_value x

(* 文字の表示 *)
fun print_char(n as Integer(x)) = (
      output1(stdOut, chr(IntInf.toInt(x))); Nil
    )
|   print_char(_) = raise Calc_run_error("Not Integer")

(* 型チェック *)
val True = Integer(1)
val False = Integer(0)

fun isNil(Nil) = True
|   isNil(_) = False

fun isInteger(Integer(_)) = True
|   isInteger(_) = False

fun isFloat(Float(_)) = True
|   isFloat(_) = False

fun isFunction(Func(_)) = True
|   isFunction(_) = False

fun isPair(Pair(_, _)) = True
|   isPair(_) = False

fun isVector(Vec(_)) = True
|   isVector(_) = False

(* 連結リストの基本関数 *)
fun car(Pair(ref x, _)) = x
|   car(_) = raise Calc_run_error("Not Pair")

fun cdr(Pair(_, ref y)) = y
|   cdr(_) = raise Calc_run_error("Not Pair")

fun cons(x, y) = Pair(ref x, ref y)

fun setCar(Pair(x, _), z) = (x := z; z)
|   setCar(_, _) = raise Calc_run_error("Not Pair")

fun setCdr(Pair(_, y), z) = (y := z; z)
|   setCdr(_, _) = raise Calc_run_error("Not Pair")

(* ベクタの生成 *)
fun make_vector(Integer(size), v) =
    Vec(Array.array(IntInf.toInt(size), v))
|   make_vector(_, _) = raise Calc_run_error("Not Integer")

(* ベクタの大きさ *)
fun vector_length(Vec(v)) =
    Integer(IntInf.fromInt(Array.length(v)))
|   vector_length(_) = raise Calc_run_error("Not Vector")

(* 大域変数 *)
val global_env = ref [("sqrt",  ref (Func(F1(call_real_func1 Math.sqrt)))),
                      ("sin",   ref (Func(F1(call_real_func1 Math.sin)))),
                      ("cos",   ref (Func(F1(call_real_func1 Math.cos)))),
                      ("tan",   ref (Func(F1(call_real_func1 Math.tan)))),
                      ("asin",  ref (Func(F1(call_real_func1 Math.asin)))),
                      ("acos",  ref (Func(F1(call_real_func1 Math.acos)))),
                      ("atan",  ref (Func(F1(call_real_func1 Math.atan)))),
                      ("atan2", ref (Func(F2(call_real_func2 Math.atan2)))),
                      ("exp",   ref (Func(F1(call_real_func1 Math.exp)))),
                      ("pow",   ref (Func(F2(call_real_func2 Math.pow)))),
                      ("ln",    ref (Func(F1(call_real_func1 Math.ln)))),
                      ("log10", ref (Func(F1(call_real_func1 Math.log10)))),
                      ("sinh",  ref (Func(F1(call_real_func1 Math.sinh)))),
                      ("cosh",  ref (Func(F1(call_real_func1 Math.cosh)))),
                      ("tanh",  ref (Func(F1(call_real_func1 Math.tanh)))),
                      ("print",      ref (Func(F1 print_value))),
                      ("putc",       ref (Func(F1 print_char))),
                      ("isNil",      ref (Func(F1 isNil))),
                      ("isInteger",  ref (Func(F1 isInteger))),
                      ("isFloat",    ref (Func(F1 isFloat))),
                      ("isFunction", ref (Func(F1 isFunction))),
                      ("isPair",     ref (Func(F1 isPair))),
                      ("isVector",   ref (Func(F1 isVector))),
                      ("car",        ref (Func(F1 car))),
                      ("cdr",        ref (Func(F1 cdr))),
                      ("cons",       ref (Func(F2 cons))),
                      ("setCar",     ref (Func(F2 setCar))),
                      ("setCdr",     ref (Func(F2 setCdr))),
                      ("makeVector", ref (Func(F2 make_vector))),
                      ("len",        ref (Func(F1 vector_length))),
                      ("nil",        ref Nil)]

(* 探索 *)
fun lookup name =
    let
      fun iter [] = NONE
      |   iter ((x as (n, _))::xs) =
          if n = name then SOME x else iter xs
    in
      iter(!global_env)
    end

(* 追加 *)
fun update(name, value) = 
    global_env := (name, ref value)::(!global_env)

(* 切り出したトークンを格納するバッファ *)
val tokenBuff = ref Others

(* 整数の切り出し *)
fun get_number s =
    let
      val buff = ref []
      fun get_numeric() =
          let val c = valOf(lookahead s) in
            if Char.isDigit(c) then (
              buff := valOf(input1 s) :: (!buff);
              get_numeric()
            ) else ()
          end
      fun check_float(c) =
          case c of
            #"." => true
          | #"e" => true
          | #"E" => true
          | _ => false
    in
      get_numeric();    (* 整数部の取得 *)
      if check_float(valOf(lookahead s)) then (
        if valOf(lookahead s) = #"." then (
          (* 小数部の取得 *)
          buff := valOf(input1 s) :: (!buff);
          get_numeric()
        ) else ();
        if Char.toUpper(valOf(lookahead s)) = #"E" then (
          (* 指数形式 *)
          buff := valOf(input1 s) :: (!buff);
          let val c = valOf(lookahead s) in
            if c = #"+" orelse c = #"-" then
              buff := (valOf(input1 s)) :: (!buff)
            else ()
          end;
          get_numeric()
        ) else ();
        tokenBuff := Value(Float(valOf(Real.fromString(implode(rev (!buff))))))
      ) else
        tokenBuff := Value(Integer(valOf(IntInf.fromString(implode(rev (!buff))))))
    end

(* 識別子の切り出し *)
fun get_ident s =
    let fun iter a =
      if Char.isAlphaNum(valOf(lookahead s)) then
        iter ((valOf(input1 s)) :: a)
      else Ident(implode(rev a))
    in
      iter []
    end

(* トークンの切り出し *)
fun get_token s =
    let val c = valOf(lookahead s) in
      if Char.isSpace(c) then (input1 s; get_token s)
      else if Char.isDigit(c) then get_number s
      else if Char.isAlpha(c) then
        let val (id as Ident(name)) = get_ident s in
          tokenBuff := (
            case name of 
                 "quit" => Quit
               | "def"  => DEF
               | "end"  => END
               | "not"  => Oper(NOT)
               | "and"  => Oper(AND)
               | "or"   => Oper(OR)
               | "if"   => IF
               | "then" => THEN
               | "else" => ELSE
               | "while" => WHL
               | "do"    => DO
               | "begin" => BGN
               | "fn"    => FN
               | "let"   => LET
               | "in"    => IN
               | "rec"   => REC
               | "list"  => LIST
               | "callcc" => CALLCC
               | _        => id
          )
        end
      else if c = #"#" then (inputLine s; get_token s)
      else (
        input1 s; (* s から c を取り除く *)
        tokenBuff := (case c of
            #"+" => Oper(Add)
          | #"-" => Oper(Sub)
          | #"*" => Oper(Mul)
          | #"/" => Oper(Quo)
          | #"%" => Oper(Mod)
          | #"=" => (case valOf(lookahead s) of
                          #"=" => (input1 s; Oper(EQ))
                        | _ => Oper(Assign))
          | #"!" => (case valOf(lookahead s) of
                          #"=" => (input1 s; Oper(NE))
                        | _ => Oper(NOT))
          | #"<" => (case valOf(lookahead s) of
                          #"=" => (input1 s; Oper(LE))
                        | _ => Oper(LT))
          | #">" => (case valOf(lookahead s) of
                          #"=" => (input1 s; Oper(GE))
                        | _ => Oper(GT))
          | #"(" => Lpar
          | #")" => Rpar
          | #"[" => Lbra
          | #"]" => Rbra
          | #";" => Semic
          | #"," => Comma
          | _    => Others
        )
      )
    end

(* 構文木の組み立て *)
fun expression s =
    let
      fun iter v =
        case !tokenBuff of
             Oper(Assign) => (
               case v of
                    (Var(_) | Ref(_)) => (get_token s;
                                          Op2(Assign, v, expression s))
                  | _ => raise Syntax_error("invalid assign form")
             )
           | _ => v
    in
      iter(expr1 s)
    end
(* 論理演算子 and, or の処理 *)
and expr1 s =
    let
      fun iter v =
          case !tokenBuff of
               Oper(AND) => (get_token s; iter(Ops(AND, v, expr2 s)))
             | Oper(OR)  => (get_token s; iter(Ops(OR,  v, expr2 s)))
             | _ => v
    in
      iter(expr2 s)
    end
(* 比較演算子の処理 *)
and expr2 s =
    let
      fun iter v =
          case !tokenBuff of
               Oper(EQ) => (get_token s; iter(Op2(EQ, v, expr3 s)))
             | Oper(NE) => (get_token s; iter(Op2(NE, v, expr3 s)))
             | Oper(LT) => (get_token s; iter(Op2(LT, v, expr3 s)))
             | Oper(GT) => (get_token s; iter(Op2(GT, v, expr3 s)))
             | Oper(LE) => (get_token s; iter(Op2(LE, v, expr3 s)))
             | Oper(GE) => (get_token s; iter(Op2(GE, v, expr3 s)))
             | _ => v
    in
      iter(expr3 s)
    end
and expr3 s =
    let
      fun iter v =
          case !tokenBuff of
            Oper(Add) => (get_token s; iter(Op2(Add, v, term s)))
          | Oper(Sub) => (get_token s; iter(Op2(Sub, v, term s)))
          | _ => v
    in
      iter (term s)
    end
and term s =
    let
      fun iter v =
          case !tokenBuff of
            Oper(Mul) => (get_token s; iter(Op2(Mul, v, factor s)))
          | Oper(Quo) => (get_token s; iter(Op2(Quo, v, factor s)))
          | Oper(Mod) => (get_token s; iter(Op2(Mod, v, factor s)))
          | _ => v
    in
      iter (factor s)
    end
and factor s =
    case !tokenBuff of
      Lpar => (
          get_token s;
          let
            val v = expression s
          in
            case !tokenBuff of
              Rpar => (get_token s; v)
            | _ => raise Syntax_error("')' expected")
          end

        )
    | Lbra => (
          get_token s;
          let val args = get_comma_list(s, []) in
            case !tokenBuff of
                 Rbra => (get_token s; Crv(args))
               | _ => raise Syntax_error("']' expected")
          end
        )
    | Value(n) => (get_token s; Val(n))
    | Quit => raise Calc_exit
    | IF => (get_token s; make_sel s)
    | WHL => (get_token s; make_while s)
    | BGN => (get_token s; make_begin s)
    | FN  => (get_token s; make_clo s)
    | LET => (get_token s; make_let s)
    | LIST => (get_token s; make_list s)
    | CALLCC => (get_token s; make_ct s)
    | Oper(NOT) => (get_token s; Op1(NOT, factor s))
    | Oper(Sub) => (get_token s; Op1(Sub, factor s))
    | Oper(Add) => (get_token s; Op1(Add, factor s))
    | Ident(name) => (
        get_token s;
        case !tokenBuff of
             Lpar => App(Var(name), get_argument s)
           | Lbra => Ref(Var(name), get_index s)
           | _ => Var(name)
      )
    | _ => raise Syntax_error("unexpected token")
(* カンマで区切られた式を取得 *)
and get_comma_list(s, a) =
    let val v = expression s in
      case !tokenBuff of
           Comma => (get_token s; get_comma_list(s, v::a))
         | _ => rev(v::a)
    end
(* 引数の取得 *)
and get_argument s =
    case !tokenBuff of
         Lpar => (get_token s;
                  case !tokenBuff of
                       Rpar => (get_token s; [])
                     | _ => let val args = get_comma_list(s, []) in
                              case !tokenBuff of
                                   Rpar => (get_token s; args)
                                 | _ => raise Syntax_error("unexpected token")
                            end)
       | _ => raise Syntax_error("'(' expected")
(* 仮引数の取得 *)
and get_parameter s =
    let val parm = get_argument s in
      map (fn x => case x of
                        Var(name) => name
                      | _ => raise Syntax_error("bad parameter"))
          parm
    end
(* if *)
and make_sel s =
    let val test_form = expression s in
      case !tokenBuff of
           THEN => (
             get_token s;
             let val then_form = get_comma_list(s, []) in
               case !tokenBuff of
                    ELSE => (
                      get_token s;
                      let val else_form = get_comma_list(s, []) in
                        case !tokenBuff of
                             END => (get_token s;
                                     Sel(test_form, Bgn(then_form), Bgn(else_form)))
                           | _ => raise Syntax_error("end expected")
                      end
                    )
                  | END => (get_token s;
                            Sel(test_form, Bgn(then_form), Val(False)))
                  | _ => raise Syntax_error("else or end expected")
             end
           )
         | _ => raise Syntax_error("then expected")
    end
(* while *)
and make_while s = 
    let val test_form = expression s in
      case !tokenBuff of
           DO => (get_token s; Whl(test_form, make_begin s))
         | _ => raise Syntax_error("do expected")
    end
(* begin *)
and make_begin s =
    let
      val body = get_comma_list(s, [])
    in
      case !tokenBuff of
           END => (get_token s; Bgn(body))
         | _ => raise Syntax_error("end expected")
    end
(* closure *)
and make_clo s =
    let
      val args = get_parameter s
      val body = make_begin s
    in
      case !tokenBuff of
           Lpar => App(Clo(args, body), get_argument s)
         | _ => Clo(args, body)
    end
and make_let s =
    let
      fun iter(a, b) =
          case !tokenBuff of
               IN => (get_token s; (a, b, make_begin s))
             | Comma => (get_token s; iter(a, b))
             | _ => let val e1 = expression s in
                      case e1 of
                           Op2(Assign, Var(x), e2) => iter(x::a, e2::b)
                         | _ => raise Syntax_error("invalid let form")
                    end
    in
      case !tokenBuff of
           REC => (get_token s; Rec(iter([], [])))
         | _ => Let(iter([], []))
    end
and make_list s =
    case !tokenBuff of
         Lpar => (get_token s;
                  let
                    val args = get_comma_list(s, [])
                  in
                    case !tokenBuff of
                         Rpar => (get_token s; Lst(args))
                       | _ => raise Syntax_error("')' expected")
                  end)
       | _ => raise Syntax_error("'(' expected")
(* ベクタの添字を取得する *)
and get_index s =
    let
      fun iter a =
          let 
            val v = expression s
          in
            case !tokenBuff of
                 Rbra => (get_token s;
                          case !tokenBuff of
                               Lbra => (get_token s; iter(v::a))
                             | _ => rev(v::a))
               | _ => raise Syntax_error("']' expected")
          end
    in
      get_token s;
      iter([])
    end
(* 継続の生成 *)
and make_ct s =
    case !tokenBuff of
         Lpar => (get_token s;
                  let
                    val v = expression s
                  in
                    case !tokenBuff of
                         Rpar => (get_token s; Cct(v))
                       | _ => raise Syntax_error("')' expected")
                  end)
       | _ => raise Syntax_error("'(' expected")

(* 変数束縛 *)
fun add_binding([], _, a) = a
|   add_binding(_, [], _) = raise Calc_run_error("Not enough argument")
|   add_binding(name::ps, x::xs, a) = add_binding(ps, xs, (name, ref x)::a)

fun check_args_num(args, n) =
    if length(args) < n
    then raise Calc_run_error("Not enough argument")
    else ()

(* 変数を求める *)
fun get_var(name, []) = lookup(name)
|   get_var(name, (x as (n, _))::xs) =
    if name = n then SOME x else get_var(name, xs)

(* 真偽のチェック *)
fun isTrue(Float(v))  = Real.!=(v, 0.0)
|   isTrue(Integer(v)) = v <> 0
|   isTrue(Nil) = false
|   isTrue(_) = true

(* 演算子の評価 *)
fun eval_op(op1, op2, v, w) =
    case (v, w) of
         (Integer(n), Integer(m)) => Integer(op1(n, m))
       | (Integer(n), Float(m)) => Float(op2(Real.fromLargeInt(n), m))
       | (Float(n), Integer(m)) => Float(op2(n, Real.fromLargeInt(m)))
       | (Float(n), Float(m)) => Float(op2(n, m))
       | (_, _) => raise Calc_run_error("Not Number")

fun eval_op_int(op1, v, w) =
    case (v, w) of
         (Integer(n), Integer(m)) => Integer(op1(n, m))
       | (_, _) => raise Calc_run_error("Not Integer")

(* 比較演算子の評価 *)
fun eval_comp(op1, op2, v, w) =
    case (v, w) of
         (Integer(n), Integer(m)) =>
         if op1(n, m) then True else False
       | (Integer(n), Float(m)) =>
         if op2(Real.fromLargeInt(n), m) then True else False
       | (Float(n), Integer(m)) =>
         if op2(n, Real.fromLargeInt(m)) then True else False
       | (Float(n), Float(m)) =>
         if op2(n, m) then True else False
       | (_, _) => raise Calc_run_error("Not Number")

(* ベクタの更新 *)
fun update_vector(Vec(v), [x], w) = (Array.update(v, x, w); w)
|   update_vector(Vec(v), x::xs, w) = update_vector(Array.sub(v, x), xs, w)
|   update_vector(_, _, _) = raise Calc_run_error("Not Vector")

(* ベクタの値を取得 *)
fun get_vector(Vec(v), [x]) = Array.sub(v, x)
|   get_vector(Vec(v), x::xs) = get_vector(Array.sub(v, x), xs)
|   get_vector(_, _) = raise Calc_run_error("Not Vector")

(* 式の評価 *)
fun eval_expr(Val(n), _, cont) = cont n
|   eval_expr(Var(name), env, cont) = (
      case get_var(name, env) of
           NONE => raise Calc_run_error("Unbound variable: " ^ name)
         | SOME (_, ref v) => cont v
    )
|   eval_expr(Ref(expr, args), env, cont) =
    eval_expr(
      expr,
      env,
      fn v => eval_index(args,
                         env,
                         fn a => cont(get_vector(v, a))))
|   eval_expr(Op2(Assign, expr1, expr2), env, cont) =
    eval_expr(
      expr2,
      env,
      fn w => case expr1 of
                   Var(name) =>
                     (case get_var(name, env) of
                           NONE => (update(name, w); cont w)
                         | SOME (_, v) => (v := w; cont w) )
                 | Ref(expr, args) =>
                     eval_expr(
                       expr,
                       env,
                       fn v => eval_index(
                                 args,
                                 env, 
                                 fn a => (update_vector(v, a, w); cont w)))
                 | _ => raise Calc_run_error("Illegal assign form") )
|   eval_expr(Op2(op2, expr1, expr2), env, cont) = 
    eval_expr(
      expr1,
      env,
      fn v => eval_expr(
                expr2,
                env,
                fn w => case op2 of
                             Add => cont(eval_op(op +, op +, v, w))
                           | Sub => cont(eval_op(op -, op -, v, w))
                           | Mul => cont(eval_op(op *, op *, v, w))
                           | Quo => cont(eval_op(op div, op /, v, w))
                           | Mod => cont(eval_op_int(op mod,  v, w))
                           | EQ => cont(eval_comp(op =, Real.==, v, w))
                           | NE => cont(eval_comp(op <>, Real.!=, v, w))
                           | LT => cont(eval_comp(op <, op <, v, w))
                           | GT => cont(eval_comp(op >, op >, v, w))
                           | LE => cont(eval_comp(op <=, op <=, v, w))
                           | GE => cont(eval_comp(op >=, op >=, v, w))
                           | _  => raise Calc_run_error("Illegal operator") ))
|   eval_expr(Op1(op1, expr1), env, cont) =
    eval_expr(
      expr1,
      env,
      fn v => case (op1, v) of
                   (Add, _) => cont v
                 | (Sub, Integer(n)) => cont(Integer(~n))
                 | (Sub, Float(n)) => cont(Float(~n))
                 | (NOT, _) => cont(if isTrue(v) then False else True)
                 | _ => raise Calc_run_error("Illegal expression") )
|   eval_expr(Ops(ops, expr1, expr2), env, cont) =
    eval_expr(
     expr1,
     env,
     fn v => case ops of
                  AND => if isTrue(v)
                         then eval_expr(expr2, env, cont)
                         else cont v
                | OR  => if isTrue(v)
                         then cont v
                         else eval_expr(expr2, env, cont)
                | _   => raise Calc_run_error("Illegal operator") )
|   eval_expr(Sel(expr_c, expr_t, expr_e), env, cont) =
    eval_expr(
      expr_c,
      env,
      fn v => if isTrue(v)
              then eval_expr(expr_t, env, cont)
              else eval_expr(expr_e, env, cont) )
|   eval_expr(App(expr, args), env, cont) = 
    let
      fun iter([], _, k) = k []
      |   iter(x::xs, env, k) =
          eval_expr(x,
                    env,
                    fn v => iter(xs, env, fn w => k (v::w)))
    in
      iter(args,
           env,
           fn vs => eval_expr(
                      expr,
                      env,
                      fn v => case v of
                                   Func(F1 f1) => (check_args_num(vs, 1);
                                                   cont(f1(hd vs)))
                                 | Func(F2 f2) => (check_args_num(vs, 2);
                                                   cont(f2(hd vs, hd (tl vs))))
                                 | Func(CLO(parm, body, clo)) =>
                                   eval_expr(body, add_binding(parm, vs, clo), cont)
                                 | Func(CT k) => (check_args_num(vs, 1);
                                                  k(hd vs))
                                 | _ => raise Calc_run_error("Not function") ))
    end
|   eval_expr(Cct(expr), env, cont) =
    eval_expr(expr,
              env,
              fn f => case f of
                           Func(CLO(parm, body, clo)) =>
                           eval_expr(body, add_binding(parm, [Func(CT cont)], clo), cont)
                         | _ => raise Calc_run_error("Not Closure") )
|   eval_expr(Whl(expr_c, expr_b), env, cont) = 
    let
      fun iter () =
          eval_expr(expr_c,
                    env,
                    fn v => if isTrue(v)
                            then eval_expr(expr_b, env, fn _ => iter ())
                            else cont False)
    in
      iter ()
    end
|   eval_expr(Bgn(xs), env, cont) =
    let
      fun iter [] = raise Calc_run_error("ivalid begin form")
      |   iter [x] = eval_expr(x, env, cont)
      |   iter (x::xs) = eval_expr(x, env, fn _ => iter(xs))
    in
      iter(xs)
    end
|   eval_expr(Clo(args, expr), env, cont) = cont(Func(CLO(args, expr, env)))
|   eval_expr(Let(parm, args, body), env, cont) =
    let
      fun iter([], [], a) = eval_expr(body, a, cont)
      |   iter(n::ns, e::es, a) =
          eval_expr(e, env, fn v => iter(ns, es, (n, ref v)::a))
    in
      iter(parm, args, env)
    end
|   eval_expr(Rec(parm, args, body), env, cont) =
    let
      val new_env = foldl (fn(x, a) => (x, ref Nil)::a) env parm
      fun iter([], []) = eval_expr(body, new_env, cont)
      |   iter(n::ns, e::es) =
          eval_expr(e,
                    new_env,
                    fn v => (case get_var(n, new_env) of
                                  NONE => raise Calc_run_error("let rec error")
                                | SOME(_, var) => var := v;
                             iter(ns, es)) )
    in
      iter(parm, args)
    end
|   eval_expr(Lst(args), env, cont) =
    let
      fun iter([], k) = k Nil
      |   iter(x::xs, k) =
          eval_expr(x,
                    env,
                    fn v => iter(xs, fn w => k (Pair(ref v, ref w))))
    in
      iter(args, fn v => cont v)
    end
|   eval_expr(Crv(args), env, cont) =
    let
      fun toVector(_, [], v) = v
      |   toVector(i, x::xs, v) =
          (Array.update(v, i - 1, x); toVector(i - 1, xs, v))
      fun iter(i, [], a) =
          cont(Vec(toVector(i, a, Array.array(i, Nil))))
      |   iter(i, x::xs, a) = 
          eval_expr(x, env, fn w => iter(i + 1, xs, w::a))
    in
      iter(0, args, [])
    end
(* 添字の評価 *)
and eval_index(args, env, cont) =
    let
      fun iter([], k) = k []
      |   iter(x::xs, k) =
          eval_expr(
            x,
            env,
            fn w => iter(xs, fn v => case w of
                                          Integer(n) => k(IntInf.toInt(n)::v)
                                        | _ => raise Calc_run_error("Index is not Integer")))
    in
      iter(args, fn v => cont v)
    end

(* 実行 *)
fun toplevel s = (
    get_token s;
    case !tokenBuff of
      DEF => (
        get_token s;
        case !tokenBuff of
             Ident(name) => (
               get_token s;
               let
                 val a = get_parameter s
                 val b = get_comma_list(s, [])
               in
                 case !tokenBuff of
                      END => (update(name, Func(CLO(a, Bgn(b), [])));
                              print (name ^ "\n"))
                    | _ => raise Syntax_error("end expected")
               end
             )
           | _ => raise Syntax_error("ivalid def form")
    )
    | _ => let val result = expression s in
        case !tokenBuff of
          Semic => ()
        | Quit  => raise Calc_exit
        | _ => raise Syntax_error("unexpected token");
        case eval_expr(result, [], fn x => x) of
             Nil => Nil
           | v => print_value(v);
        print "\n"
      end
)

(* ファイルのロード *)
fun load_library(filename) =
    let
      val a = openIn(filename)
    in
      (while true do toplevel(a)) handle
          Option => ()
        | Syntax_error(mes) => print("ERROR: " ^ mes ^ "\n")
        | Calc_run_error(mes) => print("ERROR: " ^ mes ^ "\n")
        | Div => print("ERROR: divide by zero\n")
        | Subscript => print("ERROR: subscript out of bounds\n")
        | err => raise err;
      closeIn(a)
    end

fun calc(filename) = (
    if filename <> "" then load_library(filename) else ();
    while true do (
      print "Calc> ";
      flushOut(stdOut);
      toplevel(stdIn) handle 
        Syntax_error(mes) => print("ERROR: " ^ mes ^ "\n")
      | Calc_run_error(mes) => print("ERROR: " ^ mes ^ "\n")
      | Div => print("ERROR: divide by zero\n")
      | Subscript => print("ERROR: subscript out of bounds\n")
      | err => raise err;
      inputLine(stdIn)
    )
)

初版 2012 年 9 月 9 日
改訂 2021 年 6 月 5 日

Copyright (C) 2012-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | SML/NJ | NextPage ]