SML/NJ は関数型言語なので、関数を引数として受け取る関数、いわゆる「高階関数 (higher order function)」を簡単に定義することができます。もちろん、値として関数を返すこともできるので、関数を作る関数を定義することも簡単です。実際、関数の操作は Common Lisp よりも柔軟で簡単です。
簡単な例として、リストの要素に関数 f を適用して、その結果をリストに格納して返す関数を作ってみましょう。このような操作を「マッピング (写像)」といいます。次のリストを見てください。
リスト : マッピング fun mapcar(_, nil) = nil | mapcar(f, x::xs) = f x :: mapcar(f, xs)
SML/NJ には同等の機能を持つ関数 map が定義されているので、関数名は mapcar としました。名前は Common Lisp から拝借しました。SML/NJ の場合、受け取った関数を呼び出すとき、特別なことを行う必要はありません。SML/NJ は引数 f が関数として使われているので、引数 f を関数型と推論してコンパイルします。関数を渡す場合も簡単です。関数が束縛されている変数を渡すだけでいいのです。とても簡単ですね。
mapcar を定義すると、次のように表示されます。
val mapcar = fn : ('a -> 'b) * 'a list -> 'b list
第 1 引数が関数型 'a -> 'b で第 2 引数がリスト 'a list になります。関数 f はリストの要素を受け取るので、関数 f の引数の型とリストの型は一致します。これを型変数 'a で表しています。同様に、関数 f の返り値の型と mapcar の返り値のリストの型は一致します。これを型変数 'b で表しています。このように、mapcar は多相型関数として定義されます。
それでは簡単な実行例を示しましょう。
- fun square x = x * x; val square = fn : int -> int - mapcar(square, [1, 2, 3, 4, 5]); val it = [1,4,9,16,25] : int list - mapcar(fn x => x * x, [1, 2, 3, 4, 5]); val it = [1,4,9,16,25] : int list
引数を 2 乗する関数 square を定義します。この関数を mapcar に渡すと、リストの要素が 2 乗されます。また、Common Lisp のラムダ式のように、SML/NJ にも「匿名関数」があります。匿名関数は次のように定義します。
fn(args, ...) => 式
匿名関数を定義する場合、fun ではなく fn を使うことに注意してください。mapcar に fn x => x * x を渡せば、リストの要素を 2 乗することができます。
フィルター (filter) はリストの要素に関数 f を適用し、関数 f が真を返す要素をリストに格納して返す関数です。真または偽を返す関数のことを「述語 (predicate)」といいます。SML/NJ には filter が定義されているので、ここでは述語が真を返す要素を削除する関数 remove_if を作ってみましょう。関数名は Common Lisp から拝借しました。
リスト : remove_if fun remove_if(_, nil) = nil | remove_if(p, x::xs) = if p x then remove_if(p, xs) else x::remove_if(p, xs)
remove_if も簡単ですね。p x が真ならば x をリストに加えず、偽ならば x をリストに加えるだけです。remove_if を定義すると次のようになります。
val remove_if = fn : ('a -> bool) * 'a list -> 'a list
関数 p が if のテストで使われているので、SML/NJ は関数 p の型を 'a -> bool と推論しています。もちろん、remove_if も多相型関数として定義されます。SML/NJ の型推論はとても便利ですね。簡単な実行例を示します。
- remove_if(fn x => x >= 10, [1, 10, 2, 12, 3, 13]); val it = [1,2,3] : int list - remove_if(fn x => x = "abc", ["abc", "def", "abc", "ghi"]); val it = ["def","ghi"] : string list
最初の例は述語が x >= 10 なので、10 以上の要素が削除されます。次の例は文字列 "abc" が削除されます。
2 つの引数を取る関数 f とリストを引数に受け取る関数 reduce を考えます。そして、reduce はリストの各要素に対して関数 f を次のように適用します。
(1) [a1, a2, a3, ..., an-1, an] => f( ... f( f( a1, a2 ), a3 ), ...), an-1 ), an ) (2) [a1, a2, a3, ..., an-1, an] => f( a1, f( a2, f( a3, ..., f( an-1, an ) ... )))
関数 f を適用する順番で 2 通りの方法があります。たとえば、関数 f が単純な加算関数とすると、reduce の結果はリストの要素の和になります。
f(x, y) = x + y の場合 : reduce => a1 + a2 + a3 + ... + an-1 + an
このような操作を「縮約」とか「畳み込み」といいます。reduce は引数に初期値 g を指定する場合があります。この場合は、次のような動作になります。
(1) [a1, a2, a3, ..., an-1, an] => f( ... f( f( g, a1 ), a2 ), ...), an-1 ), an ) (2) [a1, a2, a3, ..., an-1, an] => f( a1, f( a2, f( a3, ..., f( an, g ) ... )))
SML/NJ には foldl と foldr という同等の機能を持つ関数が定義されているので、ここでは関数 fold_left と fold_right を作ってみましょう。プログラムは次のようになります。
リスト : 畳み込み fun fold_left(f, g, nil) = g | fold_left(f, g, x::xs) = fold_left(f, f(g, x), xs) fun fold_right(f, g, nil) = g | fold_right(f, g, x::xs) = f(x, fold_right(f, g, xs))
第 1 引数 f が適用する関数、第 2 引数 g が初期値、第 3 引数がリストです。最初の定義は再帰呼び出しの停止条件ですが、引数に空リストが与えられた場合にも対応します。この場合は初期値 g を返します。次の節でリストの要素を取り出して関数 f を呼び出します。
たとえば、リストが [1, 2, 3] で g が 0 とします。最初は f(0, 1) が実行され、その返り値が fold_left の第 2 引数に渡されます。次は f(g, 2) が実行されますが、これは f(f(0, 1), 2) と同じことです。そして、その結果が fold_left の第 2 引数になります。最後に f(g, 3) が実行されますが、これは f(f(f(0, 1), 2), 3) となり、上図 (1) と同じ動作になります。
fold_left の場合、リストの要素が関数 f の第 2 引数になり、第 1 引数にはこれまでの処理結果が渡されます。これに対し、fold_right の場合は逆になり、関数 f の第 1 引数にリストの要素が渡されて、これまでの処理結果は第 2 引数に渡されます。これで上図 (2) の動作を実現することができます。
それでは、fold_left と fold_right の型を示します。
val fold_left = fn : ('a * 'b -> 'a) * 'a * 'b list -> 'a val fold_right = fn : ('a * 'b -> 'b) * 'b * 'a list -> 'b
fold_left と fold_right の返り値は初期値 g の型と同じになります。つまり、リストの要素と同じ型である必要はありません。そして、渡される関数の型にも注目してください。fold_left の場合、関数の第 1 引数と初期値の型が同じです。第 1 引数に今までの処理結果が渡されて、第 2 引数にリストの要素が渡されることがわかります。これに対し、fold_right は逆になっていることに注意してください。
簡単な実行例を示します。
- fold_left(op +, 0, [1,2,3,4,5]); val it = 15 : int - fold_left(fn(a, x) => x::a, nil, [1,2,3,4,5]); val it = [5,4,3,2,1] : int list - fold_right(op +, 0, [1,2,3,4,5]); val it = 15 : int - fold_right(op ::, nil, [1,2,3,4,5]); val it = [1,2,3,4,5] : int list
op は 2 項演算子を 2 引数の関数として使うためのものです。次の例を見てください。
- op +(1, 2); val it = 3 : int
二項演算子を関数として使用する場合は op を使ってください。上の例では、初期値に 0 を指定することで、リストの要素の合計値を求めています。また、初期値に nil を指定して、関数にコンス演算子を渡すと、リストの逆順やリストのコピーを求めることができます。
このように、fold_left, fold_right と 2 引数の関数を組み合わせることで、いろいろな関数を実現することができます。もう少しだけ簡単な例を示しましょう。最初は length です。
- fold_left(fn(a, _) => a + 1, 0, [1,1,1,1,1,1,1]); val it = 7 : int - fold_right(fn(_, a) => a + 1, 0, [1,1,1,1,1,1,1]); val it = 7 : int
fold_left で length を実現する場合、初期値を 0 にして第 1 引数の値を +1 することで実現できます。fold_right の場合は第 2 引数の値を +1 します。
次に map の例を示します。
- fold_right(fn(x, a) => (x * x)::a, nil, [1,2,3,4,5,6]); val it = [1,4,9,16,25,36] : int list
map の場合は fold_rigth を使うと簡単です。初期値を nil にして第 1 引数の計算結果を第 2 引数のリストに追加するだけです。
次に filter の例を示します。
- fold_right(fn(x, a) => if x mod 2 = 0 then x::a else a, nil, [1,2,3,4,5,6,7,8]); val it = [2,4,6,8] : int list
filter の場合も初期値を nil にして、第 1 引数が条件を満たしていれば第 2 引数のリストに追加します。
最後に述語が真となる要素の個数を求めてみましょう。これは Common Lisp の関数 count-if と同じです。
- fold_left(fn(a, x) => if x mod 2 = 0 then a + 1 else a, 0, [1,2,3,4,5,6,7,8,9]); val it = 4 : int
このように、畳み込みを使っていろいろな処理を実現することができます。
次の高階関数を定義してください。
リスト : 解答例 fun sum_of(f, s, e) = if s > e then 0 else f s + sum_of(f, s + 1, e) fun tabulate(f, s, e) = if s > e then nil else f s :: tabulate(f, s + 1, e) fun maplist(_, nil) = nil | maplist(f, xs) = f xs :: maplist(f, tl xs) fun take_while(_, nil) = nil | take_while(pred, x::xs) = if pred x then x :: take_while(pred, xs) else nil fun drop_while(_, nil) = nil | drop_while(pred, xs) = if pred (hd xs) then drop_while(pred, tl xs) else xs fun scan_left(f, a, nil) = [a] | scan_left(f, a, x::xs) = a :: scan_left(f, f(a, x), xs) fun scan_right(f, a, nil) = [a] | scan_right(f, a, x::xs) = let val ys = scan_right(f, a, xs) in f(x, hd ys) end
- sum_of; val it = fn : (int -> int) * int * int -> int - sum_of(fn x => x, 1, 100); val it = 5050 : int - sum_of(fn x => x * x, 1, 100); val it = 338350 : int - sum_of(fn x => x * x * x, 1, 100); val it = 25502500 : int - tabulate; val it = fn : (int -> 'a) * int * int -> 'a list - tabulate(fn x => x, 1, 10); val it = [1,2,3,4,5,6,7,8,9,10] : int list - tabulate(fn x => x *x, 1, 10); val it = [1,4,9,16,25,36,49,64,81,100] : int list - tabulate(fn x => x * x * x, 1, 10); val it = [1,8,27,64,125,216,343,512,729,1000] : int list - maplist; val it = fn : ('a list -> 'b) * 'a list -> 'b list - maplist(fn xs => xs, [1,2,3,4,5]); val it = [[1,2,3,4,5],[2,3,4,5],[3,4,5],[4,5],[5]] : int list list - maplist(fn xs => length xs, [1,2,3,4,5]); val it = [5,4,3,2,1] : int list - take_while; val it = fn : ('a -> bool) * 'a list -> 'a list - take_while(fn x => x < 5, [1,2,3,4,5,6,7,8]); val it = [1,2,3,4] : int list - take_while(fn x => x < 9, [1,2,3,4,5,6,7,8]); val it = [1,2,3,4,5,6,7,8] : int list - take_while(fn x => x < 0, [1,2,3,4,5,6,7,8]); val it = [] : int list - drop_while; val it = fn : ('a -> bool) * 'a list -> 'a list - drop_while(fn x => x < 5, [1,2,3,4,5,6,7,8]); val it = [5,6,7,8] : int list - drop_while(fn x => x < 9, [1,2,3,4,5,6,7,8]); val it = [] : int list - drop_while(fn x => x < 0, [1,2,3,4,5,6,7,8]); val it = [1,2,3,4,5,6,7,8] : int list - scan_left; val it = fn : ('a * 'b -> 'a) * 'a * 'b list -> 'a list - scan_left(op +, 0, [1,2,3,4,5,6,7,8,9,10]); val it = [0,1,3,6,10,15,21,28,36,45,55] : int list - scan_left(op *, 1, [1,2,3,4,5,6,7,8,9,10]); val it = [1,1,2,6,24,120,720,5040,40320,362880,3628800] : int list - scan_left(fn(a, x) => x::a, nil, [1,2,3,4,5]); val it = [[],[1],[2,1],[3,2,1],[4,3,2,1],[5,4,3,2,1]] : int list list - scan_right; val it = fn : ('a * 'b -> 'b) * 'b * 'a list -> 'b list - scan_right(op +, 0, [1,2,3,4,5,6,7,8,9,10]); val it = [55,54,52,49,45,40,34,27,19,10,0] : int list - scan_right(op *, 1, [1,2,3,4,5,6,7,8,9,10]); val it = [3628800,3628800,1814400,604800,151200,30240,5040,720,90,10,1] : int list - scan_right(op ::, nil, [1,2,3,4,5]); val it = [[1,2,3,4,5],[2,3,4,5],[3,4,5],[4,5],[5],[]] : int list list
scan_left はリストの最後の要素が最終の累積値になります。xs が空リストのとき、累積変数 a の値をリストに格納して返します。そうでなければ、scan_left を再帰呼び出しして、その返り値に累積変数 a の値を追加して返します。scan_left を再帰呼び出しするときは、関数 f を呼び出して累積変数の値を更新することに注意してください。
scan_right はリストの先頭の要素が最終の累積値、最後の要素が初期値になります。リスト xs が空リストの場合は [a] を返します。そうでなければ、scan_right を再帰呼び出しします。このとき、累積変数 a の値は更新しません。返り値のリストは変数 ys にセットします。この ys の先頭要素が一つ前の累積値になるので、この値とリストの先頭要素 x を関数 f に渡して評価します。あとは、f の返り値を ys の先頭に追加して返せばいいわけです。