M.Hiroi's Home Page

お気楽 Haskell プログラミング入門

micro Scheme 編 : Haskell で作る micro Scheme (6)

Copyright (C) 2013-2021 Makoto Hiroi
All rights reserved.

はじめに

今回は micro Scheme に「継続 (continuation)」の機能を追加してみましょう。micro Scheme で継続を扱う場合、式を評価する関数を「継続渡しスタイル (CPS)」で書き直す必要があります。この修正が少々面倒な作業になりますが、継続そのものは簡単に実装することができます。

●継続の使い方

Scheme の継続は Haskell の継続モナドとは異なるので、ここで簡単に説明しておきましょう。

Scheme の場合、継続を取り出すには関数 call/cc を使います。call/cc には関数をひとつ渡します。call/cc に渡される関数は引数がひとつで、その引数に call/cc が取り出した継続が渡されます。call/cc はその関数を評価し、その結果が call/cc の返り値になります。プログラムは継続渡しスタイルで記述する必要はありません。普通のプログラムの中で継続を取り扱うことができます。

Scheme の仕様書 (R5RS) によると、継続は引数を一つ取る関数で表されます。引数を渡して継続を評価すると、今までの処理を破棄して、call/cc で取り出された残りの計算 (継続) を実行します。このとき、継続に渡した引数が call/cc の返り値になります。

簡単な例を示しましょう。

Scm> (call/cc (lambda (k) k))
<continuation>
Scm> (+ 1 (* 2 (call/cc (lambda (k) 3))))
7
Scm> (+ 1 (* 2 (call/cc (lambda (k) (k 4) 3))))
9

最初の例では、ラムダ式の引数 k に継続が渡されます。ラムダ式は k をそのまま返しているので、call/cc の返り値は取り出された継続になります。micro Scheme では継続を <continuation> と表示します。

次の例を見てください。call/cc によって取り出される継続は、call/cc の返り値を 2 倍して、その結果に 1 を加えるという処理になります。call/cc の返り値を X とすると、継続は (+ 1 (* 2 X)) という式で表すことができます。ラムダ式では継続を評価せずに 3 をそのまま返しているので、(+ 1 (* 2 3)) をそのまま計算して値は 7 になります。

最後の例では、匿名関数の中で (k 4) を評価しています。継続を評価しているので、現在の処理を破棄して、取り出した継続 (+ 1 (* 2 X)) を評価します。したがって、ラムダ式で (k 4) の後ろにある 3 を返す処理は実行されません。X の値は継続の引数 4 になるので、(+ 1 (* 2 4)) を評価して値は 9 になります。

Scheme の場合、継続を変数に保存しておいて、あとから実行することもできます。次の例を見てください。

Scm> (define c false)
c
Scm> (+ 1 (* 2 (call/cc (lambda (k) (set! c k) 3))))
7
Scm> (c 10)
21
Scm> (c 100)
201

ラムダ式の中で取り出した継続を大域変数 c に保存します。継続で行う処理は (+ 1 (* 2 X)) なので、(c 10) は (+ 1 (* 2 10)) を評価して値は 21 になります。同様に、(c 100) は (+ 1 (* 2 100)) を評価して値は 201 になります。

Scheme の継続については拙作のページ Scheme 入門「継続と継続渡しスタイル」で詳しく説明しています。よろしければ参考にしてください。

●継続を表すデータ型の定義

それではプログラムを作りましょう。最初に、継続を表すデータ型を定義します。

リスト : データ型の定義

type Cont r a = (a -> r) -> r

type ScmFunc = Env -> SExpr -> Cont (Scm SExpr) (Scm SExpr)

data SExpr = INT  !Integer
           | REAL !Double
           | SYM  String
           | STR  String
           | CELL SExpr SExpr
           | NIL
           | PRIM ScmFunc
           | SYNT ScmFunc
           | CLOS SExpr LEnv
           | CONT (Scm SExpr -> Scm SExpr)
           | MACR SExpr

継続を表すデータ型 Cont r a を定義します。プリミティブ (PRIM)、シンタックス形式 (SYNT)、S 式を評価する関数はすべて継続渡しスタイルで書き直します。このデータ型を ScmFunc で表します。継続を表す関数の型は Scm SExpr -> Scm SExpr になるので、これを CONT に格納して「継続」を表すことにします。継続 CONT の生成は micro Scheme の関数 call/cc の処理を行うプリミティブ callcc で行い、継続の評価は関数 apply で行います。

●S 式の評価の修正

次は S 式を評価する関数 eval を修正します。

リスト : S 式の評価

eval :: ScmFunc
eval env NIL        c = c $ return NIL
eval env v@(INT _)  c = c $ return v
eval env v@(REAL _) c = c $ return v
eval env v@(STR _)  c = c $ return v
eval env (SYM name) c = do
  a <- liftIO $ lookupLEnv name $ snd env
  case a of
    Nothing -> do b <- liftIO $ H.lookup (fst env) name
                  case b of
                    Nothing -> throwError $ "unbound variable: " ++ name
                    Just v  -> c $ return v
    Just v -> c $ return v
eval env (CELL func args) c =
  eval env func (\m ->
    do v <- m
       case v of
         SYNT f -> f env args c
         MACR f -> apply env f args (\m1 -> do expr <- m1
                                               eval env expr c)
         _      -> evalArguments env args (\m2 -> do vs <- m2
                                                     apply env v vs c))

eval の引数 c が継続を表すクロージャです。たとえば、値 value を返す場合は c に return value を渡して評価します。引数が自己評価フォームの場合はそれ自身を、変数の場合は求めた値を return で包んで継続 c に渡して評価します。エラーを返す場合は継続 c にエラーを渡す必要はありません。継続 c を破棄してエラーを返すだけです。なお、継続 c にエラーを渡した場合でも、モナドの働きによって継続が評価されることは無いので、エラーはそのまま返されることになります。

引数がリストの場合はちょっと複雑です。まず、先頭要素を eval で評価して、その値を継続 (ラムダ式) の引数 m に渡して処理します。このように、eval を呼び出すときは、必ず継続渡しスタイルでプログラムを記述してください。

m はモナドなので <- で値を取り出して変数 v にセットします。あとは case で処理を振り分けます。SYNT f の場合は引数 args を評価しないで f に渡します。MACR f の場合は apply でマクロ本体を評価し、その結果をラムダ式に渡します。ラムダ式ではモナド m1 から式 expr を取り出して、それを eval で評価します。それ以外の場合は、evalArgument で引数 args を評価して、その結果をラムダ式に渡します。その中でモナド m2 から引数を取り出して apply に渡します。

引数を評価する evalArgument も CPS で書き直します。次のリストを見てください。

リスト : 引数の評価

evalArguments :: ScmFunc
evalArguments env NIL c = c $ return NIL
evalArguments env (CELL expr rest) c =
  eval env
       expr
       (\m1 -> evalArguments env
                             rest
                             (\m2 -> do v  <- m1
                                        vs <- m2
                                        c $ return (CELL v vs)))
evalArguments _ _ _ = throwError "invalid function form"

引数が NIL の場合は継続 c で return NIL を返します。そうでなければ、eval で先頭要素 expr を評価し、その結果を継続 (ラムダ式) の引数 m1 に渡します。この中で evalArgument を再帰呼び出しし、残りのリスト rest を処理します。その結果は継続の引数 m2 に渡されます。ここで、モナド m1, m2 から値を取り出して、その値を CELL に格納して継続 c で返します。これで引数を評価した結果をリストに格納して返すことができます。

継続渡しスタイルのプログラムはちょっと難しいと思います。よく理解できない方は拙作のページ「継続渡しスタイル」をお読みください。

●シンタックス形式の修正

次はシンタックス形式を処理する関数を修正します。次のリストを見てください。

リスト : シンタックス形式 (1)

-- quote
evalQuote :: ScmFunc
evalQuote env (CELL expr _) c = c $ return expr
evalQuote _ _ _ = throwError "invalid quote form"

-- define
evalDef :: ScmFunc
evalDef env (CELL sym@(SYM name) (CELL expr NIL)) c =
  eval env expr (\m -> do
    v <- m
    lift $ H.update (fst env) name v
    c $ return sym)
evalDef _ _ _ = throwError "invalid define form"

-- define-macro
evalDefM :: ScmFunc
evalDefM env (CELL sym@(SYM name) (CELL expr NIL)) c =
  eval env expr (\m -> do
    v <- m
    lift $ H.update (fst env) name (MACR v)
    c $ return sym)
evalDefM _ _ _ = throwError "invalid define-macro form"

evalQuote は簡単です。リストの先頭要素 expr を return でモナドに包んで継続 c に渡すだけです。evalDef は eval でリストの第 2 要素 expr を評価して、その値を継続の引数 m に渡します。この中で、モナド m から値を v を取り出して大域変数にセットします。そして、継続 c で return sym を返します。マクロを定義する evalDefM も同じです。

リスト : シンタックス形式 (2)

-- if
evalIf :: ScmFunc
evalIf env (CELL pred (CELL thenForm rest)) c =
  eval env pred (\m -> do
    v <- m
    if v /= false
    then eval env thenForm c
    else case rest of
           CELL elseForm _ -> eval env elseForm c
           _               -> c $ return false)
evalIf _ _ _ = throwError $ "if : " ++ errNEA

-- lambda
evalLambda :: ScmFunc
evalLambda env expr c = c $ return (CLOS expr (snd env))

evalIf は条件部 pred を eval で評価して、その結果を継続の引数 m に渡します。その中でモナド m から値 v を取り出し、その値が真の場合、eval で then 節を評価します。偽の場合、else 節があればそれを eval で評価します。なければ、継続 c で return false を返します。evalLambda は簡単です。クロージャを生成して return でモナドに包み、それを継続 c で返すだけです。

リスト : シンタックス形式 (3)

-- set!
evalSet :: ScmFunc
evalSet env (CELL (SYM name) (CELL expr _)) c =
  eval env expr (\m -> do
    v <- m
    a <- lift $ lookupLEnv name (snd env)
    case a of
      Nothing -> do b <- lift $ H.lookup (fst env) name
                    case b of
                      Nothing -> throwError $ "unbound variable: " ++ name
                      Just _ -> do lift $ H.update (fst env) name v
                                   c $ return v
      Just _  -> do lift $ updateLEnv name v (snd env)
                    c $ return v)
evalSet _ _ _ = throwError "invalid set! form"

evalSet は eval で式 expr を評価して、継続の引数 m に渡します。この中でモナド m から値 v を取り出します。次に、大域変数から変数 name を探します。見つからない場合は局所変数から name を探します。変数を見つけた場合はその値を v に更新して、継続 c で return v を返します。

●関数適用の修正

次は apply の修正と call/cc を処理するプリミティブを作ります。

リスト : micro Scheme 用 call/cc

callcc :: ScmFunc
callcc env (CELL func _) c = 
  apply env func (CELL (CONT c) NIL) c
callcc _ _ _ = throwError $ "call/cc " ++ errNEA

Scheme の関数 call/cc の処理はプリミティブ callcc で行います。call/cc に渡される関数 func を apply で呼び出します。このとき、継続 c を CONT に格納して func の引数として渡すだけです。

apply のプログラムは次のようになります。

リスト : 関数適用

apply :: Env -> SExpr -> SExpr -> Cont (Scm SExpr) (Scm SExpr)
apply env func actuals c =
  case func of
    PRIM f  -> f env actuals c
    CONT c1 -> case actuals of
                 NIL -> throwError errNEA
                 (CELL x _) -> c1 $ return x
    CLOS (CELL parms body) lenv0 -> do
      lenv1 <- makeBindings lenv0 parms actuals
      evalBody (fst env, lenv1) body c
    _ -> throwError $ "Not Function: " ++ show func

apply も継続渡しスタイルで記述します。関数 func が PRIM f の場合はプリミティブ f を呼び出すだけです。CONT c1 の場合、継続 c1 に引数を一つ渡して評価します。このとき、apply の継続 c を評価して値を返してはいけません。ここで継続 c を破棄して、call/cc で取り出した継続 c1 を実行します。CLOS の場合、makeBindings で変数束縛を行い、evalBody で本体を評価します。

evalBody も継続渡しスタイルでプログラムします。次のリストを見てください。

リスト : 本体の評価

evalBody :: ScmFunc
evalBody env (CELL expr NIL) c = eval env expr c
evalBody env (CELL expr rest) c =
  eval env expr (\_ -> evalBody env rest c)
evalBody _ _ _ = throwError "invalid body form"

引数の要素が残り一つの場合、それを eval で評価します。要素が複数有る場合、eval で先頭要素を評価し、その結果を継続に渡しますが、その値は捨てていることに注意してください。それから、ラムダ式の中で evalBody を再帰呼び出しします。結局、最後の S 式の評価結果を継続 c で返すことになります。

●REPL の修正

最後に REPL (read-eval-print-loop) を修正します。次のリストを見てください。

リスト : REPL

initGEnv :: [(String, SExpr)]
initGEnv = [("true",   true),
            ("false",  false),

            ・・・ 省略 ・・・

            ("display", PRIM display),
            ("newline", PRIM newline),
            ("error",   PRIM error'),
            ("call/cc", PRIM callcc)]

repl :: Env -> String -> IO ()
repl env xs = do
  putStr "Scm> "
  hFlush stdout
  case readSExpr xs of
    Left  (ParseErr xs' mes) -> do putStrLn mes
                                   repl env $ dropWhile (/= '\n') xs'
    Right (expr, xs') -> do result <- runExceptT $ eval env expr id
                            case result of
                              Left mes -> putStrLn mes
                              Right v  -> print v
                            repl env xs'

initGEnv には call/cc のほかに、テストで使うため display と newline を追加します。関数 repl は eval に id を渡すだけです。これで、eval の評価結果を求めて、値を表示することができます。

あとの修正は簡単なので説明は割愛します。詳細はプログラムリストをお読みください。

●簡単な実行例

それでは実際に継続を使ってみましょう。

Scm> (define a false)
a
Scm> (list 'a 'b (call/cc (lambda (k) (set! a k) 'c)) 'd)
(a b c d)
Scm> (a 'e)
(a b e d)
Scm> (a 'f)
(a b f d)

変数 a に取り出した継続をセットします。この場合、継続は (list 'a 'b [ ] 'd) になります。list の処理だけではなく、'd を評価する処理も残っています。継続 a に引数を渡して評価すると、[ ] の部分に継続の引数がセットされ、'd を評価して list に渡されます。したがって、(a 'e) を評価すると (a b e d) になり、(a 'f) を評価すると (a b f d) になります。正常に動作していますね。

●大域脱出

次は大域脱出を試してみましょう。

Scm> (define bar1 (lambda (k) (display "call bar1\n")))
bar1
Scm> (define bar2 (lambda (k) (display "call bar2\n") (k false)))
bar2
Scm> (define bar3 (lambda (k) (display "call bar3\n")))
bar3
Scm> (define test (lambda (k) (bar1 k) (bar2 k) (bar3 k)))
test
Scm> (call/cc (lambda (k) (test k)))
call bar1
call bar2
false

bar2 からトップレベルへ脱出するので、bar3 は呼び出されていません。これも正常に動作していますね。

●繰り返しからの脱出

もちろん、繰り返しから脱出することもできます。次の例を見てください。

リスト : do から脱出する場合

(define find-do
  (lambda (fn ls)
    (call/cc
      (lambda (k)
        (do ((xs ls (cdr xs)))
            ((null? xs) false)
          (if (fn (car xs)) (k (car xs))))))))

リスト ls から関数 fn が真を返す要素を探します。継続のテストということで、あえて do を使って実装しています。fn が真を返す場合、継続 k でその要素を返します。それでは実行してみましょう。

Scm> (find-do (lambda (x) (eq? x 'c)) '(a b c d e))
c
Scm> (find-do (lambda (x) (eq? x 'c)) '(a b d e f))
false

もちろん高階関数からも脱出することができます。

リスト : map から脱出する場合

(define map-check (lambda (fn chk ls)
  (call/cc
    (lambda (k)
      (map (lambda (x) (if (chk x) (k '()) (fn x))) ls)))))
Scm> (map-check (lambda (x) (cons x x)) (lambda (x) (eq? x 'e)) '(a b c d e))
()
Scm> (map-check (lambda (x) (cons x x)) (lambda (x) (eq? x 'e)) '(a b c d f))
((a . a) (b . b) (c . c) (d . d) (f . f))

関数 chk が真となる要素がある場合、処理を中断して空リストを返します。これも正常に動いていますね。

●再帰呼び出しからの脱出

再帰呼び出しから脱出することも簡単です。

リスト : flatten の再帰呼び出しから脱出する場合

(define flatten (lambda (ls)
  (call/cc
    (lambda (cont)
      (letrec ((flatten-sub
                (lambda (ls)
                  (cond ((null? ls) '())
                        ((not (pair? ls)) (list ls))
                        ((null? (car ls)) (cont '()))
                        (else (append (flatten-sub (car ls))
                                      (flatten-sub (cdr ls))))))))
        (flatten-sub ls))))))

拙作のページ Scheme 入門「継続と継続渡しスタイル」で作成したプログラムと同じです。リストを平坦化する関数 flatten で、要素に空リストが含まれている場合は空リストを返します。

Scm> (flatten '(a (b (c (d . e) f) g) h))
(a b c d e f g h)
Scm> (flatten '(a (b (c (d () . e) f) g) h))
()

これも正常に動作しています。

●イテレータの生成

最後に、イテレータを生成する関数 make-iter を試してみます。

リスト : イテレータを生成する関数

(define make-iter
  (lambda (proc . args)
    (letrec ((iter
              (lambda (return)
                (apply 
                 proc
                 (lambda (x)             ; 高階関数に渡す関数の本体
                   (set! return          ; 脱出先継続の書き換え
                         (call/cc
                          (lambda (cont)
                            (set! iter cont)  ; 継続の書き換え
                            (return x)))))
                 args)
                ; 終了後は継続 return で脱出
                (return false))))
            (lambda ()
              (call/cc
               (lambda (cont) (iter cont)))))))
リスト : 木の高階関数

(define for-each-tree 
  (lambda (fn ls)
    (let loop ((ls ls))
      (cond ((null? ls) '())
            ((pair? ls)
             (loop (car ls))
             (loop (cdr ls)))
            (else (fn ls))))))

拙作のページ Scheme 入門「継続と継続渡しスタイル」で作成したプログラムと同じです。詳しい説明は「継続と継続渡しスタイル」をお読みください。

それでは実行してみましょう。

Scm> (define a (make-iter for-each-tree '(a (b (c (d . e) f) g) h)))
a
Scm> (a)
a
Scm> (a)
b
Scm> (a)
c
Scm> (a)
d
Scm> (a)
e
Scm> (a)
f
Scm> (a)
g
Scm> (a)
h
Scm> (a)
false

正常に動作していますね。

今回はここまでです。次回は継続を使った例題として、非決定性計算を行う関数 amb を作ってみましょう。

●参考文献, URL

  1. 黒川利明, 『LISP 入門』, 培風館, 1982
  2. Patrick Henry Winston, Berthold Klaus Paul Horn, 『LISP 原書第 3 版 (1)』, 培風館, 1992, 18. Lisp で書く Lisp
  3. R. Kent Dybvig (著), 村上雅章 (訳), 『プログラミング言語 SCHEME』, 株式会社ピアソン・エデュケーション, 2000, 9.2 Scheme のメタ循環インタプリタ
  4. Ravi Sethi (著), 神林靖 (訳), 『プログラミング言語の概念と構造』, アジソンウェスレイ, 1995, 第 11 章 定義インタプリタ
  5. 小西弘一, 清水剛, 『CプログラムブックⅢ』, アスキー, 1986
  6. Harold Abelson, Gerald Jay Sussman, Julie Sussman, "Structure and Interpretation of Computer Programs", 4.1 The Metacircular Evaluator
  7. 稲葉雅幸, ソフトウェア特論, Scheme インタプリタ

●プログラムリスト

--
-- mscheme4.hs : microScheme インタプリタ
--
--               Copyright (C) 2013-2021 Makoto Hiroi
--
import Data.Char
import Data.IORef
import qualified Data.HashTable.IO as H
import Control.Monad.Except
import Control.Monad.Trans
import Control.Monad.IO.Class
import System.IO

-- S 式の定義
type Cont r a = (a -> r) -> r

type ScmFunc = Env -> SExpr -> Cont (Scm SExpr) (Scm SExpr)

data SExpr = INT  !Integer
           | REAL !Double
           | SYM  String
           | STR  String
           | CELL SExpr SExpr
           | NIL
           | PRIM ScmFunc
           | SYNT ScmFunc
           | CLOS SExpr LEnv
           | CONT (Scm SExpr -> Scm SExpr)
           | MACR SExpr

-- 等値の定義
instance Eq SExpr where
  INT x  == INT y  = x == y
  REAL x == REAL y = x == y
  SYM x  == SYM y  = x == y
  STR x  == STR y  = x == y
  NIL    == NIL    = True
  _      == _      = False

-- パーサエラーの定義
data ParseErr = ParseErr String String deriving Show

-- パーサの定義
type Parser a = Either ParseErr a

-- エラーの送出
parseError :: String -> String -> Parser a
parseError x s = Left (ParseErr x s)

-- 評価器の定義
type Scm a = ExceptT String IO a

-- ローカル環境の定義
type LEnv = [(String, IORef SExpr)]

pushLEnv :: String -> SExpr -> LEnv -> IO LEnv
pushLEnv s v env = do
  a <- v `seq` newIORef v
  return ((s, a):env)

lookupLEnv :: String -> LEnv -> IO (Maybe SExpr)
lookupLEnv s env =
  case lookup s env of
    Nothing -> return Nothing
    Just v  -> do a <- readIORef v
                  return (Just a)

updateLEnv :: String -> SExpr -> LEnv -> IO (LEnv)
updateLEnv s v env =
  case lookup s env of
    Nothing -> pushLEnv s v env
    Just a  -> do writeIORef a v
                  return env

-- グローバルな環境
type HashTable k v = H.BasicHashTable k v
type GEnv = HashTable String SExpr

-- 両方の環境を保持する
type Env = (GEnv, LEnv)

-- 真偽値
true  = SYM "true"
false = SYM "false"

-- Primitive の定義
errNUM  = "Illegal argument, Number required"
errINT  = "Illegal argument, Integer required"
errNEA  = "Not enough arguments"
errCELL = "Illegal argument, List required"
errZERO = "Divide by zero"

-- リスト操作
car, cdr, cons, pair :: ScmFunc
car _ NIL c = throwError $ "car : " ++ errNEA
car _ (CELL (CELL a _) _) c = c $ return a
car _ _                   c = throwError $ "car : " ++ errCELL

cdr _ NIL c = throwError $ "cdr : " ++ errNEA
cdr _ (CELL (CELL _ d) _) c = c $ return d
cdr _ _                   c = throwError $ "cdr : " ++ errCELL

cons _ (CELL a (CELL b _)) c = c $ return (CELL a b)
cons _ _                   c = throwError $ "cons : " ++ errNEA

pair _ NIL                 c = throwError $ "pair? : " ++ errNEA
pair _ (CELL (CELL _ _) _) c = c $ return true
pair _ _                   c = c $ return false

-- 畳み込み
foldCell :: (SExpr -> SExpr -> Scm SExpr) -> SExpr -> SExpr -> Scm SExpr
foldCell _ a NIL = return a
foldCell f a (CELL x rest) = do v <- f a x
                                foldCell f v rest
foldCell _ _ _ = throwError errCELL

-- 四則演算
adds, subs, muls, divs, mod' :: ScmFunc
add, sub, mul, div' :: SExpr -> SExpr -> Scm SExpr

add (INT x)  (INT y)  = return (INT (x + y))
add (INT x)  (REAL y) = return (REAL (fromIntegral x + y))
add (REAL x) (INT y)  = return (REAL (x + fromIntegral y))
add (REAL x) (REAL y) = return (REAL (x + y))
add _        _        = throwError $ "+ : " ++ errNUM

adds _ xs c = c $ foldCell add (INT 0) xs

sub (INT x)  (INT y)  = return (INT (x - y))
sub (INT x)  (REAL y) = return (REAL (fromIntegral x - y))
sub (REAL x) (INT y)  = return (REAL (x - fromIntegral y))
sub (REAL x) (REAL y) = return (REAL (x - y))
sub _        _        = throwError $ "- : " ++ errNUM

subs _ (CELL (INT a) NIL)  c = c $ return (INT (-a))
subs _ (CELL (REAL a) NIL) c = c $ return (REAL (-a))
subs _ (CELL a rest) c = c $ foldCell sub a rest
subs _ _ _ = throwError $ "- : " ++ errNEA

mul (INT x)  (INT y)  = return (INT (x * y))
mul (INT x)  (REAL y) = return (REAL (fromIntegral x * y))
mul (REAL x) (INT y)  = return (REAL (x * fromIntegral y))
mul (REAL x) (REAL y) = return (REAL (x * y))
mul _        _        = throwError $ "- : " ++ errNUM

muls _ xs c = c $ foldCell mul (INT 1) xs

div' _        (INT 0)  = throwError errZERO
div' _        (REAL 0) = throwError errZERO
div' (INT x)  (INT y)  = return (INT (x `div` y))
div' (INT x)  (REAL y) = return (REAL (fromIntegral x / y))
div' (REAL x) (INT y)  = return (REAL (x / fromIntegral y))
div' (REAL x) (REAL y) = return (REAL (x / y))
div' _        _        = throwError $ "- : " ++ errNUM

divs _ (CELL a NIL)  c = c $ div' (INT 1) a
divs _ (CELL a rest) c = c $ foldCell div' a rest
divs _ _ _ = throwError $ "/ : " ++ errNEA

mod' _ NIL          c = throwError $ "mod : " ++ errNEA
mod' _ (CELL _ NIL) c = throwError $ "mod : " ++ errNEA
mod' _ (CELL _ (CELL (INT 0) _))  c = throwError errZERO
mod' _ (CELL _ (CELL (REAL 0) _)) c = throwError errZERO
mod' _ (CELL (INT x) (CELL (INT y) _)) c = c $ return (INT (mod x y))
mod' _ _ _ = throwError $ "mod : " ++ errINT

-- 等値の判定
eq', equal' :: ScmFunc

eq' _ (CELL x (CELL y _)) c =
  if x == y then c (return true) else c (return false)
eq' _ _ _ = throwError $ "eq : " ++ errNEA

equal' _ (CELL x (CELL y _)) c =
  if iter x y then c (return true) else c (return false)
  where iter (CELL a b) (CELL c d) = iter a c && iter b d
        iter x y = x == y
equal' _ _ _ = throwError $ "equal : " ++ errNEA

-- 数値の比較演算子
compareNum :: SExpr -> SExpr -> Scm Ordering
compareNum (INT x)  (INT y)  = return $ compare x y
compareNum (INT x)  (REAL y) = return $ compare (fromIntegral x) y
compareNum (REAL x) (INT y)  = return $ compare x (fromIntegral y)
compareNum (REAL x) (REAL y) = return $ compare x y
compareNum _ _ = throwError errNUM

compareNums :: (Ordering -> Bool) -> SExpr -> Cont (Scm SExpr) (Scm SExpr)
compareNums _ NIL          _ = throwError errNEA
compareNums _ (CELL _ NIL) _ = throwError errNEA
compareNums p (CELL x (CELL y NIL)) c = do
  r <- compareNum x y
  if p r then c (return true) else c (return false)
compareNums p (CELL x ys@(CELL y _)) c = do
  r <- compareNum x y
  if p r then compareNums p ys c else c (return false)
compareNums _ _ _ = throwError "invalid function form"

eqNum, ltNum, gtNum, ltEq, gtEq :: ScmFunc
eqNum _ = compareNums (== EQ)
ltNum _ = compareNums (== LT)
gtNum _ = compareNums (== GT)
ltEq  _ = compareNums (<= EQ)
gtEq  _ = compareNums (>= EQ)


-- apply
apply' :: ScmFunc
apply' _ (CELL _ NIL) _ = throwError $ "apply : " ++ errNEA
apply' env (CELL func args) c = do
  apply env func (iter args) c
  where iter (CELL NIL NIL) = NIL
        iter (CELL xs@(CELL _ _) NIL) = xs
        iter (CELL x xs) = CELL x (iter xs)
apply' _ _ _ = throwError $ "apply : " ++ errNEA

-- call/cc
callcc :: ScmFunc
callcc env (CELL func _) c =
  apply env func (CELL (CONT c) NIL) c
callcc _ _ _ = throwError $ "call/cc " ++ errNEA

-- S 式の表示
display :: ScmFunc
display _ (CELL x _) c = do case x of
                              STR s -> lift $ putStr s
                              _     -> lift $ putStr $ show x
                            c $ return NIL
display _ _ _ = throwError $ "display : " ++ errNEA

-- 改行
newline :: ScmFunc
newline _ _ c = do lift $ putStrLn ""
                   c $ return NIL

-- エラー
error' :: ScmFunc
error' _ (CELL (STR x) NIL) c = throwError $ "ERROR: " ++ x
error' _ (CELL (STR x) (CELL y _)) c = throwError $ "ERROR: " ++ x ++ " " ++ show y
error' _ (CELL x _) c = throwError $ "ERROR: " ++ show x
error' _ _ _ = throwError "ERROR: "

-- load
load :: ScmFunc
load env (CELL (STR filename) _) c = do
  xs <- lift $ readFile filename
  r <- lift $ iter xs
  if r then c (return true) else c (return false)
  where
    iter :: String -> IO Bool
    iter xs =
      case readSExpr xs of
        Left  (ParseErr xs' mes) -> if mes == "EOF"
                                      then return True
                                      else do print mes
                                              return False
        Right (expr, xs') -> do result <- runExceptT $ eval env expr id
                                case result of
                                  Left mes -> do print mes
                                                 return False
                                  Right _  -> iter xs'
load _ _ _ = throwError "invalid load form"

--
-- S 式の表示
--
showCell :: SExpr -> String
showCell (CELL a d) =
  show a ++ case d of
              NIL      -> ""
              PRIM _   -> "<primitive>"
              CLOS _ _ -> "<closure>"
              SYNT _   -> "<syntax>"
              MACR _   -> "<macro>"
              CONT _   -> "<continuation>"
              INT x    -> " . " ++ show x
              REAL x   -> " . " ++ show x
              SYM x    -> " . " ++ x
              STR x    -> " . " ++ show x
              _        -> " " ++ showCell d
showCell xs = show xs

instance Show SExpr where
  show (INT x)    = show x
  show (REAL x)   = show x
  show (SYM x)    = x
  show (STR x)    = show x
  show NIL        = "()"
  show (SYNT _)   = "<syntax>"
  show (PRIM _)   = "<primitive>"
  show (CLOS _ _) = "<closure>"
  show (MACR _)   = "<macro>"
  show (CONT _)   = "<continuation>"
  show xs         = "(" ++ showCell xs ++ ")"

--
-- S 式の読み込み
--

isAlpha' :: Char -> Bool
isAlpha' x = elem x "!$%&*+-/:<=>?@^_~"

isIdent0 :: Char -> Bool
isIdent0 x = isAlpha x || isAlpha' x

isIdent1 :: Char -> Bool
isIdent1 x = isAlphaNum x || isAlpha' x

isREAL :: Char -> Bool
isREAL x = elem x ".eE"

quote           = SYM "quote"
quasiquote      = SYM "quasiquote"
unquote         = SYM "unquote"
unquoteSplicing = SYM "unquote-splicing"

isNUM :: String -> Bool
isNUM (x:_) = isDigit x
isNUM _     = False

getNumber :: String -> Parser (SExpr, String)
getNumber xs =
  let (s, ys) = span isDigit xs
  in if not (null ys) && isREAL (head ys)
     then case reads xs of
            [] -> parseError "" "" -- ありえないエラー
            [(y', ys')] -> return (REAL y', ys')
     else return (INT (read s), ys)

readSExpr :: String -> Parser (SExpr, String)
readSExpr [] = parseError "" "EOF"
readSExpr (x:xs)
  | isSpace x  = readSExpr xs
  | isDigit x  = getNumber (x:xs)
  | isIdent0 x = if x == '+' && isNUM xs
                 then getNumber xs
                 else if x == '-' && isNUM xs
                 then do (y, ys) <- getNumber xs
                         case y of
                           INT x  -> return (INT  (- x), ys)
                           REAL x -> return (REAL (- x), ys)
                 else let (name, ys) = span isIdent1 (x:xs)
                      in return (SYM name, ys)
  | otherwise  =
      case x of
        '('  -> readCell 0 xs
        ';'  -> readSExpr $ dropWhile (/= '\n') xs
        '"'  -> case reads (x:xs) of
                  [] -> parseError "" ""
                  [(y, ys)] -> y `seq` ys `seq` return (STR y, ys)
        '\'' -> readSExpr xs >>= 
               \(e, ys) -> e `seq` ys `seq` return (CELL quote (CELL e NIL), ys)
        '`'  -> readSExpr xs >>= 
               \(e, ys) -> e `seq` ys `seq` return (CELL quasiquote (CELL e NIL), ys)
        ','  -> if not (null xs) && head xs == '@'
                  then readSExpr (tail xs) >>= \(e, ys) -> e `seq` ys `seq` 
                    return (CELL unquoteSplicing (CELL e NIL), ys)
                  else readSExpr xs >>= \(e, ys) -> e `seq` ys `seq` 
                    return (CELL unquote (CELL e NIL), ys)
        _    -> parseError xs ("unexpected token: " ++ show x)

readCell :: Int -> String -> Parser (SExpr, String)
readCell _ [] = parseError "" "EOF"
readCell n (x:xs)
  | isSpace x = readCell n xs
  | otherwise =
      case x of
        ')' -> xs `seq` return (NIL, xs)
        '.' -> if n == 0
               then parseError xs "invalid dotted list"
               else do (e, ys) <- readSExpr xs
                       case dropWhile isSpace ys of
                         ')':zs -> return (e, zs)
                         _      -> parseError xs "invalid dotted list"
        '(' -> do (a, ys) <- readCell 0 xs
                  (d, zs) <- readCell 1 ys
                  return (CELL a d, zs)
        _   -> do (a, ys) <- readSExpr (x:xs)
                  (d, zs) <- readCell 1 ys
                  return (CELL a d, zs)

--
-- S 式の評価
--
eval :: ScmFunc
eval env NIL        c = c $ return NIL
eval env v@(INT _)  c = c $ return v
eval env v@(REAL _) c = c $ return v
eval env v@(STR _)  c = c $ return v
eval env (SYM name) c = do
  a <- liftIO $ lookupLEnv name $ snd env
  case a of
    Nothing -> do b <- liftIO $ H.lookup (fst env) name
                  case b of
                    Nothing -> throwError $ "unbound variable: " ++ name
                    Just v  -> c $ return v
    Just v -> c $ return v
eval env (CELL func args) c =
  eval env func (\m ->
    do v <- m
       case v of
         SYNT f -> f env args c
         MACR f -> apply env f args (\m1 -> do expr <- m1
                                               eval env expr c)
         _      -> evalArguments env args (\m2 -> do vs <- m2
                                                     apply env v vs c))

-- 引数の評価
evalArguments :: ScmFunc
evalArguments env NIL c = c $ return NIL
evalArguments env (CELL expr rest) c =
  eval env
       expr
       (\m1 -> evalArguments env
                             rest
                             (\m2 -> do v  <- m1
                                        vs <- m2
                                        c $ return (CELL v vs)))
evalArguments _ _ _ = throwError "invalid function form"

-- 変数束縛
makeBindings :: LEnv -> SExpr -> SExpr -> Scm LEnv
makeBindings lenv NIL        _    = return lenv
makeBindings lenv (SYM name) rest = lift $ pushLEnv name rest lenv
makeBindings lenv (CELL (SYM name) parms) (CELL v args) = do
  lenv' <- makeBindings lenv parms args
  lift (pushLEnv name v lenv')
makeBindings _ _ NIL = throwError errNEA
makeBindings _ _ _   = throwError "invalid arguments form"

-- 関数適用
apply :: Env -> SExpr -> SExpr -> Cont (Scm SExpr) (Scm SExpr)
apply env func actuals c =
  case func of
    PRIM f  -> f env actuals c
    CONT c1 -> case actuals of
                 NIL -> throwError errNEA
                 (CELL x _) -> c1 $ return x
    CLOS (CELL parms body) lenv0 -> do
      lenv1 <- makeBindings lenv0 parms actuals
      evalBody (fst env, lenv1) body c
    _ -> throwError $ "Not Function: " ++ show func

-- 本体の評価
evalBody :: ScmFunc
evalBody env (CELL expr NIL) c = eval env expr c
evalBody env (CELL expr rest) c =
  eval env expr (\_ -> evalBody env rest c)
evalBody _ _ _ = throwError "invalid body form"

--
-- シンタックス形式
--

-- quote
evalQuote :: ScmFunc
evalQuote env (CELL expr _) c = c $ return expr
evalQuote _ _ _ = throwError "invalid quote form"

-- define
evalDef :: ScmFunc
evalDef env (CELL sym@(SYM name) (CELL expr NIL)) c =
  eval env expr (\m -> do
    v <- m
    lift $ H.insert (fst env) name v
    c $ return sym)
evalDef _ _ _ = throwError "invalid define form"

-- define-macro
evalDefM :: ScmFunc
evalDefM env (CELL sym@(SYM name) (CELL expr NIL)) c =
  eval env expr (\m -> do
    v <- m
    lift $ H.insert (fst env) name (MACR v)
    c $ return sym)
evalDefM _ _ _ = throwError "invalid define-macro form"

-- if
evalIf :: ScmFunc
evalIf env (CELL pred (CELL thenForm rest)) c =
  eval env pred (\m -> do
    v <- m
    if v /= false
    then eval env thenForm c
    else case rest of
           CELL elseForm _ -> eval env elseForm c
           _               -> c $ return false)
evalIf _ _ _ = throwError $ "if : " ++ errNEA

-- lambda
evalLambda :: ScmFunc
evalLambda env expr c = c $ return (CLOS expr (snd env))

-- set!
evalSet :: ScmFunc
evalSet env (CELL (SYM name) (CELL expr _)) c =
  eval env expr (\m -> do
    v <- m
    a <- lift $ lookupLEnv name (snd env)
    case a of
      Nothing -> do b <- lift $ H.lookup (fst env) name
                    case b of
                      Nothing -> throwError $ "unbound variable: " ++ name
                      Just _ -> do lift $ H.insert (fst env) name v
                                   c $ return v
      Just _  -> do lift $ updateLEnv name v (snd env)
                    c $ return v)
evalSet _ _ _ = throwError "invalid set! form"

--
-- 大域変数の初期化
--
initGEnv :: [(String, SExpr)]
initGEnv = [("true",   true),
            ("false",  false),
            ("quote",  SYNT evalQuote),
            ("define", SYNT evalDef),
            ("lambda", SYNT evalLambda),
            ("if",     SYNT evalIf),
            ("set!",   SYNT evalSet),
            ("define-macro", SYNT evalDefM),
            ("eq?",    PRIM eq'),
            ("equal?", PRIM equal'),
            ("pair?",  PRIM pair),
            ("+",      PRIM adds),
            ("-",      PRIM subs),
            ("*",      PRIM muls),
            ("/",      PRIM divs),
            ("mod",    PRIM mod'),
            ("=",      PRIM eqNum),
            ("<",      PRIM ltNum),
            (">",      PRIM gtNum),
            ("<=",     PRIM ltEq),
            (">=",     PRIM gtEq),
            ("car",    PRIM car),
            ("cdr",    PRIM cdr),
            ("cons",   PRIM cons),
            ("load",   PRIM load),
            ("apply",  PRIM apply'),
            ("display", PRIM display),
            ("newline", PRIM newline),
            ("error",   PRIM error'),
            ("call/cc", PRIM callcc)]

-- read-eval-print-loop
repl :: Env -> String -> IO ()
repl env xs = do
  putStr "Scm> "
  hFlush stdout
  case readSExpr xs of
    Left  (ParseErr xs' mes) -> do putStrLn mes
                                   repl env $ dropWhile (/= '\n') xs'
    Right (expr, xs') -> do result <- runExceptT $ eval env expr id
                            case result of
                              Left mes -> putStrLn mes
                              Right v  -> print v
                            repl env xs'

main :: IO ()
main = do
  xs <- hGetContents stdin
  ht <- H.fromList initGEnv :: IO (GEnv)
  repl (ht, []) xs

初版 2013 年 9 月 15 日
改訂 2021 年 8 月 1 日