M.Hiroi's Home Page

続・お気楽 Java プログラミング入門

immutable な連結リスト


Copyright (C) 2016-2021 Makoto Hiroi
All rights reserved.

はじめに

パッケージの簡単な例題として immutable な連結リストを作ってみました。

●不変データ構造

Java の場合、フィールド変数を final で宣言すると、コンストラクタで初期化する以外に値を変更することができなくなります。さらに、class を final 宣言して継承を禁止すると、新しいサブクラスを作って mutable なフィールド変数を追加することができなくなります。これで immutable なデータ構造 (不変データ構造) を作ることができるように思われますが、実はそう簡単ではありません。

●実装上の問題点

格納する要素が immutable なデータ (String や基本データ型) であれば問題ないのですが、mutable なデータ (オブジェクト) を格納すると、そのデータ構造は immutable ではなくなります。たとえば、データ構造 xs にオブジェクトの参照を返すメソッド get() があるとしましょう。オブジェクトのセッターを setX(...) とすると、xs.get().setX(...) でオブジェクトの値を更新することが可能です。つまり、データ構造 xs の値を書き換えることができるわけです。

この問題を解決するため、データ構造を生成するときに格納するオブジェクトをコピーする、参照を返すときにもコピーを作るといった方法があるのですが、今回は簡単な例題ということで、mutable なオブジェクトを格納するときは連結リストの不変性を放棄することにします。つまり、不変データ構造として使いたい場合は、プログラマの責任で immutable なデータを入れてください、ということにします。

あくまでも学習が目的のライブラリなので、実用性はほとんどありませんが、興味のある方はいろいろ試してみてください。

●連結リストの仕様

●プログラムリスト

//
// ImList.java : immutable な連結リスト
//
//               Copyright (C) 2016-2021 Makoto Hiroi
//
package immutable;

import java.util.Iterator;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;


public final class ImList<E> implements Iterable<E> {
  private final E item;
  private ImList<E> next;

  private ImList(E x, ImList<E> y) {
    item = x;
    next = y;
  }

  // 終端
  private static final ImList<?> NIL = new ImList<>(null, null);

  public static <E> ImList<E> nil() {
    @SuppressWarnings("unchecked")
    ImList<E> t = (ImList<E>)NIL;
    return t;
  }

  public boolean isEmpty() { return this == NIL; }

  // リストの生成
  public static <E> ImList<E> cons(E x, ImList<E> y) {
    return new ImList<E>(x , y);
  }
  
  @SafeVarargs
  public static <E> ImList<E> of(E... args) {
    ImList<E> xs = nil();
    for (int i = args.length - 1; i >= 0; i--) {
      xs = cons(args[i], xs);
    }
    return xs;
  }

  public static <E> ImList<E> fill(int n, E x) {
    ImList<E> xs = nil();
    while (n-- > 0) xs = cons(x, xs);
    return xs;
  }

  public static <E> ImList<E> tabulate(int n, Function<Integer, ? extends E> func) {
    ImList<E> xs = nil();
    while (--n >= 0) {
      xs = cons(func.apply(n), xs);
    }
    return xs;
  }
  
  public static ImList<Integer> iota(int n, int m) {
    ImList<Integer> xs = nil();
    while (m >= n) xs = cons(m--, xs);
    return xs;
  }

  public static <T, U, V> ImList<V> zipWith(BiFunction<? super T, ? super U, ? extends V> func, ImList<T> xs, ImList<U> ys) {
    ImList<V> zs = nil();
    while (xs != NIL && ys != NIL) {
      zs = cons(func.apply(xs.first(), ys.first()), zs);
      xs = xs.rest();
      ys = ys.rest();
    }
    return zs.nreverse();
  }
  
  //
  // 基本操作
  //

  // 先頭要素
  public E first() {
    if (this == NIL)
      throw new IndexOutOfBoundsException("ImList.first()");
    return item;
  }

  // 先頭要素を取り除いたリスト
  public ImList<E> rest() {
    if (this == NIL)
      throw new IndexOutOfBoundsException("ImList.rest()");
    return next;
  }

  // 最後の要素
  public E last() {
    ImList<E> xs = this;
    if (xs == NIL)
      throw new IndexOutOfBoundsException("ImList.last()");
    while (xs.rest() != NIL) xs = xs.rest();
    return xs.first();
  }

  // n 番目の要素
  public E get(int n) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (n-- == 0) return xs.first();
      xs = xs.rest();
    }
    throw new IndexOutOfBoundsException("ImList.get()");
  }

  // n 番目に要素を追加
  public ImList<E>add(int n, E x) {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    if (n == 0) return cons(x, xs);
    while (n-- > 0) {
      ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    ImList<E> zs = ys.nreverse();
    ys.next = cons(x, xs);
    return zs;
  }

  // n 番目の要素を削除
  public ImList<E> remove(int n) {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    if (n == 0) return xs.rest();
    while (n-- > 0) {
      ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    ImList<E> zs = ys.nreverse();
    ys.next = xs.rest();
    return zs;
  }

  // 長さを求める
  public int length() {
    ImList<E> xs = this;
    int c = 0;
    while (xs != NIL) {
      c++;
      xs = xs.rest();
    }
    return c;
  }

  // 連結
  public ImList<E> append(ImList<E> ys) {
    ImList<E> xs = this;
    ImList<E> zs = nil();
    if (xs == NIL) return ys;
    while (xs != NIL) {
      zs = cons(xs.first(), zs);
      xs = xs.rest();
    }
    xs = zs.nreverse();
    zs.next = ys;
    return xs;
  }

  // 反転
  public ImList<E> reverse() {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    while (xs != NIL) {
      ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    return ys;
  }

  // 破壊的な反転
  private ImList<E> nreverse() {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    while (xs != NIL) {
      ImList<E> zs = xs.rest();
      xs.next = ys;
      ys = xs;
      xs = zs;
    }
    return ys;
  }
  
  // 先頭から n 個の要素を取り出す
  public ImList<E> take(int n) {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    while (n -- > 0 && xs != NIL) {
      ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    return ys.nreverse();
  }

  // 先頭から n 個の要素を取り除く
  public ImList<E> drop(int n) {
    ImList<E> xs = this;
    while (n-- > 0 && xs != NIL) xs = xs.rest();
    return xs;
  }

  // 探索
  public int indexOf(E x) {
    int i = 0;
    ImList<E> xs = this;
    while (xs != NIL) {
      if (x.equals(xs.first())) return i;
      i++;
      xs = xs.rest();
    }
    return -1;
  }

  public boolean contains(E x) {
    return indexOf(x) >= 0;
  }

  public Optional<E> findIf(Predicate<? super E> pred) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (pred.test(xs.first())) return Optional.of(xs.first());
      xs = xs.rest();
    }
    return Optional.empty();
  }

  public ImList<E> member(E x) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (x.equals(xs.first())) break;
      xs = xs.rest();
    }
    return xs;
  }
  
  public ImList<E> memberIf(Predicate<? super E> pred) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (pred.test(xs.first())) break;
      xs = xs.rest();
    }
    return xs;
  }

  public ImList<E> takeWhile(Predicate<? super E> pred) {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    while (xs != NIL) {
      if (pred.test(xs.first()))
        ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    return ys.nreverse();
  }

  public ImList<E> dropWhile(Predicate<? super E> pred) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (!pred.test(xs.first())) break;
      xs = xs.rest();
    }
    return xs;
  }
  
  //
  // 高階関数
  //

  // マッピング
  public <U> ImList<U> map(Function<? super E, ? extends U> func) {
    ImList<E> xs = this;
    ImList<U> ys = ImList.nil();
    while (xs != NIL) {
      ys = cons(func.apply(xs.first()), ys);
      xs = xs.rest();
    }
    return ys.nreverse();
  }

  public <U> ImList<U> flatMap(Function<? super E, ImList<U>> func) {
    ImList<E> xs = this;
    ImList<ImList<U>> ys = ImList.nil();
    while (xs != NIL) {
      ys = cons(func.apply(xs.first()), ys);
      xs = xs.rest();
    }
    ImList<U> zs = nil();
    while (ys != NIL) {
      zs = ys.first().append(zs);
      ys = ys.rest();
    }
    return zs;
  }

  // フィルター
  public ImList<E> filter(Predicate<? super E> pred) {
    ImList<E> xs = this;
    ImList<E> ys = ImList.nil();
    while (xs != NIL) {
      if (pred.test(xs.first())) ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    return ys.nreverse();
  }

  // 畳み込み
  public <U> U foldLeft(BiFunction<U, ? super E, U> func, U a) {
    ImList<E> xs = this;
    while (xs != NIL) {
      a = func.apply(a, xs.first());
      xs = xs.rest();
    }
    return a;
  }

  public <U> U foldRight(BiFunction<? super E, U, U> func, U a) {
    ImList<E> xs = this.reverse();
    while (xs != NIL) {
      a = func.apply(xs.first(), a);
      xs = xs.rest();
    }
    return a;
  }
  
  // 巡回
  public void forEach(Consumer<? super E> func) {
    ImList<E> xs = this;
    while (xs != NIL) {
      func.accept(xs.first());
      xs = xs.rest();
    }
  }

  // 述語
  public boolean allMatch(Predicate<? super E> pred) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (!pred.test(xs.first())) return false;
      xs = xs.rest();
    }
    return true;
  }

  public boolean anyMatch(Predicate<? super E> pred) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (pred.test(xs.first())) return true;
      xs = xs.rest();
    }
    return false;
  }

  //
  // 集合演算
  //
  
  // 重複要素を取り除く
  public ImList<E> distinct() {
    ImList<E> xs = this;
    ImList<E> ys = nil();
    while (xs != NIL) {
      if (!ys.contains(xs.first()))
        ys = cons(xs.first(), ys);
      xs = xs.rest();
    }
    return ys.nreverse();
  }

  // 和集合
  public ImList<E> union(ImList<E> ys) {
    ImList<E> xs = this;
    ImList<E> zs = ys;
    while (xs != NIL) {
      if (!ys.contains(xs.first()))
        zs = cons(xs.first(), zs);
      xs = xs.rest();
    }
    return zs;
  }

  // 積集合
  public ImList<E> intersection(ImList<E> ys) {
    ImList<E> xs = this;
    ImList<E> zs = ImList.nil();
    while(xs != NIL) {
      if (ys.contains(xs.first()))
        zs = cons(xs.first(), zs);
      xs = xs.rest();
    }
    return zs;
  }

  // 差集合
  public ImList<E> difference(ImList<E> ys) {
    ImList<E> xs = this;
    ImList<E> zs = ImList.nil();
    while(xs != NIL) {
      if (!ys.contains(xs.first()))
        zs = cons(xs.first(), zs);
      xs = xs.rest();
    }
    return zs;
  }

  // 部分集合
  public boolean isSubset(ImList<E> ys) {
    ImList<E> xs = this;
    while (xs != NIL) {
      if (!ys.contains(xs.first())) return false;
      xs = xs.rest();
    }
    return true;
  }
  
  // イテレータ
  public Iterator<E> iterator() {
    // 無名クラス
    return new Iterator<E>() {
      ImList<E> xs = ImList.this;
      public boolean hasNext() { return xs != NIL; }
      public E next() {
        E item = xs.first();
        xs = xs.rest();
        return item;
      }
      public void remove() {
        throw new UnsupportedOperationException();
      }
    };
  }

  // 文字列に変換
  public String toString() {
    String s = "(";
    ImList<E> xs = this;
    while (xs != NIL) {
      s += xs.first().toString();
      if (xs.rest() != NIL) s += " ";
      xs = xs.rest();
    }
    s += ")";
    return s;
  }
}

●簡単なテスト

リスト : 簡単なテスト

import immutable.ImList;
import static immutable.ImList.*;

public class testimlist {
  public static void main(String[] args) {
    ImList<Integer> xs = nil();
    System.out.println(xs);
    System.out.println(xs.isEmpty());
    for (int i = 0; i < 10; i++) xs = cons(i, xs);
    System.out.println(xs);
    System.out.println(xs.isEmpty());
    System.out.println(xs.first());
    System.out.println(xs.rest());
    System.out.println(xs.rest().first());
    System.out.println(xs.rest().rest().first());
    System.out.println(xs.rest().rest().rest().first());
    xs = fill(5, 0);
    System.out.println(xs);
    xs = of(1, 3, 5, 7, 9);
    System.out.println(xs);
    xs = tabulate(5, x -> x * x);
    System.out.println(xs);
    xs = zipWith((x, y) -> x * y, of(1, 3, 5, 7, 9), of(2, 4, 6, 8, 10));
    System.out.println(xs);
    xs = iota(1, 10);
    System.out.println(xs);
    for (int i = 0; i < xs.length(); i++)
      System.out.println(xs.get(i));
    System.out.println(xs.last());
    System.out.println(xs.add(0, 100));
    System.out.println(xs.add(5, 100));
    System.out.println(xs.add(10, 100));
    System.out.println(xs.remove(0));
    System.out.println(xs.remove(4));
    System.out.println(xs.remove(9));
    System.out.println(xs.append(xs));
    System.out.println(xs.reverse());
    for (int i = 0; i <= 10; i++) {
      System.out.println(xs.take(i));
      System.out.println(xs.drop(i));
    }
    for (int i = 0; i <= 11; i++) {
      final int j = i;
      System.out.println(xs.indexOf(i));
      System.out.println(xs.contains(i));
      System.out.println(xs.findIf(x -> x == j));
      System.out.println(xs.member(i));
      System.out.println(xs.memberIf(x -> x == j));
    }
    System.out.println(xs.takeWhile(x -> x < 5));
    System.out.println(xs.dropWhile(x -> x < 5));
    System.out.println(xs.map(x -> x * x));
    System.out.println(xs.flatMap(x -> ImList.of(x, x)));
    System.out.println(xs.filter(x -> x % 2 == 0));
    System.out.println(xs.foldLeft((a, x) -> a + x, 0));
    System.out.println(xs.foldLeft((a, x) -> cons(x, a), nil()));
    System.out.println(xs.foldRight((x, a) -> a + x, 0));
    System.out.println(xs.foldRight((x, a) -> cons(x, a), nil()));
    xs.forEach(System.out::println);
    System.out.println(xs.allMatch(x -> x <= 10));
    System.out.println(xs.allMatch(x -> x < 10));
    System.out.println(xs.anyMatch(x -> x == 10));
    System.out.println(xs.anyMatch(x -> x == 0));
    for (int x: xs) System.out.print(x);
    System.out.println("");
    ImList<Integer> a = of(1, 2, 3, 4, 2, 3, 4, 3, 4, 4);
    ImList<Integer> b = a.distinct();
    System.out.println(b);
    ImList<Integer> c = of(3, 4, 5, 6);
    System.out.println(b.union(c));
    System.out.println(b.intersection(c));
    System.out.println(b.difference(c));
    System.out.println(b.isSubset(c));
    System.out.println(b.isSubset(b));
  }
}
$ javac testimlist.java
$ java testimlist
()
true
(9 8 7 6 5 4 3 2 1 0)
false
9
(8 7 6 5 4 3 2 1 0)
8
7
6
(0 0 0 0 0)
(1 3 5 7 9)
(0 1 4 9 16)
(2 12 30 56 90)
(1 2 3 4 5 6 7 8 9 10)
1
2
3
4
5
6
7
8
9
10
10
(100 1 2 3 4 5 6 7 8 9 10)
(1 2 3 4 5 100 6 7 8 9 10)
(1 2 3 4 5 6 7 8 9 10 100)
(2 3 4 5 6 7 8 9 10)
(1 2 3 4 6 7 8 9 10)
(1 2 3 4 5 6 7 8 9)
(1 2 3 4 5 6 7 8 9 10 1 2 3 4 5 6 7 8 9 10)
(10 9 8 7 6 5 4 3 2 1)
()
(1 2 3 4 5 6 7 8 9 10)
(1)
(2 3 4 5 6 7 8 9 10)
(1 2)
(3 4 5 6 7 8 9 10)
(1 2 3)
(4 5 6 7 8 9 10)
(1 2 3 4)
(5 6 7 8 9 10)
(1 2 3 4 5)
(6 7 8 9 10)
(1 2 3 4 5 6)
(7 8 9 10)
(1 2 3 4 5 6 7)
(8 9 10)
(1 2 3 4 5 6 7 8)
(9 10)
(1 2 3 4 5 6 7 8 9)
(10)
(1 2 3 4 5 6 7 8 9 10)
()
-1
false
Optional.empty
()
()
0
true
Optional[1]
(1 2 3 4 5 6 7 8 9 10)
(1 2 3 4 5 6 7 8 9 10)
1
true
Optional[2]
(2 3 4 5 6 7 8 9 10)
(2 3 4 5 6 7 8 9 10)
2
true
Optional[3]
(3 4 5 6 7 8 9 10)
(3 4 5 6 7 8 9 10)
3
true
Optional[4]
(4 5 6 7 8 9 10)
(4 5 6 7 8 9 10)
4
true
Optional[5]
(5 6 7 8 9 10)
(5 6 7 8 9 10)
5
true
Optional[6]
(6 7 8 9 10)
(6 7 8 9 10)
6
true
Optional[7]
(7 8 9 10)
(7 8 9 10)
7
true
Optional[8]
(8 9 10)
(8 9 10)
8
true
Optional[9]
(9 10)
(9 10)
9
true
Optional[10]
(10)
(10)
-1
false
Optional.empty
()
()
(1 2 3 4)
(5 6 7 8 9 10)
(1 4 9 16 25 36 49 64 81 100)
(1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 10)
(2 4 6 8 10)
55
(10 9 8 7 6 5 4 3 2 1)
55
(1 2 3 4 5 6 7 8 9 10)
1
2
3
4
5
6
7
8
9
10
true
false
true
false
12345678910
(1 2 3 4)
(2 1 3 4 5 6)
(4 3)
(2 1)
false
true

初版 2016 年 12 月 10 日
改訂 2021 年 2 月 21 日