M.Hiroi's Home Page

お気楽 Standard ML of New Jersey 入門

高階関数

Copyright (C) 2005-2020 Makoto Hiroi
All rights reserved.

はじめに

SML/NJ は関数型言語なので、関数を引数として受け取る関数、いわゆる「高階関数 (higher order function)」を簡単に定義することができます。もちろん、値として関数を返すこともできるので、関数を作る関数を定義することも簡単です。実際、関数の操作は Common Lisp よりも柔軟で簡単です。

●マッピング

簡単な例として、リストの要素に関数 f を適用して、その結果をリストに格納して返す関数を作ってみましょう。このような操作を「マッピング (写像)」といいます。次のリストを見てください。

リスト : マッピング

fun mapcar(_, nil) = nil
|   mapcar(f, x::xs) = f x :: mapcar(f, xs)  

SML/NJ には同等の機能を持つ関数 map が定義されているので、関数名は mapcar としました。名前は Common Lisp から拝借しました。SML/NJ の場合、受け取った関数を呼び出すとき、特別なことを行う必要はありません。SML/NJ は引数 f が関数として使われているので、引数 f を関数型と推論してコンパイルします。関数を渡す場合も簡単です。関数が束縛されている変数を渡すだけでいいのです。とても簡単ですね。

mapcar を定義すると、次のように表示されます。

val mapcar = fn : ('a -> 'b) * 'a list -> 'b list

第 1 引数が関数型 'a -> 'b で第 2 引数がリスト 'a list になります。関数 f はリストの要素を受け取るので、関数 f の引数の型とリストの型は一致します。これを型変数 'a で表しています。同様に、関数 f の返り値の型と mapcar の返り値のリストの型は一致します。これを型変数 'b で表しています。このように、mapcar は多相型関数として定義されます。

それでは簡単な実行例を示しましょう。

- fun square x = x * x;
val square = fn : int -> int

- mapcar(square, [1, 2, 3, 4, 5]);
val it = [1,4,9,16,25] : int list

- mapcar(fn x => x * x, [1, 2, 3, 4, 5]);
val it = [1,4,9,16,25] : int list

引数を 2 乗する関数 square を定義します。この関数を mapcar に渡すと、リストの要素が 2 乗されます。また、Common Lisp のラムダ式のように、SML/NJ にも「匿名関数」があります。匿名関数は次のように定義します。

fn(args, ...) => 式

匿名関数を定義する場合、fun ではなく fn を使うことに注意してください。mapcar に fn x => x * x を渡せば、リストの要素を 2 乗することができます。

●フィルター

フィルター (filter) はリストの要素に関数 f を適用し、関数 f が真を返す要素をリストに格納して返す関数です。真または偽を返す関数のことを「述語 (predicate)」といいます。SML/NJ には filter が定義されているので、ここでは述語が真を返す要素を削除する関数 remove_if を作ってみましょう。関数名は Common Lisp から拝借しました。

リスト : remove_if

fun remove_if(_, nil) = nil
|   remove_if(p, x::xs) =
  if p x then remove_if(p, xs) else x::remove_if(p, xs)  

remove_if も簡単ですね。p x が真ならば x をリストに加えず、偽ならば x をリストに加えるだけです。remove_if を定義すると次のようになります。

val remove_if = fn : ('a -> bool) * 'a list -> 'a list

関数 p が if のテストで使われているので、SML/NJ は関数 p の型を 'a -> bool と推論しています。もちろん、remove_if も多相型関数として定義されます。SML/NJ の型推論はとても便利ですね。簡単な実行例を示します。

- remove_if(fn x => x >= 10, [1, 10, 2, 12, 3, 13]);
val it = [1,2,3] : int list

- remove_if(fn x => x = "abc", ["abc", "def", "abc", "ghi"]);
val it = ["def","ghi"] : string list

最初の例は述語が x >= 10 なので、10 以上の要素が削除されます。次の例は文字列 "abc" が削除されます。

●畳み込み

2 つの引数を取る関数 f とリストを引数に受け取る関数 reduce を考えます。そして、reduce はリストの各要素に対して関数 f を次のように適用します。

(1) [a1, a2, a3, ..., an-1, an] => f( ... f( f( a1, a2 ), a3 ), ...), an-1 ), an )
(2) [a1, a2, a3, ..., an-1, an] => f( a1, f( a2, f( a3, ..., f( an-1, an ) ... )))

関数 f を適用する順番で 2 通りの方法があります。たとえば、関数 f が単純な加算関数とすると、reduce の結果はリストの要素の和になります。

f(x, y) = x + y の場合 : reduce => a1 + a2 + a3 + ... + an-1 + an

このような操作を「縮約」とか「畳み込み」といいます。reduce は引数に初期値 g を指定する場合があります。この場合は、次のような動作になります。

(1) [a1, a2, a3, ..., an-1, an] => f( ... f( f( g, a1 ), a2 ), ...), an-1 ), an )
(2) [a1, a2, a3, ..., an-1, an] => f( a1, f( a2, f( a3, ..., f( an, g ) ... )))

SML/NJ には foldl と foldr という同等の機能を持つ関数が定義されているので、ここでは関数 fold_left と fold_right を作ってみましょう。プログラムは次のようになります。

リスト : 畳み込み

fun fold_left(f, g, nil) = g
|   fold_left(f, g, x::xs) = fold_left(f, f(g, x), xs)

fun fold_right(f, g, nil) = g
|   fold_right(f, g, x::xs) = f(x, fold_right(f, g, xs))

第 1 引数 f が適用する関数、第 2 引数 g が初期値、第 3 引数がリストです。最初の定義は再帰呼び出しの停止条件ですが、引数に空リストが与えられた場合にも対応します。この場合は初期値 g を返します。次の節でリストの要素を取り出して関数 f を呼び出します。

たとえば、リストが [1, 2, 3] で g が 0 とします。最初は f(0, 1) が実行され、その返り値が fold_left の第 2 引数に渡されます。次は f(g, 2) が実行されますが、これは f(f(0, 1), 2) と同じことです。そして、その結果が fold_left の第 2 引数になります。最後に f(g, 3) が実行されますが、これは f(f(f(0, 1), 2), 3) となり、上図 (1) と同じ動作になります。

fold_left の場合、リストの要素が関数 f の第 2 引数になり、第 1 引数にはこれまでの処理結果が渡されます。これに対し、fold_right の場合は逆になり、関数 f の第 1 引数にリストの要素が渡されて、これまでの処理結果は第 2 引数に渡されます。これで上図 (2) の動作を実現することができます。

それでは、fold_left と fold_right の型を示します。

val fold_left = fn : ('a * 'b -> 'a) * 'a * 'b list -> 'a
val fold_right = fn : ('a * 'b -> 'b) * 'b * 'a list -> 'b

fold_left と fold_right の返り値は初期値 g の型と同じになります。つまり、リストの要素と同じ型である必要はありません。そして、渡される関数の型にも注目してください。fold_left の場合、関数の第 1 引数と初期値の型が同じです。第 1 引数に今までの処理結果が渡されて、第 2 引数にリストの要素が渡されることがわかります。これに対し、fold_right は逆になっていることに注意してください。

簡単な実行例を示します。

- fold_left(op +, 0, [1,2,3,4,5]);
val it = 15 : int
- fold_left(fn(a, x) => x::a, nil, [1,2,3,4,5]);
val it = [5,4,3,2,1] : int list

- fold_right(op +, 0, [1,2,3,4,5]);
val it = 15 : int
- fold_right(op ::, nil, [1,2,3,4,5]);
val it = [1,2,3,4,5] : int list

op は 2 項演算子を 2 引数の関数として使うためのものです。次の例を見てください。

- op +(1, 2);
val it = 3 : int

二項演算子を関数として使用する場合は op を使ってください。上の例では、初期値に 0 を指定することで、リストの要素の合計値を求めています。また、初期値に nil を指定して、関数にコンス演算子を渡すと、リストの逆順やリストのコピーを求めることができます。

●畳み込みの使用例

このように、fold_left, fold_right と 2 引数の関数を組み合わせることで、いろいろな関数を実現することができます。もう少しだけ簡単な例を示しましょう。最初は length です。

- fold_left(fn(a, _) => a + 1, 0, [1,1,1,1,1,1,1]);
val it = 7 : int
- fold_right(fn(_, a) => a + 1, 0, [1,1,1,1,1,1,1]);
val it = 7 : int

fold_left で length を実現する場合、初期値を 0 にして第 1 引数の値を +1 することで実現できます。fold_right の場合は第 2 引数の値を +1 します。

次に map の例を示します。

- fold_right(fn(x, a) => (x * x)::a, nil, [1,2,3,4,5,6]);
val it = [1,4,9,16,25,36] : int list

map の場合は fold_rigth を使うと簡単です。初期値を nil にして第 1 引数の計算結果を第 2 引数のリストに追加するだけです。

次に filter の例を示します。

- fold_right(fn(x, a) => if x mod 2 = 0 then x::a else a, nil, [1,2,3,4,5,6,7,8]);
val it = [2,4,6,8] : int list

filter の場合も初期値を nil にして、第 1 引数が条件を満たしていれば第 2 引数のリストに追加します。

最後に述語が真となる要素の個数を求めてみましょう。これは Common Lisp の関数 count-if と同じです。

- fold_left(fn(a, x) => if x mod 2 = 0 then a + 1 else a, 0, [1,2,3,4,5,6,7,8,9]);
val it = 4 : int

このように、畳み込みを使っていろいろな処理を実現することができます。

●問題

次の高階関数を定義してください。

  1. s から e までの整数に関数 f を適用し、その合計値を求める関数 sum_of(f, s, e)
  2. s から e までの整数に関数 f を適用し、その結果をリストに格納して返す関数 tabulate(f, s, e)
  3. 関数 f にリストを渡すマップ関数 maplist(f, xs)
  4. 述語 pred を満たす要素を先頭から取り出す関数 take_while(pred, xs)
  5. 述語 pred を満たす要素を先頭から取り除く関数 drop_while(pred, xs)
  6. リスト xs を先頭から畳み込むとき、計算途中の累積値をリストに格納して返す scan_left(f, a, xs)
  7. リスト xs を末尾から畳み込むとき、計算途中の累積値をリストに格納して返す scan_right(f, a, xs)
















●解答

リスト : 解答例

fun sum_of(f, s, e) =
  if s > e then 0
  else f s + sum_of(f, s + 1, e)

fun tabulate(f, s, e) =
  if s > e then nil
  else f s :: tabulate(f, s + 1, e)

fun maplist(_, nil) = nil
|   maplist(f, xs) = f xs :: maplist(f, tl xs)

fun take_while(_, nil) = nil
|   take_while(pred, x::xs) =
  if pred x then x :: take_while(pred, xs)
  else nil

fun drop_while(_, nil) = nil
|   drop_while(pred, xs) =
  if pred (hd xs) then drop_while(pred, tl xs)
  else xs

fun scan_left(f, a, nil) = [a]
|   scan_left(f, a, x::xs) = a :: scan_left(f, f(a, x), xs)

fun scan_right(f, a, nil) = [a]
|   scan_right(f, a, x::xs) = 
  let
    val ys = scan_right(f, a, xs)
  in
    f(x, hd ys)
  end
- sum_of;
val it = fn : (int -> int) * int * int -> int

- sum_of(fn x => x, 1, 100);
val it = 5050 : int
- sum_of(fn x => x * x, 1, 100);
val it = 338350 : int
- sum_of(fn x => x * x * x, 1, 100);
val it = 25502500 : int

- tabulate;
val it = fn : (int -> 'a) * int * int -> 'a list

- tabulate(fn x => x, 1, 10);
val it = [1,2,3,4,5,6,7,8,9,10] : int list
- tabulate(fn x => x *x, 1, 10);
val it = [1,4,9,16,25,36,49,64,81,100] : int list
- tabulate(fn x => x * x * x, 1, 10);
val it = [1,8,27,64,125,216,343,512,729,1000] : int list

- maplist;
val it = fn : ('a list -> 'b) * 'a list -> 'b list
- maplist(fn xs => xs, [1,2,3,4,5]);
val it = [[1,2,3,4,5],[2,3,4,5],[3,4,5],[4,5],[5]] : int list list
- maplist(fn xs => length xs, [1,2,3,4,5]);
val it = [5,4,3,2,1] : int list

- take_while;
val it = fn : ('a -> bool) * 'a list -> 'a list
- take_while(fn x => x < 5, [1,2,3,4,5,6,7,8]);
val it = [1,2,3,4] : int list
- take_while(fn x => x < 9, [1,2,3,4,5,6,7,8]);
val it = [1,2,3,4,5,6,7,8] : int list
- take_while(fn x => x < 0, [1,2,3,4,5,6,7,8]);
val it = [] : int list

- drop_while;
val it = fn : ('a -> bool) * 'a list -> 'a list
- drop_while(fn x => x < 5, [1,2,3,4,5,6,7,8]);
val it = [5,6,7,8] : int list
- drop_while(fn x => x < 9, [1,2,3,4,5,6,7,8]);
val it = [] : int list
- drop_while(fn x => x < 0, [1,2,3,4,5,6,7,8]);
val it = [1,2,3,4,5,6,7,8] : int list

- scan_left;
val it = fn : ('a * 'b -> 'a) * 'a * 'b list -> 'a list
- scan_left(op +, 0, [1,2,3,4,5,6,7,8,9,10]);
val it = [0,1,3,6,10,15,21,28,36,45,55] : int list
- scan_left(op *, 1, [1,2,3,4,5,6,7,8,9,10]);
val it = [1,1,2,6,24,120,720,5040,40320,362880,3628800] : int list
- scan_left(fn(a, x) => x::a, nil, [1,2,3,4,5]);
val it = [[],[1],[2,1],[3,2,1],[4,3,2,1],[5,4,3,2,1]] : int list list

- scan_right;
val it = fn : ('a * 'b -> 'b) * 'b * 'a list -> 'b list
- scan_right(op +, 0, [1,2,3,4,5,6,7,8,9,10]);
val it = [55,54,52,49,45,40,34,27,19,10,0] : int list
- scan_right(op *, 1, [1,2,3,4,5,6,7,8,9,10]);
val it = [3628800,3628800,1814400,604800,151200,30240,5040,720,90,10,1]
  : int list
- scan_right(op ::, nil, [1,2,3,4,5]);
val it = [[1,2,3,4,5],[2,3,4,5],[3,4,5],[4,5],[5],[]] : int list list

●補足

scan_left はリストの最後の要素が最終の累積値になります。xs が空リストのとき、累積変数 a の値をリストに格納して返します。そうでなければ、scan_left を再帰呼び出しして、その返り値に累積変数 a の値を追加して返します。scan_left を再帰呼び出しするときは、関数 f を呼び出して累積変数の値を更新することに注意してください。

scan_right はリストの先頭の要素が最終の累積値、最後の要素が初期値になります。リスト xs が空リストの場合は [a] を返します。そうでなければ、scan_right を再帰呼び出しします。このとき、累積変数 a の値は更新しません。返り値のリストは変数 ys にセットします。この ys の先頭要素が一つ前の累積値になるので、この値とリストの先頭要素 x を関数 f に渡して評価します。あとは、f の返り値を ys の先頭に追加して返せばいいわけです。


初版 2005 年 5 月 7 日
改訂 2020 年 8 月 9 日