M.Hiroi's Home Page

Go Language Programming

お気楽 Go 言語プログラミング入門

[ PrevPage | Golang | NextPage ]

電卓プログラムの改良 (コルーチン編その1)

今回は電卓プログラムに「コルーチン (co-routine)」を追加してみましょう。

●コルーチンとは?

コルーチンはサブルーチン (sub-routine) と比較するとわかりやすいと思います。ここではサブルーチンを関数のことと考えてください。サブルーチンは呼び出してから戻ってくるまで処理を中断することはできませんが、コルーチンは途中で処理を中断し、そこから実行を再開することができます。また、コルーチンを使うと複数のプログラムを (擬似的に) 並行に動作させることができます。この動作はスレッド (thread) や goroutine とよく似ています。

一般的なスレッドや goroutine は、一定時間毎に実行するスレッドや goroutine を強制的に切り替えます。このとき、スレッドのスケジューリングは処理系が行います。これをプリエンプティブ (preemptive) といいます。コルーチンの場合、プログラムの実行は一定時間ごとに切り替わるものではなく、プログラム自身が実行を中断しないといけません。これをノンプリエンプティブ (nonpreemptive) といいます。

コルーチンで複数のプログラムを並行に動作させるには、あるプログラムだけを優先的に実行するのではなく、他のプログラムが実行できるよう自主的に処理を中断する、といった協調的な動作を行わせる必要があります。そのかわり、スレッドと違って排他制御といった面倒な処理を考える必要がなく、スレッドのような切り替え時のオーバーヘッドも少ないことから、スレッドよりも動作が軽くて扱いやすいといわれています。

今回作成するコルーチンは Lua というプログラミング言語を参考にしています。Lua のコルーチンには親子関係があり、コルーチン A からコルーチン B を呼び出した場合、A が親で B が子になります。このように主従関係を持つコルーチンを「セミコルーチン (semi-coroutine)」といいます。コルーチンの親子関係は木構造と考えることができます。

●コルーチンの動作

次は、コルーチンを操作する関数を説明します。コルーチンは関数 create で生成します。

create(f) => <Coroutine>

新しいコルーチンは create で生成します。create は引数なしの関数を引数として受け取ります。このような関数を thunk といいます。この関数がコルーチンで実行する処理になります。create はコルーチンを返します。

コルーチンを実行 (または再開) するには関数 resume を使います。

resume(co, x) => yield の引数 y

引数 co は再開 (または実行) するコルーチンです。resume を呼び出したほうが親、呼び出されたほうが子になります。resume の返り値はコルーチンの実行を中断する関数 yield に渡された引数です。

yield(y) => resume の引数 x

子コルーチンの中で関数 yield(...) を実行すると、そこでプログラムの実行を中断して親コルーチンに戻ります。このとき、yield の引数が親コルーチンで呼び出した reusme の返り値になります。また、resume に引数を渡して実行を再開すると、それが yield の返り値となります。

簡単な例を示しましょう。コルーチンを使うとジェネレータを簡単に作ることができます。たとえば、リストやベクタの要素をひとつずつ取り出して返すジェネレータは次のようになります。

リスト : 列 (ベクタ, リスト) のジェネレータ

def seqGen(xs)
  create(fn() foreach(fn(x) yield(x) end, xs), nil end)
end

高階関数 foreach は電卓プログラムのライブラリ (lib.cal) に定義されています。foreach の fn 式で yield を使って要素 x を返すだけです。これで foreach の実行か中断され、resume を呼び出すと実行が再開されます。

簡単な実行例を示しましょう。

Calc> a = seqGen(list(1,2,3,4,5));
<Coroutine>
Calc> resume(a, 0);
1
Calc> resume(a, 0);
2
Calc> resume(a, 0);
3
Calc> resume(a, 0);
4
Calc> resume(a, 0);
5
Calc> resume(a, 0);
()
Calc> resume(a, 0);
resume: Dead Coroutine
Calc> b = seqGen([1,2,3]);
<Coroutine>
Calc> resume(b, 0);
1
Calc> resume(b, 0);
2
Calc> resume(b, 0);
3
Calc> resume(b, 0);
()
Calc> resume(b, 0);
resume: Dead Coroutine

resume を呼び出すたびに、リストやベクタに格納されている要素を求めることができます。コルーチンは関数の実行が終了した場合、関数の返り値が yield に渡されます。seqGen の場合、返り値の nil が resume が返す最後の値となります。そのあと、resume を実行すると "resume: Dead Coroutine" というエラーを送出します。

●コルーチンの定義

それではプログラムを作りましょう。コルーチンは Go 言語の goroutine を使うと簡単に実装することができます。最初にコルーチンを表す構造体を定義します。次のリストを見てください。

リスト : コルーチンの定義

// コルーチンの型
type COR struct {
    ch1, ch2 chan Value
    save *COR
    live bool
}

func newCOR(ch1, ch2 chan Value) *COR {
    return &COR{ch1, ch2, nil, true}
}

func (_ *COR) isTrue() bool { return true }
func (_ *COR) String() string { return "<Coroutine>" }

// 実行中のコルーチン
var runCOR *COR

構造体の名前は COR としました。フィールド変数 ch1, ch2 は親子コルーチンの通信と同期処理に使うチャネルです。ch1 は yield の引数を親コルーチンに渡すために、ch2 は resume の引数を子コルーチンに渡すために使います。同期処理についてはあとで詳しく説明します。フィールド変数 live は、コルーチンを実行しているときは true を、終了した場合は false をセットします。

関数 newCOR は引数のチャネルを構造体 COR に格納して返します。コルーチンは値なので、メソッド isTrue を定義します。そして、コルーチンを表示するため、メソッド String() を定義します。

実行中のコルーチンは大域変数 runCOR にセットします。ここで resume(co, ...) を呼び出すと、実行中のメインルーチンまたはコルーチン (runCOR) を resume の引数 co のコルーチンに切り替えます。このとき、COR のフィールド変数 save に runCOR の値をセーブし、runCOR の値を co に書き換えます。

●コルーチンの中断と再開

関数 resume の処理は次のようになります。

リスト : コルーチンの再開

func resume(x, v Value) Value {
    y, ok := x.(*COR)
    if !ok {
        panic(fmt.Errorf("resume: Coroutine required, %v", x))
    }
    if !y.live {
        panic(fmt.Errorf("resume: Dead Coroutine"))
    }
    y.save = runCOR
    runCOR = y
    y.ch2 <- v
    z := <- y.ch1
    if z == nil { z = NIL }
    return z
}

引数 x を型アサーションでコルーチン (*COR) に変換し、y.live が true であることを確認します。それから、y.save に runCOR をセーブして、runCOR の値を y に書き換えます。次に、y.ch2 に引数 v の値を送信します。これが子コルーチンで呼び出した yield の返り値になります。最後に、子コルーチンからのデータを y.ch1 で受信して、その値を返します。子ルーチンでエラーが発生したとき、チャネルをクローズして goroutine を終了するので、受信データが nil になる場合があります。このときは、空リスト NIL を返すようにしています。

ここで親子ルーチンの同期処理が行われていることに注意してください。チャネルのバッファは 0 なので、チャネルからデータを受信するまで goroutine の実行は中断されます。子コルーチンは yield を実行したあと、親コルーチンからのデータを受信するため、実行を中断しています。したがって、y.ch2 へデータを送信すると、子コルーチンの実行が再開されて、親コルーチンはデータを受信するまで実行が中断されます。

関数 yield のプログラムは次のようになります。

リスト : コルーチンの中断

func yield(v Value) Value {
    x := runCOR
    if x == nil {
        panic(fmt.Errorf("yield: Not coroutine"))
    }
    runCOR = x.save
    x.ch1 <- v
    return <- x.ch2
}

yield は実行中のコルーチンを大域変数 runCOR から取り出して、変数 x にセットします。x が nil ならばエラーを送出します。それから、runCOR の値を x.save に戻して、x.ch1 に引数 v を送信します。この値は resume の返り値になります。あとは、x.ch2 からデータを受信するまで実行を中断し、データを受信したらその値を返すだけです。

●コルーチンの生成

最後にコルーチンを生成する関数 create を作ります。

リスト ; コルーチンの生成

// コルーチンの生成
func create(v Value) Value {
    x := newCOR(make(chan Value), make(chan Value))
    fn, ok := v.(Func)
    if ok {
        if fn.Argc() == 0 {
            go func() {
                defer func(){
                    err := recover()
                    if err != nil {
                        fmt.Fprintln(os.Stderr, err)
                    }
                    x.live = false
                    runCOR = x.save
                    close(x.ch1)
                    close(x.ch2)
                }()
                <- x.ch2
                x.ch1 <- appFunc(fn, nil, nil)
            } ()
        } else {
                panic(fmt.Errorf("create function: wrong number of arguments"))
        }
    } else {
        panic(fmt.Errorf("create: function required"))
    }
    return x
}

最初に newCOR でコルーチンの実体を生成して変数 x にセットします。次に、引数 v が引数無しの関数であることを確認します。そうでなければエラーを送出します。そして、go func(){ ... }() で goroutine を起動します。

匿名関数の中で、最初に x.ch2 からデータを受信するまで待ちます。つまり、親コルーチンで最初に resume を実行するまで、コルーチンは実行されません。そして、このときの resume の引数は捨てられることに注意してください。

関数 fn の実行は appFunc を呼び出すだけです。引数無しの関数なので、appFunc の第 2 引数は nil (空のスライス) で、局所変数の環境も nil でかまいません。create に渡される関数が匿名関数 (fn 式) であれば、そのときに有効な局所変数の環境はクロージャ内に保存されています。ただの関数であれば、そのまま呼び出せばいいわけです。

コルーチンで送出されたエラーを捕捉するため、defer 文で recover を呼び出します。エラーの場合はエラーメッセージを表示して goroutine の実行を終了します。このとき、チャネル x.ch1, x.ch2 を close でクローズして、x.live を false に書き換え、runCOR の値を x.save に戻します。

あとの修正は簡単なので説明は割愛します。詳細は プログラムリスト をお読みください。

今回はここまでです。次回はコルーチンを使った簡単なサンプルプログラムを作ってみましょう。


●プログラムリスト

//
// calc10.go : 電卓プログラム (ノンプリミティブなコルーチンを追加)
//
//             Copyright (C) 2014-2021 Makoto Hiroi
//
package main

import (
    "fmt"
    "os"
    "math"
    "time"
    "text/scanner"
)

// キーワード
const (
    DEF = -(iota+10)
    END
    IF
    THEN
    ELSE
    NOT
    AND
    OR
    EQ
    NE
    LT
    GT
    LE
    GE
    BGN
    WHL
    DO
    LET
    IN
    FN
    CALL
    LIST
    DELAY
)

// キーワード表
var keyTable = make(map[string]rune)

func initKeyTable() {
    keyTable["def"]   = DEF
    keyTable["end"]   = END
    keyTable["if"]    = IF
    keyTable["then"]  = THEN
    keyTable["else"]  = ELSE
    keyTable["and"]   = AND
    keyTable["or"]    = OR
    keyTable["not"]   = NOT
    keyTable["begin"] = BGN
    keyTable["while"] = WHL
    keyTable["do"]    = DO
    keyTable["let"]   = LET
    keyTable["in"]    = IN
    keyTable["fn"]    = FN
    keyTable["call"]  = CALL
    keyTable["list"]  = LIST
    keyTable["delay"] = DELAY
}

// 値
type Value interface {
    isTrue() bool
}

// 変数
type Variable string

// 局所変数の環境
type Env struct {
    name Variable
    val  Value
    next *Env
}

// 構文木
type Expr interface {
    Eval(*Env) Value
}

// 数
type Num interface {
    Value
    Expr
    neg() Value
    sign() int
    add(Value) Value
    sub(Value) Value
    mul(Value) Value
    div(Value) Value
}

// 比較
type Cmp interface {
    compare(Value) int
}

//
// 数値
//
type Int int64
type Flt float64

// エラー
func errorNum(mes string, v Value) error {
    return fmt.Errorf("%sNumber required, %v", mes, v)
}

func errorInt(mes string, v Value) error {
    return fmt.Errorf("%sInteger required, %v", mes, v)
}

// Value の実装
func (n Int) isTrue() bool { return n != 0 }
func (n Flt) isTrue() bool { return n != 0.0 }

// Expr の実装
func (e Int) Eval(_ *Env) Value { return e }
func (e Flt) Eval(_ *Env) Value { return e }

// 符号の反転
func (n Int) neg() Value { return -n }
func (n Flt) neg() Value { return -n }

// 符号を求める
func (n Int) sign() int {
    switch {
    case n > 0: return 1
    case n < 0: return -1
    default: return 0
    }
}

func (n Flt) sign() int {
    switch {
    case n > 0.0: return 1
    case n < 0.0: return -1
    default: return 0
    }
}

// 四則演算
func (n Int) add(x Value) Value {
    switch m := x.(type) {
    case Int: return n + m
    case Flt: return Flt(n) + m
    }
    panic(errorNum("+, ", x))
}

func (n Flt) add(x Value) Value {
    switch m := x.(type) {
    case Int: return n + Flt(m)
    case Flt: return n + m
    }
    panic(errorNum("+, ", x))
}

func (n Int) sub(x Value) Value {
    switch m := x.(type) {
    case Int: return n - m
    case Flt: return Flt(n) - m
    }
    panic(errorNum("-, ", x))
}

func (n Flt) sub(x Value) Value {
    switch m := x.(type) {
    case Int: return n - Flt(m)
    case Flt: return n - m
    }
    panic(errorNum("-, ", x))
}

func (n Int) mul(x Value) Value {
    switch m := x.(type) {
    case Int: return n * m
    case Flt: return Flt(n) * m
    }
    panic(errorNum("*, ", x))
}

func (n Flt) mul(x Value) Value {
    switch m := x.(type) {
    case Int: return n * Flt(m)
    case Flt: return n * m
    }
    panic(errorNum("*, ", x))
}

func (n Int) div(x Value) Value {
    switch m := x.(type) {
    case Int: return n / m
    case Flt: return Flt(n) / m
    }
    panic(errorNum("/, ", x))
}

func (n Flt) div(x Value) Value {
    switch m := x.(type) {
    case Int: return n / Flt(m)
    case Flt: return n / m
    }
    panic(errorNum("/, ", x))
}

// Cmp の実装
func (n Int) compare(x Value) int {
    switch m := x.(type) {
    case Int: return (n - m).sign()
    case Flt: return (Flt(n) - m).sign()
    }
    panic(errorNum("compare, ", x))
}

func (n Flt) compare(x Value) int {
    switch m := x.(type) {
    case Int: return (n - Flt(m)).sign()
    case Flt: return (n - m).sign()
    }
    panic(errorNum("compare, ", x))
}

//
// 文字列
//
type Str string

// エラー
func errorStr(mes string, v Value) error {
    return fmt.Errorf("%sString required, %v", mes, v)
}

// Value, Expr, Cmp の実装
func (_ Str) isTrue() bool { return true }
func (s Str) Eval(_ *Env) Value { return s }

func (n Str) compare(x Value) int {
    m, ok := x.(Str)
    if !ok {
        panic(errorStr("compare, ", x))
    }
    switch {
    case n > m: return 1
    case n < m: return -1
    default: return 0
    }
}

//
// 配列
//
type Vec []Value

// エラー
func errorVec(mes string, v Value) error {
    return fmt.Errorf("%sVector required , %v", mes, v)
}

// Value の実装
func (_ Vec) isTrue() bool { return true }

// ベクタ生成用構文木
type Crv struct {
    xs []Expr
}

func newCrv(xs []Expr) *Crv {
    return &Crv{xs}
}

// ベクタの生成
func (a *Crv) Eval(env *Env) Value {
    v := make(Vec, len(a.xs))
    for i := 0; i < len(a.xs); i++ {
        v[i] = a.xs[i].Eval(env)
    }
    return v
}

// ベクタの参照
type Ref struct {
    name Variable
    idxs []Expr
}

func newRef(name Variable, idxs []Expr) *Ref {
    return &Ref{name, idxs}
}

// アクセス位置を求める
func getPos(a *Ref, env *Env) *Value {
    v := a.name.Eval(env)
    for k := 0; ; k++ {
        xs, ok := v.(Vec)
        if !ok {
            panic(errorVec("", v))
        }
        y := a.idxs[k].Eval(env)
        i := toInt(y)
        if j := int(i); j < 0 || j >= len(xs) {
            panic(fmt.Errorf("Out of Range, %v, %v", xs, j))
        } else if k == len(a.idxs) - 1 {
            return &xs[j]
        } else {
            v = xs[j]
        }
    }
}

// 評価
func (a *Ref) Eval(env *Env) Value {
    return *getPos(a, env)
}
    
// ベクタの更新
type Udt struct {
    ref *Ref
    val Expr
}

func newUdt(ref *Ref, val Expr) *Udt {
    return &Udt{ref, val}
}

// 評価
func (a *Udt) Eval(env *Env) Value {
    x := getPos(a.ref, env)
    v := a.val.Eval(env)
    *x = v
    return v
}

//
// 連結リスト
//
type Cell struct {
    car, cdr Value
}

func newCell(a, b Value) *Cell {
    return &Cell{a, b}
}

// 空リスト
var NIL *Cell

// NIL の初期化
func makeNIL() {
    NIL = newCell(nil, nil)
    NIL.car = NIL
    NIL.cdr = NIL
    globalEnv["nil"] = NIL
}

// Value の実装
func (x *Cell) isTrue() bool { return x != NIL }

// リストの生成
type List struct {
    xs []Expr
}

func newList(xs []Expr) *List {
    return &List{xs}
}

func (x *List) Eval(env *Env) Value {
    cp := NIL
    for i := len(x.xs) - 1; i >= 0; i-- {
        cp = newCell(x.xs[i].Eval(env), cp)
    }
    return cp
}

// リスト -> 文字列
func sprintlist(cp *Cell) string {
    s := "("
    for cp != NIL {
        s += sprintValue(cp.car)
        b, ok := cp.cdr.(*Cell)
        if ok {
            cp = b
            if cp != NIL { s += " " }
        } else {
            s += " . " + sprintValue(cp.cdr)
            break
        }
    }
    s += ")"
    return s
}

// Vec -> string
func sprintvec(xs Vec) string {
    s := "["
    i := 0
    for ; i < len(xs) - 1; i++ {
        s += sprintValue(xs[i]) + ", "
    }
    s += sprintValue(xs[i]) + "]"
    return s
}

// Value -> string
func sprintValue(v Value) string {
    switch x := v.(type) {
    case *Cell: return sprintlist(x)
    case Vec: return sprintvec(x)
    default: return fmt.Sprint(x)
    }
}

//
// 変数と環境
//

// 環境の生成
func newEnv(name Variable, val Value, env *Env) *Env {
    return &Env{name, val, env}
}

func makeBinding(xs []Variable, es []Expr, env *Env) *Env {
    var env1 *Env
    for i := 0; i < len(xs); i++ {
        env1 = newEnv(xs[i], es[i].Eval(env), env1)
    }
    return env1
}

// 局所変数を環境に追加
func addBinding(xs []Variable, es []Expr, env *Env) *Env {
    for i := 0; i < len(xs); i++ {
        env = newEnv(xs[i], es[i].Eval(env), env)
    }
    return env
}

// クロージャ用
func addBindingClo(xs []Variable, es []Expr, env, clo *Env) *Env {
    for i := 0; i < len(xs); i++ {
        clo = newEnv(xs[i], es[i].Eval(env), clo)
    }
    return clo
}

// 局所変数の探索
func lookup(name Variable, env *Env) (Value, bool) {
    for ; env != nil; env = env.next {
        if name == env.name {
            return env.val, true
        }
    }
    return Int(0), false
}

// 局所変数の更新
func update(name Variable, val Value, env *Env) bool {
    for ; env != nil; env = env.next {
        if name == env.name {
            env.val = val
            return true
        }
    }
    return false
}

// 大域的な環境
var globalEnv = make(map[Variable]Value)

// 変数の評価
func (v Variable) Eval(env *Env) Value {
    // 局所変数の探索
    val, ok := lookup(v, env)
    if ok {
        return val
    }
    // 大域変数の探索
    val, ok = globalEnv[v]
    if !ok {
        panic(fmt.Errorf("unbound variable, %v", v))
    }
    return val
}

//
// 単項演算子
//
type Op1 struct {
    code rune
    expr Expr
}

func newOp1(code rune, e Expr) Expr {
    return &Op1{code, e}
}

// bool を Value に変換する
func boolToValue(x bool) Value {
    if x {
        return Int(1)
    } else {
        return Int(0)
    }
}

// 型変換とチェック
func toNum(v Value) Num {
    n, ok := v.(Num)
    if !ok {
        panic(errorNum("", v))
    }
    return n
}

// 評価
func (e *Op1) Eval(env *Env) Value {
    v := e.expr.Eval(env)
    switch e.code {
    case '-': return toNum(v).neg()
    case '+': return v
    case NOT: return boolToValue(!v.isTrue())
    default:
        panic(fmt.Errorf("invalid Op1 code"))
    }
}

//
// 二項演算子
//
type Op2 struct {
    code rune
    left, right Expr
}

func newOp2(code rune, left, right Expr) Expr {
    return &Op2{code, left, right}
}

// 型変換とチェック
func toInt(v Value) Int {
    n, ok := v.(Int)
    if !ok {
        errorInt("", v)
    }
    return n
}

func toCmp(v Value) Cmp {
    n, ok := v.(Cmp)
    if !ok {
        panic(fmt.Errorf("%v is uncomparable type", v))
    }
    return n
}

// 評価
func (e *Op2) Eval(env *Env) Value {
    x := e.left.Eval(env)
    y := e.right.Eval(env)
    switch e.code {
    case '+': return toNum(x).add(y)
    case '-': return toNum(x).sub(y)
    case '*': return toNum(x).mul(y)
    case '/': return toNum(x).div(y)
    case '%': return toInt(x) % toInt(y)
    case EQ:  return boolToValue(toCmp(x).compare(y) == 0)
    case NE:  return boolToValue(toCmp(x).compare(y) != 0)
    case LT:  return boolToValue(toCmp(x).compare(y) < 0)
    case GT:  return boolToValue(toCmp(x).compare(y) > 0)
    case LE:  return boolToValue(toCmp(x).compare(y) <= 0)
    case GE:  return boolToValue(toCmp(x).compare(y) >= 0)
    default:
        panic(fmt.Errorf("invalid Op2 code"))
    }
}

//
// 短絡演算子
//
type Ops struct {
    code rune
    left, right Expr
}

func newOps(code rune, left, right Expr) Expr {
    return &Ops{code, left, right}
}

// 評価
func (e *Ops) Eval(env *Env) Value {
    x := e.left.Eval(env)
    switch e.code {
    case AND:
        if x.isTrue() {
            return e.right.Eval(env)
        }
        return x
    case OR:
        if x.isTrue() {
            return x
        }
        return e.right.Eval(env)
    default:
        panic(fmt.Errorf("invalid Ops code"))
    }
}

//
// if
//
type Sel struct {
    testForm, thenForm, elseForm Expr
}

func newSel(testForm, thenForm, elseForm Expr) *Sel {
    return &Sel{testForm, thenForm, elseForm}
}

func (e *Sel) Eval(env *Env) Value {
    if e.testForm.Eval(env).isTrue() {
        return e.thenForm.Eval(env)
    }
    return e.elseForm.Eval(env)
}

//
// while
//
type Whl struct {
    testForm, body Expr
}

func newWhl(testForm, body Expr) *Whl {
    return &Whl{testForm, body}
}

func (e *Whl) Eval(env *Env) Value {
    for e.testForm.Eval(env).isTrue() {
        e.body.Eval(env)
    }
    return Int(0)
}

//
// begin
//
type Bgn struct {
    body []Expr
}

func newBgn(xs []Expr) *Bgn {
    return &Bgn{xs}
}

func (e *Bgn) Eval(env *Env) Value {
    var r Value
    for _, expr := range e.body {
        r = expr.Eval(env)
    }
    return r
}

//
// let 
//
type Let struct {
    vars []Variable
    vals []Expr
    body Expr
}

func newLet(vars []Variable, vals []Expr, body Expr) *Let {
    return &Let{vars, vals, body}
}

func (e *Let) Eval(env *Env) Value {
    return e.body.Eval(addBinding(e.vars, e.vals, env))
}

//
// 代入演算子
//
type Agn struct {
    name Variable
    expr Expr
}

func newAgn(v Variable, e Expr) *Agn {
    return &Agn{v, e}
}

func (a *Agn) Eval(env *Env) Value {
    val := a.expr.Eval(env)
    if !update(a.name, val, env) {
        globalEnv[a.name] = val
    }
    return val
}

//
// 関数
//
type Func interface {
    Value
    Expr
    Argc() int
}

type Func1 func(float64) float64

func (f Func1) Argc() int {    return 1 }
func (f Func1) isTrue() bool { return true }
func (f Func1) Eval(_ *Env) Value {    return f }
func (f Func1) String() string { return "<Function1>" }

type Func2 func(float64, float64) float64

func (f Func2) Argc() int {    return 2 }
func (f Func2) isTrue() bool { return true }
func (f Func2) Eval(_ *Env) Value {    return f }
func (f Func2) String() string { return "<Function2>" }

type FuncV0 func() Value

func (f FuncV0) Argc() int { return 0 }
func (f FuncV0) isTrue() bool {    return true }
func (f FuncV0) Eval(_ *Env) Value { return f }
func (f FuncV0) String() string { return "<FunctionV1>" }

type FuncV1 func(Value) Value

func (f FuncV1) Argc() int { return 1 }
func (f FuncV1) isTrue() bool {    return true }
func (f FuncV1) Eval(_ *Env) Value { return f }
func (f FuncV1) String() string { return "<FunctionV1>" }

type FuncV2 func(Value, Value) Value

func (f FuncV2) Argc() int { return 2 }
func (f FuncV2) isTrue() bool {    return true }
func (f FuncV2) Eval(_ *Env) Value { return f }
func (f FuncV2) String() string { return "<FunctionV2>" }

// ユーザ定義関数
type FuncU struct {
    name string
    xs   []Variable
    body Expr
}

func newFuncU(name string, xs []Variable, body Expr) *FuncU {
    return &FuncU{name, xs, body}
}

func (f *FuncU) Argc() int { return len(f.xs) }
func (f *FuncU) isTrue() bool {    return true }
func (f *FuncU) Eval(_ *Env) Value { return f }
func (f *FuncU) String() string { return fmt.Sprintf("<%s>", f.name) }

// クロージャ
type Clo struct {
    xs []Variable
    body Expr
}

func newClo(xs []Variable, body Expr) *Clo {
    return &Clo{xs, body}
}

type FuncC struct {
    xs   []Variable
    body Expr
    env  *Env
}

func (f *FuncC) Argc() int { return len(f.xs) }
func (_ *FuncC) isTrue() bool { return true }
func (f *FuncC) Eval(_ *Env) Value { return f }
func (_ *FuncC) String() string { return "<Closure>" }

func (a *Clo) Eval(env *Env) Value {
    return &FuncC{a.xs, a.body, env}
}

//
// 関数呼び出し
//
type App struct {
    fn Func
    xs []Expr
}

func newApp(fn Func, xs []Expr) *App {
    return &App{fn, xs}
}

type AppV struct {
    fn Expr
    xs []Expr
}

func newAppV(fn Expr, xs []Expr) *AppV {
    return &AppV{fn, xs}
}

func valueToFloat(v Value) float64 {
    switch x := v.(type) {
    case Int: return float64(x)
    case Flt: return float64(x)
    default:
        panic(errorNum("", v))
    }
}

// 評価
func appFunc(fn Func, xs []Expr, env *Env) Value {
    switch f := fn.(type) {
    case Func1:
        x := valueToFloat(xs[0].Eval(env))
        return Flt(f(x))
    case Func2:
        x := valueToFloat(xs[0].Eval(env))
        y := valueToFloat(xs[1].Eval(env))
        return Flt(f(x, y))
    case FuncV0:
        return f()
    case FuncV1:
        return f(xs[0].Eval(env))
    case FuncV2:
        return f(xs[0].Eval(env), xs[1].Eval(env))
    case *FuncU:
        return f.body.Eval(makeBinding(f.xs, xs, env))
    case *FuncC:
        return f.body.Eval(addBindingClo(f.xs, xs, env, f.env))
    default:
        panic(fmt.Errorf("%v is not function", f))
    }
}

func (a *App) Eval(env *Env) Value {
    return appFunc(a.fn, a.xs, env)
}

func (a *AppV) Eval(env *Env) Value {
    v := a.fn.Eval(env)
    fn, ok := v.(Func)
    if !ok {
        panic(fmt.Errorf("%v is not function", v))
    }
    if fn.Argc() != len(a.xs) {
        panic(fmt.Errorf("wrong number of arguments"))
    }
    return appFunc(fn, a.xs, env)
}

//
// 遅延評価
//
type Promise struct {
    val  Value
    expr Expr
    env  *Env
}

func newPromise(expr Expr, env *Env) *Promise {
    return &Promise{nil, expr, env}
}

func (_ *Promise) isTrue() bool { return true }
func (_ *Promise) String() string { return "<Promise>" }

// Promise の生成
type MakePromise struct {
    expr Expr
}

func newMakePromise(expr Expr) *MakePromise {
    return &MakePromise{expr}
}

func (p *MakePromise) Eval(env *Env) Value {
    return newPromise(p.expr, env)
}

// 評価
func force(v Value) Value {
    p, ok := v.(*Promise)
    if !ok {
        panic(fmt.Errorf("force: Promise required, %v", v))
    }
    if p.val == nil {
        p.val = p.expr.Eval(p.env)
    }
    return p.val
}

//
// コルーチン
//
type COR struct {
    ch1, ch2 chan Value
    save *COR
    live bool
}

func newCOR(ch1, ch2 chan Value) *COR {
    return &COR{ch1, ch2, nil, true}
}

func (_ *COR) isTrue() bool { return true }
func (_ *COR) String() string { return "<Coroutine>" }

// 実行中のコルーチン
var runCOR *COR

// 生成
func create(v Value) Value {
    x := newCOR(make(chan Value), make(chan Value))
    fn, ok := v.(Func)
    if ok {
        if fn.Argc() == 0 {
            go func() {
                defer func(){
                    err := recover()
                    if err != nil {
                        fmt.Fprintln(os.Stderr, err)
                    }
                    x.live = false
                    runCOR = x.save
                    close(x.ch1)
                    close(x.ch2)
                }()
                <- x.ch2
                x.ch1 <- appFunc(fn, nil, nil)
            } ()
        } else {
                panic(fmt.Errorf("create function: wrong number of arguments"))
        }
    } else {
        panic(fmt.Errorf("create: function required"))
    }
    return x
}

// 中断
func yield(v Value) Value {
    x := runCOR
    if x == nil {
        panic(fmt.Errorf("yield: Not coroutine"))
    }
    runCOR = x.save
    x.ch1 <- v
    return <- x.ch2
}

// 再開
func resume(x, v Value) Value {
    y, ok := x.(*COR)
    if !ok {
        panic(fmt.Errorf("resume: Coroutine required, %v", x))
    }
    if !y.live {
        panic(fmt.Errorf("resume: Dead Coroutine"))
    }
    y.save = runCOR
    runCOR = y
    y.ch2 <- v
    z := <- y.ch1
    if z == nil { z = NIL }
    return z
}

func isAlive(x Value) Value {
    y, ok := x.(*COR)
    if !ok {
        panic(fmt.Errorf("resume: Coroutine required, %v", x))
    }
    return boolToValue(y.live)
}

//
// 組み込み関数の初期化
//
var funcTable = make(map[string]Func)

// 値の表示
func print(x Value) Value {
    fmt.Print(sprintValue(x))
    return x
}

func println(x Value) Value {
    fmt.Println(sprintValue(x))
    return x
}

// ベクタの生成
func makeVector(n, x Value) Value {
    k := toInt(n)
    xs := make(Vec, int(k))
    for i := 0; i < int(k); i++ {
        xs[i] = x
    }
    return xs
}

func length(v Value) Value {
    xs, ok := v.(Vec)
    if !ok {
        panic(errorVec("len, ", v))
    }
    return Int(len(xs))
}

func isInt(v Value) Value {
    _, ok := v.(Int)
    return boolToValue(ok)
}

func isFlt(v Value) Value {
    _, ok := v.(Flt)
    return boolToValue(ok)
}

func isVec(v Value) Value {
    _, ok := v.(Vec)
    return boolToValue(ok)
}

func isStr(v Value) Value {
    _, ok := v.(Str)
    return boolToValue(ok)
}

func isFunc(v Value) Value {
    _, ok := v.(Func)
    return boolToValue(ok)
}

// 等値の判定
func eql(x, y Value) Value {
    switch a := x.(type) {
    case Func: return Int(0)
    case Vec: return Int(0)
    default: return boolToValue(a == y)
    }
}

// リスト関連
func pair(v Value) Value {
    _, ok := v.(*Cell)
    return boolToValue(ok && v != NIL)
}

func null(v Value) Value {
    return boolToValue(v == NIL)
}

func cons(x, y Value) Value {
    return newCell(x, y)
}

func car(v Value) Value {
    cp, ok := v.(*Cell)
    if !ok {
        panic(fmt.Errorf("car: List required, %v", v))
    }
    return cp.car
}

func cdr(v Value) Value {
    cp, ok := v.(*Cell)
    if !ok {
        panic(fmt.Errorf("cdr: List required, %v", v))
    }
    return cp.cdr
}

func setCar(v, x Value) Value {
    cp, ok := v.(*Cell)
    if !ok {
        panic(fmt.Errorf("setCar: List required, %v", v))
    }
    cp.car = x
    return x
}

func setCdr(v, x Value) Value {
    cp, ok := v.(*Cell)
    if !ok {
        panic(fmt.Errorf("setCdr: List required, %v", v))
    }
    cp.cdr = x
    return x
}

// 時間
func clock() Value {
    return Int(time.Now().UnixNano())
}

func since(v Value) Value {
    x, ok := v.(Int)
    if !ok {
        panic(errorInt("since: ", v))
    }
    y := time.Now().UnixNano()
    return Str(time.Duration(y - int64(x)).String())
}

// エラー
func userError(v Value) Value {
    panic(fmt.Errorf("%v", v))
}

// 入力
func input(mes Value) Value {
    var lex Lex
    lex.Init(os.Stdin)
    print(mes)
    lex.getToken()
    e := expression(&lex)
    return e.Eval(nil)
}

func initFunc() {
    funcTable["sqrt"]  = Func1(math.Sqrt)
    funcTable["sin"]   = Func1(math.Sin)
    funcTable["cos"]   = Func1(math.Cos)
    funcTable["tan"]   = Func1(math.Tan)
    funcTable["sinh"]  = Func1(math.Sinh)
    funcTable["cosh"]  = Func1(math.Cosh)
    funcTable["tanh"]  = Func1(math.Tanh)
    funcTable["asin"]  = Func1(math.Asin)
    funcTable["acos"]  = Func1(math.Acos)
    funcTable["atan"]  = Func1(math.Atan)
    funcTable["atan2"] = Func2(math.Atan2)
    funcTable["exp"]   = Func1(math.Exp)
    funcTable["pow"]   = Func2(math.Pow)
    funcTable["log"]   = Func1(math.Log)
    funcTable["log10"] = Func1(math.Log10)
    funcTable["log2"]  = Func1(math.Log2)
    funcTable["print"] = FuncV1(print)
    funcTable["println"] = FuncV1(println)
    funcTable["len"]   = FuncV1(length)
    funcTable["vector"] = FuncV2(makeVector)
    funcTable["isInt"] = FuncV1(isInt)
    funcTable["isFlt"] = FuncV1(isFlt)
    funcTable["isStr"] = FuncV1(isStr)
    funcTable["isVec"] = FuncV1(isVec)
    funcTable["isFunc"] = FuncV1(isFunc)
    funcTable["eql"]  = FuncV2(eql)
    funcTable["pair"] = FuncV1(pair)
    funcTable["null"] = FuncV1(null)
    funcTable["cons"] = FuncV2(cons)
    funcTable["car"] = FuncV1(car)
    funcTable["cdr"] = FuncV1(cdr)
    funcTable["setCar"] = FuncV2(setCar)
    funcTable["setCdr"] = FuncV2(setCdr)
    funcTable["clock"] = FuncV0(clock)
    funcTable["since"] = FuncV1(since)
    funcTable["error"] = FuncV1(userError)
    funcTable["input"] = FuncV1(input)
    funcTable["force"] = FuncV1(force)
    funcTable["create"] = FuncV1(create)
    funcTable["yield"] = FuncV1(yield)
    funcTable["isAlive"] = FuncV1(isAlive)
    funcTable["resume"] = FuncV2(resume)
}

//
// 字句解析
//
type Lex struct {
    scanner.Scanner
    Token rune
}

func (lex *Lex) getToken() {
    lex.Token = lex.Scan()
    switch lex.Token {
    case scanner.Ident:
        key, ok := keyTable[lex.TokenText()]
        if ok {
            lex.Token = key
        }
    case '=':
        if lex.Peek() == '=' {
            lex.Next()
            lex.Token = EQ
        }
    case '!':
        if lex.Peek() == '=' {
            lex.Next()
            lex.Token = NE
        } else {
            lex.Token = NOT
        }
    case '<':
        if lex.Peek() == '=' {
            lex.Next()
            lex.Token = LE
        } else {
            lex.Token = LT
        }
    case '>':
        if lex.Peek() == '=' {
            lex.Next()
            lex.Token = GE
        } else {
            lex.Token = GT
        }
    }
}

//
// 構文解析
//

// 仮引数の取得
func getParameter(lex *Lex) []Variable {
    e := make([]Variable, 0)
    if lex.Token != '(' {
        panic(fmt.Errorf("'(' expected"))
    }
    lex.getToken()
    if lex.Token == ')' {
        lex.getToken()
        return e
    }
    for {
        if lex.Token == scanner.Ident {
            e = append(e, Variable(lex.TokenText()))
            lex.getToken()
            switch lex.Token {
            case ')':
                lex.getToken()
                return e
            case ',':
                lex.getToken()
            default:
                panic(fmt.Errorf("unexpected token in parameter list"))
            }
        } else {
            panic(fmt.Errorf("unexpected token in parameter list"))
        }
    }
}


// 引数の取得
func getArgs(lex *Lex) []Expr {
    e := make([]Expr, 0)
    if lex.Token != '(' {
        panic(fmt.Errorf("'(' expected"))
    }
    lex.getToken()
    if lex.Token == ')' {
        lex.getToken()
        return e
    }
    for {
        e = append(e, expression(lex))
        switch lex.Token {
        case ')':
            lex.getToken()
            return e
        case ',':
            lex.getToken()
        default:
            panic(fmt.Errorf("unexpected token in argument list"))
        }
    }
}

// if 式の解析
func makeSel(lex *Lex) Expr {
    testForm := expression(lex)
    if lex.Token == THEN {
        lex.getToken()
        thenForm := expression(lex)
        switch lex.Token {
        case ELSE:
            lex.getToken()
            elseForm := expression(lex)
            if lex.Token != END {
                panic(fmt.Errorf("if, 'end' expected"))
            }
            lex.getToken()
            return newSel(testForm, thenForm, elseForm)
        case END:
            lex.getToken()
            return newSel(testForm, thenForm, Int(0))
        default:
            panic(fmt.Errorf("if, 'else' or 'end' expected"))
        } 
    } else {
        panic(fmt.Errorf("if, 'then' expected"))
    }
}

// begin 式の解析
func getBody(lex *Lex) []Expr {
    body := make([]Expr, 0)
    for {
        body = append(body, expression(lex))
        switch lex.Token {
        case ',':
            lex.getToken()
        default:
            return body
        }
    }
}

func makeBegin(lex *Lex) Expr {
    if lex.Token == END {
        panic(fmt.Errorf("invalid begin form"))
    }
    body := getBody(lex)
    if lex.Token != END {
        panic(fmt.Errorf("'end' expected"))
    }
    lex.getToken()
    return newBgn(body)
}

// while 式の解析
func makeWhile(lex *Lex) Expr {
    testForm := expression(lex)
    if lex.Token == DO {
        lex.getToken()
        return newWhl(testForm, makeBegin(lex))
    } else {
        panic(fmt.Errorf("'do' expected"))
    }
}

// let 式の解析
func makeLet(lex *Lex) Expr {
    vars := make([]Variable, 0)
    vals := make([]Expr, 0)
    for {
        e := expression(lex)
        a, ok := e.(*Agn)
        if !ok {
            panic(fmt.Errorf("let: invalid assign form"))
        }
        vars = append(vars, a.name)
        vals = append(vals, a.expr)
        if lex.Token == IN {
            break
        } else if lex.Token != ',' {
            panic(fmt.Errorf("let, 'in' or ',' expected"))
        }
        lex.getToken()
    }
    lex.getToken()
    return newLet(vars, vals, makeBegin(lex))
}

// 添字の取得
func getIndex(lex *Lex) []Expr {
    xs := make([]Expr, 0)
    lex.getToken()
    for {
        e := expression(lex)
        if lex.Token != ']' {
            panic(fmt.Errorf("']' expected"))
        }
        xs = append(xs, e)
        lex.getToken()
        if lex.Token != '[' {
            return xs
        }
        lex.getToken()
    }
}

// 因子
func factor(lex *Lex) Expr {
    switch lex.Token {
    case '(':
        lex.getToken()
        e := expression(lex)
        if lex.Token != ')' {
            panic(fmt.Errorf("')' expected"))
        }
        lex.getToken()
        return e
    case '[':
        lex.getToken()
        xs := getBody(lex)
        if lex.Token != ']' {
            panic(fmt.Errorf("']' expected"))
        }
        lex.getToken()
        return newCrv(xs)
    case '+':
        lex.getToken()
        return newOp1('+', factor(lex))
    case '-':
        lex.getToken()
        return newOp1('-', factor(lex))
    case NOT:
        lex.getToken()
        return newOp1(NOT, factor(lex))
    case IF:
        lex.getToken()
        return makeSel(lex)
    case BGN:
        lex.getToken()
        return makeBegin(lex)
    case WHL:
        lex.getToken()
        return makeWhile(lex)
    case LET:
        lex.getToken()
        return makeLet(lex)
    case FN:
        lex.getToken()
        xs := getParameter(lex)
        body := makeBegin(lex)
        clo := newClo(xs, body)
        if lex.Token == '(' {
            ys := getArgs(lex)
            if len(xs) != len(ys) {
                panic(fmt.Errorf("wrong number of arguments: fn"))
            }
            return newAppV(clo, ys)
        }
        return clo
    case LIST:
        lex.getToken()
        if lex.Token == '(' {
            return newList(getArgs(lex))
        }
        panic(fmt.Errorf("List: '(' expected"))
    case CALL:
        lex.getToken()
        xs := getArgs(lex)
        if len(xs) == 0 {
            panic(fmt.Errorf("wrong number of arguments: call"))
        }
        return newAppV(xs[0], xs[1:])
    case DELAY:
        lex.getToken()
        if lex.Token != '(' {
            panic(fmt.Errorf("delay: '(' expected"))
        }
        lex.getToken()
        e := expression(lex)
        if lex.Token != ')' {
            panic(fmt.Errorf("delay: ')' expected"))
        }
        lex.getToken()
        return newMakePromise(e)
    case scanner.Int:
        var n int64
        fmt.Sscan(lex.TokenText(), &n)
        lex.getToken()
        return Int(n)
    case scanner.Float:
        var n float64
        fmt.Sscan(lex.TokenText(), &n)
        lex.getToken()
        return Flt(n)
    case scanner.String:
        var s string
        fmt.Sscanf(lex.TokenText(), "%q", &s)
        lex.getToken()
        return Str(s)
    case scanner.Ident:
        name := lex.TokenText()
        lex.getToken()
        if name == "quit" {
            panic(name)
        }
        v, ok := funcTable[name]
        if ok {
            if lex.Token == '(' {
                xs := getArgs(lex)
                if len(xs) != v.Argc() {
                    panic(fmt.Errorf("wrong number of arguments: %v", name))
                }
                return newApp(v, xs)
            } else {
                // 関数値をそのまま返す
                return v
            }
        } else if lex.Token == '[' {
            return newRef(Variable(name), getIndex(lex))
        } else if lex.Token == '(' {
            // 変数に格納されている関数を呼び出す
            return newAppV(Variable(name), getArgs(lex))
        } else {
            return Variable(name)
        }
    default:
        panic(fmt.Errorf("unexpected token: %v", lex.TokenText()))
    }
}

// 項
func term(lex *Lex) Expr {
    e := factor(lex)
    for {
        switch lex.Token {
        case '*':
            lex.getToken()
            e = newOp2('*', e, factor(lex))
        case '/':
            lex.getToken()
            e = newOp2('/', e, factor(lex))
        case '%':
            lex.getToken()
            e = newOp2('%', e, factor(lex))
        default:
            return e
        }
    }
}

// 式
func expr3(lex *Lex) Expr {
    e := term(lex)
    for {
        switch lex.Token {
        case '+':
            lex.getToken()
            e = newOp2('+', e, term(lex))
        case '-':
            lex.getToken()
            e = newOp2('-', e, term(lex))
        default:
            return e
        }
    }
}

// 比較演算子
func expr2(lex *Lex) Expr {
    e := expr3(lex)
    x := lex.Token
    switch x {
    case EQ, NE, LT, GT, LE, GE:
        lex.getToken()
        return newOp2(x, e, expr3(lex))
    default:
        return e
    }
}

// 論理演算子
func expr1(lex *Lex) Expr {
    e := expr2(lex)
    for {
        x := lex.Token
        switch x {
        case AND, OR:
            lex.getToken()
            e = newOps(x, e, expr2(lex))
        default:
            return e
        }
    }
}

func expression(lex *Lex) Expr {
    e := expr1(lex)
    if lex.Token == '=' {
        switch x := e.(type) {
        case Variable:
            lex.getToken()
            return newAgn(x, expression(lex))
        case *Ref:
            lex.getToken()
            return newUdt(x, expression(lex))
        default:
            panic(fmt.Errorf("invalid assign form"))
        }
    }
    return e
}

// ユーザ関数の定義
func defineFunc(lex *Lex) string {
    lex.getToken()
    if lex.Token != scanner.Ident {
        panic(fmt.Errorf("invalid define form"))
    }
    name := lex.TokenText()
    lex.getToken()
    xs := getParameter(lex)    
    v, ok := funcTable[name]
    if ok {
        switch f := v.(type) {
        case *FuncU:
            if len(f.xs) != len(xs) {
                panic(fmt.Errorf("wrong number of arguments: %v", name))
            }
            body := newBgn(getBody(lex))
            if lex.Token != END {
                panic(fmt.Errorf("'end' expected"))
            }
            f.xs = xs
            f.body = body
        default:
            panic(fmt.Errorf("%v is built-in function", name))
        }
    } else {
        // 再帰呼び出し対応
        f := newFuncU(name, xs, nil)
        funcTable[name] = f
        f.body = newBgn(getBody(lex))
        if lex.Token != END {
            delete(funcTable, name)
            panic(fmt.Errorf("'end' expected"))
        }
    }
    return name
}

// ライブラリのロード
func loadLib(name string) {
    var lex Lex
    file, err := os.Open(name)
    if err != nil {
        fmt.Fprintln(os.Stderr, err)
        os.Exit(1)
    }
    defer func(){
        err := recover()
        file.Close()
        if err != nil {
            fmt.Fprintf(os.Stderr, "%s: %v, %s\n", name, err, lex.Position)
            os.Exit(1)
        }
    }()
    lex.Init(file)
    for {
        lex.getToken()
        if lex.Token == scanner.EOF {
            break
        } else if lex.Token == DEF {
            defineFunc(&lex)
        } else {
            e := expression(&lex)
            if lex.Token != ';' {
                panic(fmt.Errorf("invalid expression"))
            }
            e.Eval(nil)
        }
    }
}

// 式の入力と評価
func toplevel(lex *Lex) (r bool) {
    r = false
    defer func(){
        err := recover()
        if err != nil {
            mes, ok := err.(string)
            if ok && mes == "quit" {
                r = true
            } else {
                fmt.Fprintln(os.Stderr, err)
                for {
                    c := lex.Peek()
                    if c == '\n' { break }
                    lex.Next()
                }
            }
        }
    }()
    for {
        fmt.Print("Calc> ")
        lex.getToken()
        if lex.Token == DEF {
            fmt.Println(defineFunc(lex))
        } else {
            e := expression(lex)
            if lex.Token != ';' {
                panic(fmt.Errorf("invalid expression"))
            } else {
                println(e.Eval(nil))
            }
        }
    }
    return r
}

func main() {
    initKeyTable()
    initFunc()
    makeNIL()
    for _, name := range os.Args[1:] {
        loadLib(name)
    }
    var lex Lex
    lex.Init(os.Stdin)
    for {
        if toplevel(&lex) { break }
    }
}

初版 2014 年 6 月 21 日
改訂 2021 年 12 月 22 日

Copyright (C) 2014-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | Golang | NextPage ]