M.Hiroi's Home Page

Functional Programming

お気楽 Scheme プログラミング入門

[ PrevPage | Scheme | NextPage ]

関数型電卓プログラム fncalc の作成 (4)

今回は fncalc に「末尾再帰最適化」を実装しましょう。

●末尾再帰とは?

末尾再帰の「末尾」とは、関数の最後で行われる処理のことです。とくに末尾で関数を呼び出すことを「末尾呼び出し (tail call)」といいます。関数を呼び出す場合、返ってきたあとに行う処理のため、必要な情報を保存しておかなければいけません。ところが、末尾呼び出しはそのあとに実行する処理がありません。呼び出したあと元に戻ってくる必要さえないのです。

このため、末尾呼び出しはわざわざ関数を呼び出す必要はなく、アセンブリ言語のような低水準のレベルではジャンプ命令に変換することができます。これを「末尾呼び出し最適化 (tail call optimization)」とか「末尾最適化」といいます。とくに末尾再帰は末尾で自分自身を呼び出しているので、関数の中で繰り返しに変換することができます。

また、相互再帰やもっと複雑な再帰呼び出しの場合でも、末尾最適化を適用することで、繰り返しに変換できる場合もあります。このように、再帰プログラムを繰り返しに変換してから実行することを「末尾再帰最適化 (tail recursion optimization)」といいます。厳密にいうと末尾最適化なのですが、一般的には末尾再帰最適化と呼ばれることが多いようです。

簡単な例を示しましょう。C言語で階乗を計算する関数 fact を作ります。

リスト : 末尾再帰を繰り返しに変換する (C言語)

/* 末尾再帰 */
int fact(int n, int a)
{
  if(n == 0){
    return a;
  } else {
    return fact(n - 1, a * n);
  }
}

/* 繰り返し */
int facti(int n, int a)
{
loop:
  if(n == 0) return a;
  a *= n;
  n--;
  goto loop;
}

fact は末尾再帰になっています。これを繰り返しに変換すると facti のようになります。引数 n と a の値を保存する必要がないので、n と a の値を書き換えてから goto 文で先頭の処理へジャンプするだけです。最近はC言語でも末尾再帰最適化を行う処理系 (GCC など) があるようです。

fncalc で用いている SECD 仮想マシンにはジャンプ命令がないので、末尾再帰を単純な繰り返しに変換することはできませんが、次に示すような関数呼び出しにおいて、メモリを消費せずに実行できるよう「最適化」を施すことは可能です。

Calc> def foo() foo(); end
=> closure
Calc> foo();
=> 無限ループになる

末尾再帰最適化が行われる場合、foo を評価すると無限ループになります。

●末尾最適化の仕組み

fncalc の場合、仮想マシン vm は末尾再帰になっています。Scheme で vm を動かす場合、Scheme は末尾再帰最適化が行われるので、vm の実行でメモリを消費することはありません。問題は命令 sel と app を実行するときです。たとえば、階乗を求める関数 fact を fncalc でコンパイルすると次のようになります。

リスト : 階乗 (末尾再帰)

def fact(n, a)
  if n == 0 then
    a;
  else
    fact(n - 1, a * n);
  end
end
(closure
  (
    ld   (0 . 0)
    ldc  0
    ==
    sel
      (
        ld   (0 . 1)
        join
      )
      (
        ld   (0 . 0)
        ldc  1
        -
        ld   (0 . 1)
        ld   (0 . 0)
        *
        args  2
        ldg  (fact . xxxx)
        app
        join
      )
    rtn
  )
())

SECD 仮想マシンは sel 命令を実行するとき、コードレジスタ C をダンプに保存します。ここでメモリが消費されます。fact の場合、if は末尾呼び出しであり、その後の命令は rtn しかありません。この場合、join を rtn に変更すると、コードレジスタ C をダンプに保存する必要がなりなります。

そこで、仮想マシンに新しい命令 selr を追加します。selr の状態遷移を図に示します。

(v . s) e (selr ct cf . c) d = v (真) => s e ct d
                             = v (偽) => s e cf d

sel はコード c をダンプ d に保存します。selr 命令の場合、ct と cf が末尾呼び出しになるので、コード c をダンプに保存する必要はありません。また、ct と cf は join 命令ではなく rtn 命令で終了します。

selr 命令を使うと、fact は次のようにコンパイルされます。

(closure
  (
    ld   (0 . 0)
    ldc  0
    ==
    selr
      (
        ld   (0 . 1)
        rtn
      )
      (
        ld   (0 . 0)
        ldc  1
        -
        ld   (0 . 1)
        ld   (0 . 0)
        *
        args  2
        ldg  (fact . xxxx)
        app
        rtn
      )
    rtn
  )
())

ここで fact を呼び出す app 命令に注目してください。fact は末尾呼び出しで、その後に実行する命令は rtn しかありません。この場合、レジスタ S, E, C をダンプレジスタに保存する必要はありません。そこで、新しい命令 tapp を追加します。tapp の状態遷移を示します。

((closure code env) vs . s) e (tapp . c) d => s (vs . env) code d

app は s, e, c をダンプ d に保存します。tapp 命令の場合、実行する関数は末尾呼び出しになるので、s, e, c をダンプに保存する必要はありません。コード code を環境 (v . env) の元で評価するだけです。

このように、selr と tapp 命令を追加することで、fncalc で末尾最適化を実現することができます。

●末尾最適化の実装

それではプログラムを作りましょう。拙作のページ micro Scheme コンパイラの作成 (4) では、末尾再帰最適化を実装するためコンパイラを修正しましたが、今回はコンパイラが出力したコードを直接修正することにします。

最適化を行う関数 optimize は次のようになります。

リスト : 最適化

(define (optimize code)
  (let loop ((code code))
    (when (pair? code)
      (cond ((or (eq? (car code) 'ld)
                 (eq? (car code) 'ldg))
             ;; スキップする
	     (loop (cddr code)))
            ((pair? (car code))
             (optimize (car code))
	     (loop (cdr code)))
	    ((and (eq? (car code) 'sel)
	          (eq? (cadddr code) 'rtn))
             ;; sel then else rtn ならば最適化
             (set-car! code 'selr)
             (set-car! (last-pair (cadr code)) 'rtn)  ; then 節
             (set-car! (last-pair (caddr code)) 'rtn) ; else 節
	     (loop (cdr code)))
	    ((and (eq? (car code) 'app)
	          (eq? (cadr code) 'rtn))
             ;; app rtn ならば最適化
             (set-car! code 'tapp)
	     (loop (cdr code)))
	    (else
	     (loop (cdr code)))))))

引数 code は関数 compile が出力したコードです。命令が ld, ldg の場合、次のコードは最適化する必要がないのでスキップします。もし、大域変数の値が関数で再帰呼び出しされている場合、コードは巡回リストになるためスキップしないと無限ループになってしまいます。ご注意ください。命令がリストの場合は optimize を再帰呼び出して最適化を行います。

命令が sel で then 節と else 節の次の命令が rtn の場合は selr に書き換えます。このとき、then 節と else 節の最後の命令は join になっているので、それを rtn に書き換えます。関数 last-pair はリストの最後のペアを返します。R7RS-small には定義されていないので、今回は自分で作りました。selr, then, else, rtn の rtn は実行されないコード (デッドコード) になるので削除してもかまいませんが、今回はそのまま放置しています。

そして、(loop (cdr code)) で次の命令を最適化します。この場合、then 節の最適化が行われ、次に else 節の最適化が行われます。join が rtn に書き換えられているので、関数が末尾呼び出しされていれば、app を tapp に書き換える処理が行われます。次に、命令が app で次の命令が rtn ならば app を tapp に書き換えます。それ以外の命令は何もしないで、(loop (cdr code)) で次の命令をチェックします。

●仮想マシンの修正

次は仮想マシン vm を修正します。

リスト : 仮想マシン vm の修正

(define (vm s e c d)
  (case (car c)
    
    ・・・・・省略・・・・・
    
    ((tapp)
     (let ((clo (car s)) (lvar (cadr s)))
       (case (pop! clo)
         ((primitive)
	  ;; (primitive function)
          (vm (cons (apply (car clo) lvar) (cddr s)) e (cdr c) d))
	 ((continuation)
	  (vm (cons (car lvar) (car clo)) (cadr clo) (caddr clo) (cadddr clo)))
         (else
	  (vm (cddr s) (cons lvar (cadr clo)) (car clo) d)))))
    
    ・・・・・省略・・・・・
    
    ((selr)
     (let ((t-clause (cadr c))
           (e-clause (caddr c)))
       (if (zero? (car s))
           (vm (cdr s) e e-clause d)
	 (vm (cdr s) e t-clause d))))
    
    ・・・・・省略・・・・・
    
  ))

tapp の場合、primitive と continuation の処理は app と同じです。それ以外の場合はコード (car clo) を環境 (cons lvar (cadr clo)) の元で評価します。s, e, c をダンプ d に保存する必要はありません。selr も簡単です。then 節 (t-clause) と else 節 (e-clause) どちらを評価するにしても、ダンプ d にコード c を保存する必要はありません。

●簡単な実行例

それでは簡単な実行例を示しましょう。1 から x までの合計値を求めるプログラムを作ります。次のリストを見てください。

リスト : 1 から x までの合計値を求める

def sum(x)
  if x == 0 then
    0;
  else
    x + sum(x - 1);
  end
end

def sum1(x, a)
  if x == 0 then
    a;
  else
    sum1(x - 1, a + x);
  end
end

関数 sum は末尾再帰になっていないので、大きな値を計算するとメモリを大量に消費します。関数 sum1 は末尾再帰になっているので、大きな値でもメモリを消費せずに計算することができます。実行結果は次のようになります。

Calc> sum(1000000);
=> 500000500000
Calc> sum1(1000000, 0);
=> 500000500000

Gauche version 0.9.10, Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz で実行した場合、どちらの関数でも値を求めることができます。実行時間は sum が 5.83 秒、sum1 が 5.20 秒になりました。末尾最適化を行うと、実行時間も少し速くなるようです。

また、次の関数を実行するとメモリを消費せずに無限ループとなります。

Calc> def foo() foo(); end
=> closure

Calc> foo();
=> 無限ループになる

●相互再帰

相互再帰とは、関数 foo が関数 bar を呼び出し、bar でも foo を呼び出すというように、お互いに再帰呼び出しを行っていることをいいます。簡単な例を示しましょう。次のリストを見てください。

リスト : 相互再帰

def foo(n)
  if n == 0 then
    1;
  else
    bar(n - 1);
  end
end

def bar(n)
  if n == 0 then
    0;
  else
    foo(n - 1);
  end
end

このプログラムは関数 foo と bar が相互再帰しています。foo と bar が何をしているのか、実際に動かしてみましょう。

Calc> foo(10);
=> 1
Calc> bar(10);
=> 0
Calc> foo(15);
=> 0
Calc> bar(15);
=> 1

結果を見ればおわかりのように、foo は n が偶数のときに真を返し、bar は n が奇数のときに真を返します。なお、このプログラムはあくまでも相互再帰の例題であり、実用的なプログラムではありません。

今回実装した末尾最適化はこのような相互再帰でも機能します。bar(1000000) を実行したところ、結果は次のようになりました。

     表 : 実行結果 (単位:秒)

                    最適化
                |  無  |  有  
    ------------+------+------
    bar(1000000)| 5.32 | 3.54

Gauche version 0.9.10, Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

最適化を行わないとメモリを大量に消費しますが、最適化を行うことでメモリを消費せず、実行速度も少し速くなります。末尾最適化の効果は十分に出ていると思います。

●たらいまわし関数

次は「たらいまわし関数」を試してみましょう。プログラムリストと実行結果を示します。

リスト : たらいまわし関数

def tarai(x, y, z)
  if x <= y then
    y;
  else
    tarai(tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y));
  end
end

def tak(x, y, z)
  if x <= y then
    z;
  else
    tak(tak(x - 1, y, z), tak(y - 1, z, x), tak(z - 1, x, y));
  end
end
  表 : たらいまわし関数の実行結果

                      最適化
                  |  無  |  有
  ----------------+------+------
  tarai(10, 5, 0) | 1.99 | 1.73
  tak  (14, 7, 0) | 2.44 | 2.14

  単位 : 秒

  Gauche version 0.9.10, Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

実行時間は最適化を行ったほうが少しだけ速くなりました。

●遅延評価による高速化

関数 tarai は「遅延評価 (delayed evaluation または lazy evaluation)」を行う処理系、たとえば関数型言語の Haskell では高速に実行することができます。また、Scheme でも delay と force を使って遅延評価を行うことができます。

tarai のプログラムを見てください。x <= y のときに y を返しますが、このとき引数 z の値は必要ありませんね。引数 z の値は x > y のときに計算するようにすれば、無駄な計算を省略することができます。なお、関数 tak は x <= y のときに z を返しているため、遅延評価で高速化することはできません。ご注意ください。

完全ではありませんが、fncalc でもクロージャを使って遅延評価を行うことができます。次のリストを見てください。

リスト : たらいまわし関数 (遅延評価)

def tarai_delay(x, y, z)
  if x <= y then
    y;
  else
    let zz = z() in
      tarai_delay(tarai_delay(x - 1, y, fn() zz; end),
                  tarai_delay(y - 1, zz, fn() x; end),
                  fn() tarai_delay(zz - 1, x, fn() y; end); end);
    end
  end
end

遅延評価したい処理をクロージャに包んで引数 z に渡します。そして、x > y のときに引数 z を評価 (関数呼び出し) します。すると、クロージャ内の処理が実行されて z の値を求めることができます。たとえば、fn() 0; end を z に渡す場合、z() とすると返り値は 0 になります。fn() x; end を渡せば、x に格納されている値が返されます。fn() tarai_delay( ... ); end を渡せば、関数 tarai_delay が実行されてその値が返されるわけです

それでは、実際に実行してみましょう。

tarai_delay(100, 50, fn() 0; end) : 0.04 秒

Gauche version 0.9.10, Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

tarai_delay に大きな値を与えても、高速に実行することができます。遅延評価の効果は十分に出ていると思います。


●プログラムリスト

;;;
;;; fncalc3.scm : 関数型電卓プログラム (R7RS-small 対応版)
;;;
;;;               Copyright (C) 2011-2021 Makoto Hiroi
;;;
(import (scheme base) (scheme cxr) (scheme char) (scheme inexact)
        (scheme bitwise) (scheme file) (scheme read) (scheme write)
        (scheme time))

;;;
;;; マクロ定義
;;;

;;; 多値は考慮しない簡略版
(define-syntax begin0
  (syntax-rules ()
    ((_ a) a)
    ((_ a b ...) (let ((x a)) (begin b ...) x))))

;;; データの追加
(define-syntax push!
  (syntax-rules ()
    ((_ place x) (set! place (cons x place)))))

;;; データの取得
(define-syntax pop!
  (syntax-rules ()
    ((_ place)
     (let ((x (car place)))
       (set! place (cdr place))
       x))))

;;; 末尾のセルを求める
(define (last-pair xs)
  (if (null? (cdr xs))
      xs
      (last-pair (cdr xs))))

;;;
;;; 大域変数
;;;
(define *ch*    #f)
(define *token* #f)
(define *value* #f)
(define *input* (current-input-port))
(define *line*  #f)
(define *col*   #f)

;;;
;;; グローバルな環境
;;;
(define *global-environment*
  `((exp     primitive ,exp)
    (log     primitive ,log)
    (sin     primitive ,sin)
    (cos     primitive ,cos)
    (tan     primitive ,tan)
    (asin    primitive ,asin)
    (acos    primitive ,acos)
    (atan    primitive ,atan)
    (sqrt    primitive ,sqrt)
    (expt    primitive ,expt)
    (number   primitive ,(lambda (x) (if (number? x) 1 0)))
    (string   primitive ,(lambda (x) (if (string? x) 1 0)))
    (function primitive ,(lambda (x) (if (pair? x) 1 0)))
    (load     primitive ,(lambda (x) (load-file x) 1))
    (display  primitive ,(lambda (x) (display (if (pair? x) (car x) x)) x))
    (newline  primitive ,(lambda ()  (newline) 0))
    (print    primitive ,(lambda (x) (display (if (pair? x) (car x) x)) (newline) x))))

;;; 大域変数を求める
(define (get-gvar sym)
  (let ((val (assoc sym *global-environment*)))
    (unless val
      (set! val (cons sym 0))
      (push! *global-environment* val))
    val))

;;;
;;; 入力処理
;;;

;;; 文字の読み込み
(define (nextch)
  (set! *ch* (read-char *input*))
  (cond ((eof-object? *ch*)
         (set! *ch* #\null))
        ((eqv? *ch* #\newline)
         (set! *line* (+ *line* 1))
         (set! *col* 0))
        (else
         (set! *col* (+ *col* 1)))))

;;; コンパイルエラー
(define (compile-error mes)
  (error mes *token* *line* *col*))

;;; 先読み記号の取得
(define (getch) *ch*)

;;; 数値
(define (get-number)
  (let ((buff '()))
    ;; 整数を buff に格納
    (define (get-numeric)
      (do ()
          ((not (char-numeric? (getch))))
        (push! buff (getch))
        (nextch)))
    ;; 整数部
    (get-numeric)
    (case (getch)
      ((#\.)
       ;; 小数部
       (push! buff (getch))
       (nextch)
       (get-numeric)
       (case (getch)
         ((#\d #\D #\e #\E)
          ;; 指数部
          (push! buff (getch))
          (nextch)
          (when (or (eqv? (getch) #\+)
                    (eqv? (getch) #\-))
            (push! buff (getch))
            (nextch))
          ;; 指数の数字
          (get-numeric))))
      ((#\/)
       ;; 分数
       (push! buff (getch))
       (nextch)
       (get-numeric)))
    (string->number (list->string (reverse buff)))))

;;; 識別子
(define (get-ident)
  (let loop ((a '()))
    (if (and (not (char-alphabetic? (getch)))
             (not (char-numeric? (getch)))
             (not (eqv? (getch) #\_)))
        (string->symbol (list->string (reverse a)))
      (loop (begin0 (cons (getch) a) (nextch))))))

;;; 文字列
(define (escape-code c)
  (case c
    ((#\t) #\tab)
    ((#\n) #\newline)
    (else c)))

(define (get-string)
  (nextch)
  (let loop ((buff '()))
    (cond ((eqv? (getch) #\")
           (nextch)
           (list->string (reverse buff)))
          ((eqv? (getch) #\\)
           ;; エスケープ記号
           (nextch)
           (loop (begin0 (cons (escape-code (getch)) buff) (nextch))))
          (else
           (loop (begin0 (cons (getch) buff) (nextch)))))))

;;; トークンの切り出し
(define (get-token)
  ;; 空白文字の読み飛ばし
  (do ()
      ((not (char-whitespace? (getch))))
    (nextch))
  (cond ((char-numeric? (getch))
         (set! *token* 'number)
         (set! *value* (get-number)))
        ((char-alphabetic? (getch))
         (set! *value* (get-ident))
         (case *value*
           ((def end if then else and or not while do begin let in fn eq callcc)
            (set! *token* *value*))
           (else
            (set! *token* 'ident))))
        (else
         (case (getch)
          ((#\#)
           ;; コメントの読み飛ばし
           (do ()
               ((eqv? (getch) #\newline))
             (nextch))
           (get-token))
          ((#\")
           ;; 文字列
           (set! *token* 'string)
           (set! *value* (get-string)))
          ((#\=)
           (set! *token* '=)
           (nextch)
           (when (eqv? (getch) #\=)
             (set! *token* '==)
             (nextch)))
          ((#\+)
           (set! *token* '+)
           (nextch))
          ((#\-)
           (set! *token* '-)
           (nextch))
          ((#\*)
           (set! *token* '*)
           (nextch))
          ((#\%)
           (set! *token* '%)
           (nextch))
          ((#\/)
           (set! *token* '/)
           (nextch))
          ((#\()
           (set! *token* 'lpar)
           (nextch))
          ((#\))
           (set! *token* 'rpar)
           (nextch))
          ((#\<)
           (set! *token* '<)
           (nextch)
           (when (eqv? (getch) #\=)
             (set! *token* '<=)
             (nextch)))
          ((#\>)
           (set! *token* '>)
           (nextch)
           (when (eqv? (getch) #\=)
             (set! *token* '>=)
             (nextch)))
          ((#\!)
           (set! *token* 'not)
           (nextch)
           (when (eqv? (getch) #\=)
             (set! *token* '!=)
             (nextch)))
          ((#\,)
           (set! *token* 'comma)
           (nextch))
          ((#\;)
           (set! *token* 'semic)
           (nextch))
          ((#\null)
           (set! *token* 'eof))
          (else
           (set! *token* 'others))))))

;;;
;;; 式の評価
;;;
(define (expression env)
  (let ((val (expr1 env)))
    (case *token*
      ((=)
       (get-token)
       (case (car val)
         ((ld)
          ;; 局所変数の代入
          (append (expression env) (list 'lset (cadr val))))
         ((ldg)
          ;; 大域変数の代入
          (append (expression env) (list 'gset (cadr val))))
         (else
          (compile-error "invalid assignment form"))))
      (else val))))

;;; 論理演算子 (and と or の優先順位は同じとする)
(define (expr1 env)
  (let loop ((val1 (expr2 env)))
    (case *token*
      ((and)
       (get-token)
       (loop (append val1 (expr2 env) (list 'and))))
      ((or)
       (get-token)
       (loop (append val1 (expr2 env) (list 'or))))
      (else val1))))

;;; 比較演算子 (==, !=, <, <=, >, >= の優先順位は同じとする)
(define (expr2 env)
  (let ((val1 (expr3 env)))
    (case *token*
      ((==)
       (get-token)
       (append val1 (expr3 env) (list '==)))
      ((!=)
       (get-token)
       (append val1 (expr3 env) (list '!=)))
      ((<)
       (get-token)
       (append val1 (expr3 env) (list '<)))
      ((<=)
       (get-token)
       (append val1 (expr3 env) (list '<=)))
      ((>)
       (get-token)
       (append val1 (expr3 env) (list '>)))
      ((>=)
       (get-token)
       (append val1 (expr3 env) (list '>=)))
      ((eq)
       (get-token)
       (append val1 (expr3 env) (list 'eq)))
      (else val1))))

(define (expr3 env)
  (let loop ((val (term env)))
    (case *token*
      ((+)
       (get-token)
       (loop (append val (term env) (list '+))))
      ((-)
       (get-token)
       (loop (append val (term env) (list '-))))
      (else val))))

;;; 項
(define (term env)
  (let loop ((val (factor env)))
    (case *token*
      ((*)
       (get-token)
       (loop (append val (factor env) (list '*))))
      ((/)
       (get-token)
       (loop (append val (factor env) (list '/))))
      ((%)
       (get-token)
       (loop (append val (factor env) (list '%))))
      (else val))))

;;; 実引数のコンパイル
(define (compile-argument env)
  (get-token)
  (if (eq? *token* 'rpar)
      (begin (get-token) (list 'args 0))
    (let loop ((n 1) (a '()))
      (let ((expr (expression env)))
        (case *token*
          ((rpar)
           (get-token)
           (append (append a expr) (list 'args n)))
          ((comma)
           (get-token)
           (loop (+ n 1) (append a expr)))
          (else
           (compile-error "unexpected token")))))))

;;; 仮引数の取得
(define (get-parameter)
  (get-token)
  (unless (eq? *token* 'lpar)
    (compile-error "'(' expected"))
  (get-token)
  (let loop ((a '()))
    (let ((val *value*))
      (case *token*
        ((rpar)
         (get-token)
         (reverse a))
        ((ident)
         (let ((val *value*))
           (get-token)
           (loop (cons val a))))
        ((comma)
         (get-token)
         (loop a))
        (else
         (compile-error "unexpected token"))))))

;;; 位置を求める
(define (position var ls)
  (let loop ((i 0) (ls ls))
    (cond ((null? ls) #f)
          ((eqv? var (car ls)) i)
          (else
           (loop (+ i 1) (cdr ls))))))

;;; フレームと局所変数の位置を求める
(define (location var ls)
  (let loop ((i 0) (ls ls))
    (if (null? ls)
        #f
      (let ((j (position var (car ls))))
        (if j
            (cons i j)
          (loop (+ i 1) (cdr ls)))))))

;;; 因子
(define (factor env)
  (case *token*
    ((lpar)
     (get-token)
     (begin0
      (expression env)
      (if (eq? *token* 'rpar)
          (get-token)
          (compile-error "')' expected"))))
    ((number)
     (begin0 (list 'ldc *value*) (get-token)))
    ((string)
     (begin0 (list 'ldc *value*) (get-token)))
    ((not)
     (get-token)
     (append (factor env) (list 'not)))
    ((+)
     ;; 単項演算子 (+ をはずすだけ)
     (get-token)
     (factor env))
    ((-)
     ;; 単項演算子
     (get-token)
     (append (factor env) (list 'neg)))
    ((fn)
     ;; クロージャの生成
     (let ((code (list 'ldf
                       (append (compile-block (cons (get-parameter) env))
                               (list 'rtn)))))
       (get-token)
       (if (eq? *token* 'lpar)
           ;; 関数呼び出し
           (append (compile-argument env) code (list 'app))
         code)))
    ((callcc)
     ;; 継続 callcc(f)
     ;; ldct next args 1 引数 f の評価 app next ...
     (get-token)
     (unless (eq? *token* 'lpar)
       (compile-error "callcc: '(' expected"))
     (get-token)
     (let ((code (append (list 'args 1) (expression env) (list 'app))))
       (unless (eq? *token* 'rpar)
         (compile-error "callcc: invalid token"))
       (get-token)
       (append (list 'ldct (length code)) code)))
    ((ident)
     (let ((code #f)
           (pos (location *value* env)))
       (if pos
           ;; 局所変数
           (set! code (list 'ld pos))
         ;; 大域変数
         (set! code (list 'ldg (get-gvar *value*))))
       (get-token)
       (if (eq? *token* 'lpar)
           ;; 関数呼び出し
           (append (compile-argument env) code (list 'app))
         ;; 変数
         code)))
    (else
     (compile-error "unexpected token"))))

;;; if 文のコンパイル
(define (compile-if env)
  (let ((test-form (expression env))
        (then-form #f)
        (else-form #f))
    (unless (eq? *token* 'then)
      (compile-error "if: then expected"))
    (get-token)
    (set! then-form (append (compile-statement env) (list 'join)))
    (get-token)  ; end, semic を読み飛ばす
    (if (eq? *token* 'else)
        (begin (get-token)
               (set! else-form
                     (append (begin0 (compile-statement env)
                                     (get-token)) ; end, semic を読み飛ばす
                             (list 'join))))
      (set! else-form (list 'ldc 0 'join)))
    (unless (eq? *token* 'end)
      (compile-error "if: end expected"))
    (append test-form (list 'sel then-form else-form))))

;;; while 文のコンパイル
(define (compile-while env)
  (let ((test (expression env))
        (body #f))
    (unless (eq? *token* 'do)
      (compile-error "while: do expected"))
    (get-token)
    (set! body (append (compile-block env) (list 'rpt)))
    (append (list 'bgn) test (list 'whl) (list body))))

;;; block 文のコンパイル
(define (compile-block env)
  (let loop ((code '()))
    (let ((code1 (compile-statement env)))
      (get-token)  ; 実行文の終端 (semic, end) を読み飛ばす
      (cond ((eq? *token* 'end)
             (append code code1))
            (else
             (loop (append code code1 (list 'pop))))))))

;;; let 文のコンパイル
(define (compile-let env)
  (let loop ((vars '()) (code '()))
    (cond ((eq? *token* 'in)
           (get-token)
           ;; 本体コードの生成
           (append code
                   (list 'args
                         (length vars)
                         'ldf
                         (append (compile-block (cons (reverse vars) env))
                                 (list 'rtn))
                         'app)))
          ((eq? *token* 'ident)
           (let ((var *value*))
             (get-token)
             (unless (eq? *token* '=)
               (compile-error "let: invalid assignment form"))
             (get-token)
             (loop (cons var vars) (append code (expr1 env)))))
          ((eq? *token* 'comma)
           (get-token)
           (loop vars code))
          (else
           (compile-error "let: unexpected token")))))

;;; 実行文のコンパイル
(define (compile-statement env)
  (case *token*
    ((begin)
     (get-token)
     (compile-block env))
    ((if)
     (get-token)
     (compile-if env))
    ((while)
     (get-token)
     (compile-while env))
    ((let)
     (get-token)
     (compile-let env))
    (else
     ;; 式文
     (begin0
       (expression env)
       (unless (eq? *token* 'semic)
         (compile-error "';' expected"))))))

;;; 最適化
(define (optimize code)
  (let loop ((code code))
    (when (pair? code)
      (cond ((or (eq? (car code) 'ld)
                 (eq? (car code) 'ldg))
             ;; スキップする
             (loop (cddr code)))
            ((pair? (car code))
             (optimize (car code))
             (loop (cdr code)))
            ((and (eq? (car code) 'sel)
                  (eq? (cadddr code) 'rtn))
             ;; sel then else rtn ならば最適化
             (set-car! code 'selr)
             (set-car! (last-pair (cadr code)) 'rtn)  ; then 節
             (set-car! (last-pair (caddr code)) 'rtn) ; else 節
             (loop (cdr code)))
            ((and (eq? (car code) 'app)
                  (eq? (cadr code) 'rtn))
             ;; app rtn ならば最適化
             (set-car! code 'tapp)
             (loop (cdr code)))
            (else
             (loop (cdr code)))))))

;;; コンパイル
(define (compile)
  (cond ((eq? *token* 'def)
         ;; 関数定義
         (get-token)
         (unless (eq? *token* 'ident)
           (compile-error "invalid def form"))
         (let ((name *value*)
               (code (append (compile-block (list (get-parameter)))
                     (list 'rtn))))
           (list 'ldf code 'gset (get-gvar name))))
        (else
         (compile-statement '()))))

;;;
;;; 仮想マシン
;;;

;;;
(define (drop ls n)
  (if (or (zero? n) (null? ls))
      ls
    (drop (cdr ls) (- n 1))))

;;; 局所変数の値を求める
(define (get-lvar e i j)
  (list-ref (list-ref e i) j))

;;; 局所変数の値を更新する
(define (set-lvar! e i j val)
  (set-car! (drop (list-ref e i) j) val))

(define (vm s e c d)
  (case (car c)
    ((+)
     (vm (cons (+ (cadr s) (car s)) (cddr s)) e (cdr c) d))
    ((-)
     (vm (cons (- (cadr s) (car s)) (cddr s)) e (cdr c) d))
    ((*)
     (vm (cons (* (cadr s) (car s)) (cddr s)) e (cdr c) d))
    ((/)
     (vm (cons (/ (cadr s) (car s)) (cddr s)) e (cdr c) d))
    ((%)
     (vm (cons (modulo (cadr s) (car s)) (cddr s)) e (cdr c) d))
    ((==)
     (vm (cons (if (= (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((!=)
     (vm (cons (if (= (cadr s) (car s)) 0 1) (cddr s)) e (cdr c) d))
    ((<)
     (vm (cons (if (< (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((<=)
     (vm (cons (if (<= (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((<)
     (vm (cons (if (< (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((<=)
     (vm (cons (if (<= (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((>)
     (vm (cons (if (> (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((>=)
     (vm (cons (if (>= (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((eq)
     (vm (cons (if (eqv? (cadr s) (car s)) 1 0) (cddr s)) e (cdr c) d))
    ((and)
     (vm (cons (if (zero? (bitwise-and (cadr s) (car s))) 0 1) (cddr s)) e (cdr c) d))
    ((or)
     (vm (cons (if (zero? (bitwise-ior (cadr s) (car s))) 0 1) (cddr s)) e (cdr c) d))
    ((neg)
     (vm (cons (- (car s)) (cdr s)) e (cdr c) d))
    ((not)
     (vm (cons (if (zero? (car s)) 1 0) (cdr s)) e (cdr c) d))
    ((ld)
     (let ((pos (cadr c)))
       (vm (cons (get-lvar e (car pos) (cdr pos)) s) e (cddr c) d)))
    ((ldc)
     (vm (cons (cadr c) s) e (cddr c) d))
    ((ldg)
     ;; c = (ldg (sym . val) ...)
     (vm (cons (cdr (cadr c)) s) e (cddr c) d))
    ((ldf)
     (vm (cons (list 'closure (cadr c) e) s) e (cddr c) d))
    ((ldct)
     ;; 継続
     (vm (cons (list 'continuation s e (drop (cddr c) (cadr c)) d) s)
         e
         (cddr c)
         d))
    ((lset)
     (let ((pos (cadr c)))
       (set-lvar! e (car pos) (cdr pos) (car s))
       (vm s e (cddr c) d)))
    ((gset)
     ;; c = (gset (sym . val) ...)
     (set-cdr! (cadr c) (car s))
     (vm s e (cddr c) d))
    ((app)
     (let ((clo (car s)) (lvar (cadr s)))
       (case (pop! clo)
         ((primitive)
          ;; (primitive function)
          (vm (cons (apply (car clo) lvar) (cddr s)) e (cdr c) d))
         ((continuation)
          (vm (cons (car lvar) (car clo)) (cadr clo) (caddr clo) (cadddr clo)))
         (else
          ;; (closure code env)
          (vm '()
              (cons lvar (cadr clo))
              (car clo)
              (cons (list (cddr s) e (cdr c)) d))))))
    ((tapp)
     (let ((clo (car s)) (lvar (cadr s)))
       (case (pop! clo)
         ((primitive)
          ;; (primitive function)
          (vm (cons (apply (car clo) lvar) (cddr s)) e (cdr c) d))
         ((continuation)
          (vm (cons (car lvar) (car clo)) (cadr clo) (caddr clo) (cadddr clo)))
         (else
          ;; (closure code env)
          (vm (cddr s) (cons lvar (cadr clo)) (car clo) d)))))
    ((rtn)
     (let ((save (car d)))
       (vm (cons (car s) (car save)) (cadr save) (caddr save) (cdr d))))
    ((sel)
     (let ((t-clause (cadr c))
           (e-clause (caddr c)))
       (if (zero? (car s))
           (vm (cdr s) e e-clause (cons (cdddr c) d))
         (vm (cdr s) e t-clause (cons (cdddr c) d)))))
    ((selr)
     (let ((t-clause (cadr c))
           (e-clause (caddr c)))
       (if (zero? (car s))
           (vm (cdr s) e e-clause d)
         (vm (cdr s) e t-clause d))))
    ((join)
     (vm s e (car d) (cdr d)))
    ((pop)
     (vm (cdr s) e (cdr c) d))
    ((args)
     (let loop ((n (cadr c)) (a '()))
       (if (zero? n)
           (vm (cons a s) e (cddr c) d)
         (loop (- n 1) (cons (pop! s) a)))))
    ((bgn)
     (vm s e (cdr c) (cons (cdr c) d)))
    ((whl)
     (if (zero? (car s))
         (vm (cons 0 (cdr s)) e (cddr c) (cdr d))
       (vm (cdr s) e (cadr c) d)))
    ((rpt)
     (vm (cdr s) e (car d) d))
    ((halt)
     (car s))
    (else
     (error "vm: unexpected code:" (car c)))))

(define (load-file name)
  (define (restore-env xs)
    (set! *input* (list-ref xs 0))
    (set! *token* (list-ref xs 1))
    (set! *value* (list-ref xs 2))
    (set! *ch*    (list-ref xs 3))
    (set! *line*  (list-ref xs 4))
    (set! *col*   (list-ref xs 5)))
  (call-with-input-file name
    (lambda (in)
      (let ((env (list *input* *token* *value* *ch* *line* *col*)))
        (set! *input* in)
        (set! *line* 1)
        (set! *col*  0)
        (nextch)
        (with-exception-handler
         (lambda (err) (restore-env env))
         (lambda ()
           (let loop ()
             (get-token)
             (when (not (eq? *token* 'eof))
                   (vm '() '() (append (compile) (list 'halt)) '())
                   (loop)))
           (restore-env env)))))))

;;; 入力をクリアする
(define (clear-input-data)
  (do ()
      ((eqv? *ch* #\newline))
    (nextch)))

;;; プロンプトの表示
(define (prompt)
  (display "Calc> ")
  (flush-output-port)
  (set! *line* 0)
  (set! *col* 0))

(define (calc)
  (prompt)
  (nextch)
  (call/cc
    (lambda (break)
      (let loop ()
        (guard (err
                (else (display "ERROR: ")
                      (display (error-object-message err))
                      (unless
                       (null? (error-object-irritants err))
                       (display (error-object-irritants err)))
                      (newline)
                      (clear-input-data)))
          (get-token)
          (when (eqv? *token* 'eof) (break #t))
          (let ((code (append (compile) (list 'halt))))
            (optimize code)
            (let* ((s (current-jiffy))
                   (val (vm '() '() code '())))
              (display (inexact (/ (- (current-jiffy) s) (jiffies-per-second))))
              (newline)
              (display "=> ")
              (display (if (pair? val) (car val) val))
              (newline))))
        (prompt)
        (loop)))))

;;; 実行
(calc)

初版 2011 年 8 月 28 日
改訂 2021 年 6 月 26 日

Copyright (C) 2011-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | Scheme | NextPage ]