M.Hiroi's Home Page

Julia Language Programming

お気楽 Julia プログラミング超入門


Copyright (C) 2018-2021 Makoto Hiroi
All rights reserved.

Julia の基礎知識

●Expr 型と Symbol 型

julia> :((1 + 2) * (3 - 4))
:((1 + 2) * (3 - 4))

julia> typeof(:((1 + 2) * (3 - 4)))
Expr

julia> dump(:((1 + 2) * (3 - 4)))
Expr
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol *
    2: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol +
        2: Int64 1
        3: Int64 2
    3: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol -
        2: Int64 3
        3: Int64 4

julia> Meta.parse("(1 + 2) * (3 - 4)")
:((1 + 2) * (3 - 4))
julia> a = :foo
:foo

julia> b = Symbol("foo")
:foo

julia> a === b
true

julia> Meta.show_sexpr(:((1 + 2) * (3 - 4)))
(:call, :*, (:call, :+, 1, 2), (:call, :-, 3, 4))

julia> eval(:((1 + 2) * (3 - 4)))
-3

julia> x = 100
100

julia> :(x + 200)
:(x + 200)

julia> :($x + 200)
:(100 + 200)

julia> eval(:(x + 200))
300

julia> x = 10
10

julia> eval(:(x + 200))
210
リスト : REPL (Read - Eval - Print - Loop)

while true
    print(">>> ")
    s = readline(stdin)
    try
        println(eval(Meta.parse(s)))
    catch e
        println(e)
    end
end
>>> 1 + 2 * 3
7
>>> a = 10
10
>>> a * 10
100
>>> b
UndefVarError(:b)
>>> square(x) = x * x
square
>>> square(1.2345)
1.5239902499999998
>>> exit()
$ 

●マクロの基本

julia> macro m_square(x)
       :($x * $x)
       end
@m_square (macro with 1 method)

julia> @m_square(100)
10000

julia> a = 20
20

julia> @m_square(a)
400

julia> x = 123
123

julia> @m_square(x * 2)
60516

julia> square(x) = x * x
square (generic function with 1 method)

julia> foo(x) = (println("foo!"); x)
foo (generic function with 1 method)

julia> square(foo(10))
foo!
100

julia> @m_square(foo(10))
foo!
foo!
100
julia> macro chunk(expr)
       :(() -> $expr)
       end
@chunk (macro with 1 method)

julia> f = @chunk(1 + 2 * 3)
#3 (generic function with 1 method)

julia> f()
7

julia> g = @chunk(x * x for x = 1 : 10)
#5 (generic function with 1 method)

julia> for x = g()
       println(x)
       end
1
4
9
16
25
36
49
64
81
100

julia> a = Task(@chunk(for x = 1:10; sleep(0.5); println(x); end))
Task (runnable) @0x0000000010b49430

julia> schedule(a)
Task (runnable) @0x0000000010b49430

julia> 1
2
3
4
5
6
7
8
9
10
julia> macro chunk(expr) :(() -> $expr) end
@chunk (macro with 1 method)

julia> a = 1
1

julia> b = 2
2

julia> f = @chunk(a = b)
#3 (generic function with 1 method)

julia> f()
2

julia> a
1

julia> @macroexpand @chunk(a = b)
:(()->begin
          #= REPL[18]:1 =#
          var"#5#a" = Main.b  # a が変数 #5#a に置き換えられている
      end)                    # b も大域変数のアクセスになっている

julia> f = @chunk(global a = b)
#5 (generic function with 1 method)

julia> f()
2

julia> a
2

julia> @macroexpand @chunk(global a = b)
:(()->begin
          #= REPL[18]:1 =#
          global a = Main.b     # global を付けると置換しないようだ
      end)

julia> let a = 10, b = 20
       f = @chunk(a + b)        # 局所変数にはアクセスできない
       println(f())
       end
4

julia> a + b
4
julia> macro chunk(expr) esc(:(() -> $expr)) end
@chunk (macro with 1 method)

julia> a = 1
1

julia> b = 2
2

julia> f = @chunk(a = b)
#9 (generic function with 1 method)

julia> f()
2

julia> a
1

julia> @macroexpand @chunk(a = b)
:(()->begin
          #= REPL[36]:1 =#
          a = b                       # 無名関数の中なので a は局所変数になる
      end)

julia> f = @chunk(global a = b)
#11 (generic function with 1 method)

julia> f()
2

julia> a
2

julia> let a = 10, b = 20
       f = @chunk(a + b)
       println(f())
       end
30

julia> f = @chunk(a + b)              # これは大域変数をアクセスする
#15 (generic function with 1 method)

julia> f()
4

julia> @macroexpand @chunk(a + b)
:(()->begin
          #= REPL[36]:1 =#
          a + b
      end)
julia> macro setf1(x, y)
       :($(esc(x)) = $(esc(y)))
       end
@setf1 (macro with 1 method)

julia> @macroexpand @setf1(a, b)
:(a = b)

julia> macro setf2(x, y)
       quote
       z = $(esc(y))
       $(esc(x)) = z
       end
       end
@setf2 (macro with 1 method)

julia> @macroexpand @setf2(a, b)
quote
    #= REPL[52]:3 =#
    var"#6#z" = b
    #= REPL[52]:4 =#
    a = var"#6#z"
end

julia> a = 10
10

julia> b = 20
20

julia> @setf2 a b
20

julia> a
20

julia> let a = 100, b = 200
       @setf2 a b
       println(a)
       end
200
julia> macro arithmeticif(test, n, z, p)
       quote
       local r = $(esc(test))
       if r == 0
       $z
       elseif r > 0
       $p
       else
       $n
       end
       end
       end
@arithmeticif (macro with 1 method)

julia> @arithmeticif(-1, "negative", "zero", "positive")
"negative"

julia> @arithmeticif(0, "negative", "zero", "positive")
"zero"

julia> @arithmeticif(1, "negative", "zero", "positive")
"positive"

julia> a = 100
100

julia> @arithmeticif(a, "negative", "zero", "positive")
"positive"

julia> @arithmeticif(1 - 3, "negative", "zero", "positive")
"negative"

julia> @arithmeticif(1 - 3 + 2, "negative", "zero", "positive")
"zero"

●並列プログラミング

$ julia -p 4
... 省略 ...

julia> nworkers()
4

julia> nprocs()
5
julia> @everywhere fibo(n) = if n < 2 n else fibo(n - 1) + fibo(n - 2) end

julia> @time fibo(42)
  2.118362 seconds
267914296

julia> @time fibo(42) + fibo(42)
  4.224953 seconds
535828592

julia> r = remotecall(fibo, 2, 42)
Future(2, 1, 14, ReentrantLock(nothing, 0x00000000, 0x00, Base.GenericCondition
{Base.Threads.SpinLock}(Base.IntrusiveLinkedList{Task}(nothing, nothing), 
Base.Threads.SpinLock(0)), (0, 140536252341712, 140536252342000)),nothing)

julia> fetch(r)
267914296

julia> remotecall_fetch(fibo, 2, 42)
267914296

julia> function test()
       a = remotecall(fibo, 2, 42)
       b = remotecall(fibo, 3, 42)
       fetch(a) + fetch(b)
       end
test (generic function with 1 method)

julia> @time test()
  2.864782 seconds (175 allocations: 8.531 KiB)
535828592

julia> function test1()
       a = remotecall(fibo, 2, 42)
       b = remotecall(fibo, 2, 42)
       fetch(a) + fetch(b)
       end
test1 (generic function with 1 method)

julia> @time test1()
  4.222432 seconds (174 allocations: 8.234 KiB)
535828592
julia> function test2()
       a = @spawn fibo(42)
       b = @spawn fibo(42)
       fetch(a) + fetch(b)
       end
test2 (generic function with 1 method)

julia> @time test2()
  3.705540 seconds (860 allocations: 54.008 KiB, 0.26% compilation time)
535828592

julia> function test3()
       a = @spawnat 2 fibo(42)
       b = @spawnat 3 fibo(42)
       fetch(a) + fetch(b)
       end
test3 (generic function with 1 method)

julia> @time test3()
  2.762539 seconds (337 allocations: 18.578 KiB)
535828592

Julia> function test4()
       a = @spawnat :any fibo(42)
       b = @spawnat :any fibo(42)
       fetch(a) + fetch(b)
       end

julia> @time test4()
  2.585752 seconds (376 allocations: 21.203 KiB)
535828592
julia> @time @distributed (+) for _ in 1 : 2; fibo(42); end
  2.733667 seconds (189.63 k allocations: 12.874 MiB, 7.24% compilation time)
535828592

julia> @time @distributed (+) for _ in 1 : 3; fibo(42); end
  3.413032 seconds (10.43 k allocations: 738.352 KiB, 0.70% compilation time)
803742888

julia> @time @distributed (+) for _ in 1 : 4; fibo(42); end
  4.458989 seconds (10.63 k allocations: 752.883 KiB, 0.54% compilation time)
1071657184

julia> @time map(fibo, [42, 42])
  4.280798 seconds (2 allocations: 160 bytes)
2-element Vector{Int64}:
 267914296
 267914296

julia> @time pmap(fibo, [42, 42])
  3.056607 seconds (677.39 k allocations: 46.379 MiB, 0.49% gc time, 18.34% compilation time)
2-element Vector{Int64}:
 267914296
 267914296

julia> @time pmap(fibo, [42, 42, 42])
  3.401667 seconds (187 allocations: 8.422 KiB)
3-element Vector{Int64}:
 267914296
 267914296
 267914296
julia> @time @sync (@spawn fibo(42); @spawn fibo(42))
  3.618582 seconds (47.02 k allocations: 3.321 MiB, 4.92% compilation time)
Future(5, 1, 56, ReentrantLock(nothing, 0x00000000, 0x00, Base.GenericCondition
{Base.Threads.SpinLock}(Base.IntrusiveLinkedList{Task}(nothing, nothing), 
Base.Threads.SpinLock(0)), (8, 526336, 1519571675789)), nothing)

julia> @time (@spawn fibo(42); @spawn fibo(42))
  0.000850 seconds (262 allocations: 13.953 KiB)
Future(3, 1, 60, ReentrantLock(nothing, 0x00000000, 0x00, Base.GenericCondition
{Base.Threads.SpinLock}(Base.IntrusiveLinkedList{Task}(nothing, nothing), 
Base.Threads.SpinLock(0)), (0, 140536219851440, 140536219848096)), nothing)
julia> @everywhere cnt = 0

julia> @everywhere function test(n)
       for _ in 1 : 10
       global cnt        # ワーカープロセスで実行すると、
                         # そのワーカープロセスの大域変数 cnt にアクセスする
       cnt += n
       println(cnt)
       sleep(n)
       end
       end

julia> r = @spawn test(1); fetch(r)
        From worker 3:  1
        From worker 3:  2
        From worker 3:  3
        From worker 3:  4
        From worker 3:  5
        From worker 3:  6
        From worker 3:  7
        From worker 3:  8
        From worker 3:  9
        From worker 3:  10

julia> cnt
0                        # マスタープロセスの cnt は 0 のまま

julia> a = [1, 2, 3, 4, 5]
5-element Vector{Int64}:
 1
 2
 3
 4
 5

julia> @everywhere test1(a, n, x) = (a[n] = x; println(a))

julia> r = @spawn test1(a, 3, 30); fetch(r)
      From worker 5:    [1, 2, 30, 4, 5]    # ワーカープロセスの配列を変更

julia> a                 # マスタープロセスの配列は変更されていない
5-element Vector{Int64}:
 1
 2
 3
 4
 5

簡単なプログラム

●平方根

実数 a の平方根 \(\sqrt a\) の値を求める場合、方程式 \(x^2 - a = 0\) を Newton (ニュートン) 法で解くことが多いと思います。方程式を \(f(x)\), その導関数を \(f'(x)\) とすると、ニュートン法は次の漸化式の値が収束するまで繰り返す方法です。

\( x_{n+1} = x_n - \dfrac{f(x_n)}{f'(x_n)} \)

平方根を求める場合、導関数は \(f'(x) = 2x\) になるので、漸化式は次のようになります。

\( x_{n+1} = \dfrac{1}{2} (x_n + \dfrac{a}{x_n}) \)

参考文献『C言語による最新アルゴリズム事典』によると、\(\sqrt a\) より大きめの初期値から出発し、置き換え x <- (x + a / x) / 2 を減少が止まるまで繰り返すことで \(\sqrt a\) の正確な値を求めることができるそうです。

Julia でプログラムすると、次のようになります。

リスト : 平方根を求める

function sqrt1(x::Float64)
    function init(x::Float64, s = 1.0)
        while s < x
            s *= 2.0
            x /= 2.0
        end
        s
    end
    if x < 0; error("sqrt1: domain error"); end
    p = x > 1.0 ? init(x) : 1.0
    while true
        q = (p + x / p) / 2
        if q >= p; break; end
        p = q
    end
    p
end

println(sqrt1(2.0))
println(sqrt(2))
println(sqrt1(123456.0))
println(sqrt(123456))
1.414213562373095
1.4142135623730951
351.363060095964
351.363060095964

局所関数 init は \(\sqrt x\) よりも大きめの初期値を求めます。たとえば、\(\sqrt {123456}\) を求める場合、初期値の計算は次のようになります。

   s         x
-------------------
  1.0  123456.0
  2.0   61728.0
  4.0   30864.0
  8.0   15432.0
 16.0    7716.0
 32.0    3858.0
 64.0    1929.0
128.0     964.5
256.0     482.25
512.0     241.125

√123456 = 351.363060095964 

s を 2 倍、x を 1 / 2 していき、s >= x となったときの s が初期値 (512) となります。4, 16, 64, 256, ... 22n の平方根はこれだけで求めることができます。

あとは漸化式を計算して変数 q にセットし、q がひとつ前の値 p 以上になったら p を返すだけです。\(\sqrt {123456}\) を求めたときの p と q の値を示します。

   p                  q
--------------------------------------
512.0              376.5625
376.5625           352.20622925311204
352.20622925311204 351.3640693544162
351.3640693544162  351.3630600974135
351.3630600974135  351.363060095964
351.363060095964   351.363060095964

√123456 = 351.363060095964 

6 回の繰り返しで \(\sqrt {123456}\) を求めることができます。


●めのこ平方

平方根の整数部もニュートン法を使って求めることができますが、次の公式を使って平方根の整数部分を求めることもできます。

\(\begin{array}{ll} (1) & 1 + 3 + 5 + \cdots + (2n - 1) = n^2 \\ (2) & 1 + 3 + 5 + \cdots + (2n - 1) = n^2 \lt m \lt 1 + 3 + \cdots + (2n - 1) + (2n + 1) = (n + 1)^2 \end{array}\)

式 (1) は、奇数 \(1\) から \(2n - 1\) の総和は \(n^2\) になることを表しています。式 (2) のように、整数 m の値が \(n^2\) より大きくて \((n + 1)^2\) より小さいのであれば、m の平方根の整数部分は n であることがわかります。これは m から奇数 \(1, 3, 5, \ldots, (2n - 1), (2n + 1)\) を順番に引き算していき、引き算できなくなった時点の (2n + 1) / 2 = n が m の平方根になります。参考文献『平方根計算法』によると、この方法を「めのこ平方」と呼ぶそうです。

プログラムは次のようになります。

リスト : めのこ平方

# めのこ平方
function isqrt(n, m = 1)
    while n >= m
        n -= m
        m += 2
    end
    div(m, 2)
end

println(isqrt(4))
println(isqrt(16))
println(isqrt(64))
println(isqrt(80))
println(isqrt(81))
println(isqrt(82))
println(isqrt(100))
2
4
8
8
9
9
10

この方法はとても簡単ですが、数が大きくなると時間がかかるようになります。そこで、整数を 2 桁ずつ分けて計算していくことにします。次の図を見てください。

整数 6789 を 67 と 89 に分ける

1 + 3 + ... + 15 = 82 < 67

両辺を 100 倍すると 802 < 6700 < 6789

802 = 1 + 3 + ... + 159 (= 2 * 80 - 1)

161 + 163 < (6789 - 6400 = 389) < 161 + 163 + 165

整数 6789 を 67 と 89 に分けます。最初に 67 の平方根を求めます。この場合は 8 になり、82 < 67 を満たします。次に、この式を 100 倍します。すると、802 < 6700 になり、6700 に 89 を足した 6789 も 802 より大きくなります。802 は 1 から 159 までの奇数の総和であることはすぐにわかるので、6789 - 6400 = 389 から奇数 161, 163, ... を順番に引き算していけば 6789 の平方根を求めることができます。

プログラムは次のようになります。

リスト : めのこ平方 (改良版)

function isqrt1(n)
    if n < 100
        isqrt(n)
    else
        m = 10 * isqrt1(div(n, 100))
        isqrt(n - m * m, 2 * m + 1)
    end
end

println(isqrt1(6789))
println(isqrt1(123456789))
println(isqrt1(1234567890))
82
11111
35136

isqrt1() は n の平方根の整数部分を求めます。n が 100 未満の場合は isqrt() で平方根を求めます。これが再帰呼び出しの停止条件になります。n が 100 以上の場合は、n の下位 2 桁を取り除いた値 div(n, 100) の平方根を isqrt1() で求め、その値を 10 倍して変数 m にセットします。そして、isqrt() で n - m * m から奇数 2 * m + 1, 2 * m + 3 ... を順番に引き算していって n の平方根を求めます。

興味のある方はいろいろ試してみてください。

●参考文献

  1. 奥村晴彦,『C言語による最新アルゴリズム事典』, 技術評論社, 1991
  2. 仙波一郎のページ, 『平方根計算法』 (PDF)

●べき集合

今回は配列 xs のべき集合を求める高階関数 power_set() を作ります。たとえば配列 [1, 2, 3] のべき集合は [], [1], [2], [3], [1, 2], [1, 3], [2, 3], [1, 2, 3] になります。

リスト : べき集合

function power_set(f, xs)
    function iter(i, a)
        if i > length(xs)
            f(a)
        else
            iter(i + 1, a)
            push!(a, xs[i])
            iter(i + 1, a)
            pop!(a)
        end
    end
    a::typeof(xs) = []
    iter(1, a)
end

power_set(println, ["foo", "bar", "baz"])
power_set(println, [1,2,3,4])

power_set() は簡単です。実際の処理は局所関数 iter() で行います。xs の i 番目の要素を選択する場合は、その要素を引数 a に追加して iter() を再帰呼び出しします。選択しない場合は、引数 a に要素を追加せずに iter() を再帰呼び出しするだけです。これでべき集合の要素をすべて求めることができます。

String[]
["baz"]
["bar"]
["bar", "baz"]
["foo"]
["foo", "baz"]
["foo", "bar"]
["foo", "bar", "baz"]
Int64[]
[4]
[3]
[3, 4]
[2]
[2, 4]
[2, 3]
[2, 3, 4]
[1]
[1, 4]
[1, 3]
[1, 3, 4]
[1, 2]
[1, 2, 4]
[1, 2, 3]
[1, 2, 3, 4]

●マスターマインドの解法

「マスターマインド」は 0 から 9 までの重複しない 4 つの数字からなる隠しコードを当てるゲームです。数字は合っているが位置が間違っている個数を cows で表し、数字も位置も合っている個数を bulls で表します。bulls が 4 になると正解です。

   [6, 2, 8, 1] : 正解
---------------------------------
1: [0, 1, 2, 3] : cows 2 : bulls 0
2: [1, 0, 4, 5] : cows 1 : bulls 0
3: [2, 3, 5, 6] : cows 2 : bulls 0
4: [3, 2, 7, 4] : cows 0 : bulls 1
5: [3, 6, 0, 8] : cows 2 : bulls 0
6: [6, 2, 8, 1] : cows 0 : bulls 4

  図 : マスターマインドの動作例

今回はマスターマインドを解くプログラムを作ることにします。

●推測アルゴリズム

このゲームは 10 個の数字の中から 4 個選ぶわけですから、全体では 10 * 9 * 8 * 7 = 5040 通りのコードしかありません。この中から正解を見つける方法ですが、質問したコードとその結果を覚えておいて、それと矛盾しないコードを作るようにします。具体的には、4 つの数字の順列を生成し、それが今まで質問したコードと矛盾しないことを確かめます。これは生成検定法と同じですね。

矛盾しているかチェックする方法も簡単で、以前に質問したコードと比較して、bulls と cows が等しいときは矛盾していません。たとえば、次の例を考えてみてください。

[6, 2, 8, 1] が正解の場合

[0, 1, 2, 3] => bulls = 0, cows = 2

           [0, 1, 2, 3]  と比較する
     --------------------------------------------------------
           [0, X, X, X]  0 から始まるコードは bulls = 1
                         になるので矛盾する。
           ・・・・

           [1, 0, 3, 4]  cows = 3, bulls = 0 になるので矛盾する

           ・・・・

           [1, 0, 4, 5]  cows = 2, bulls = 0 で矛盾しない。
     --------------------------------------------------------

[1, 0, 4, 5] => bulls = 0, cows = 1

次は、[0, 1, 2, 3] と [1, 0, 4, 5] に矛盾しない数字を選ぶ

        図 : マスターマインドの推測アルゴリズム

[0, 1, 2, 3] で bulls が 0 ですから、その位置にその数字は当てはまりません。したがって、[0, X, X, X] というコードは [0, 1, 2, 3] と比較すると bulls が 1 となるので、矛盾していることがわかります。

次に [1, 0, 3, 4] というコードを考えてみます。[0, 1, 2, 3] の結果は cows が 2 ですから、その中で合っている数字は 2 つしかないわけです。ところが、[1, 0, 3, 4] と [0, 1, 2, 3] と比較すると cows が 3 になります。当たっている数字が 2 つしかないのに、同じ数字を 3 つ使うのでは矛盾していることになりますね。

次に [1, 0, 4, 5] というコードと比較すると、bulls が 0 で cows が 2 となります。これは矛盾していないので、このコードを質問することにします。その結果が bulls = 0, cows = 1 となり、今度は [0, 1, 2, 3] と [1, 0, 4, 5] に矛盾しないコードを選択するのです。

●プログラムの作成

それでは、プログラムを作っていきましょう。まず、質問したコードとその結果を格納するデータ型を定義します。

リスト : データ型の定義

# 定数
const CSIZE = 4

# 質問したコードとその結果
type Query
    bulls::Int
    cows::Int
    code::Array{Int, 1}
end

型名は Query としました。bulls, cows と質問したコード code を格納します。これを大域変数 query の配列に格納します。

次は bulls を数える関数 count_bulls() を作ります。

リスト : bulls を数える

function count_bulls(xs, ys)
    c = 0
    for i = 1 : CSIZE
        if xs[i] == ys[i]; c += 1; end
    end
    c
    # 次のコードでもよい
    # count(map(==, xs, ys))
end

count_bulls() は簡単です。配列 xs, ys の要素を順番に比較して、等しい場合は変数 c の値を +1 します。この処理は関数 count() と map() を使うと 1 行で書くことができます。

count(pred, iter) => Integer
count(iter) => Integer

count(pred, iter) はイテレータから要素を取り出して、関数 pred が真を返す要素の個数を求めます。iter だけ渡すと、要素が真の個数を求めます。map() の返り値は bool 型の配列になるので、count() で等しい要素の個数を求めることができます。

次は cows を数える処理を作ります。いきなり cows を数えようとすると難しいのですが、2 つのリストに共通の数字を数えることは簡単にできます。この方法では、bulls の個数を含んだ数を求めることになりますが、そこから bulls を引けば cows を求めることができます。関数名は count_same_number() としましょう。プログラムは次のようになります。

リスト : 同じ数字の個数を数える

function count_same_number(xs, ys)
    c = 0
    for x = xs
        if x in ys c += 1 end
    end
    c
end

for ループで xs の要素を順番に取り出して変数 x にセットします。そして、x in ys で x が ys に含まれているかチェックします。そうであれば、変数 c の値を +1 します。

次は、今まで質問したコードと矛盾していないか調べる関数 check を作ります。

リスト : 今まで質問したコードと矛盾していないか

function check(answer, xs)
    global query
    for q = query
        b = count_bulls(q.code, xs)
        c = count_same_number(q.code, xs) - b
        if b != q.bulls || c != q.cows
            return
        end
    end
    b = count_bulls(answer, xs)
    c = count_same_number(answer, xs) - b
    q = Query(b, c, xs)
    push!(query, q)
    n = length(query)
    println("$n: $xs, bulls = $b, cows = $c")
    if b == 4
        throw("Good Job!")
    end
end

引数 answer は正解のコード、xs は生成したコードです。最初に、大域変数 query に格納されたデータをチェックしていきます。count_bulls() と count_same_number() を使って bulls (変数 b) と cows (変数 c) を求めて、質問したときの q.bulls と q.cows に矛盾しないかチェックします。矛盾している場合は return で終了します。

それから、正解のコード answer と xs を比較して bulls と cows を求め、それらを構造体 Query にまとめて query に追加します。あとは関数 permutations() で順列を生成するだけです。詳細はプログラムリストをお読みください。

●何回で当たるか

これでプログラムは完成です。それでは実行例を示しましょう。

julia> solver([9, 8, 7, 6])
1: [0, 1, 2, 3], bulls = 0, cows = 0
2: [4, 5, 6, 7], bulls = 0, cows = 2
3: [5, 4, 8, 9], bulls = 0, cows = 2
4: [6, 7, 9, 8], bulls = 0, cows = 4
5: [8, 9, 7, 6], bulls = 2, cows = 2
6: [9, 8, 7, 6], bulls = 4, cows = 0
Good Job!

julia> solver([9, 4, 3, 1])
1: [0, 1, 2, 3], bulls = 0, cows = 2
2: [1, 0, 4, 5], bulls = 0, cows = 2
3: [2, 3, 5, 4], bulls = 0, cows = 2
4: [3, 4, 0, 6], bulls = 1, cows = 1
5: [3, 5, 6, 1], bulls = 1, cows = 1
6: [6, 5, 0, 2], bulls = 0, cows = 0
7: [7, 4, 3, 1], bulls = 3, cows = 0
8: [8, 4, 3, 1], bulls = 3, cows = 0
9: [9, 4, 3, 1], bulls = 4, cows = 0
Good Job!

肝心の質問回数ですが、5, 6 回で当たる場合が多いようです。実際に、5040 個のコードをすべて試してみたところ、平均は 5.56 回になりました。これは参考文献「数当てゲーム (MOO, マスターマインド)」の結果と同じです。質問回数の最大値は 9 回で、そのときのコードは [9 4 3 1], [9 2 4 1], [5 2 9 3], [9 2 0 4], [9 2 1 4] でした。

なお、参考文献 1 には平均質問回数がこれよりも少なくなる方法が紹介されています。単純な数当てゲームだと思っていましたが、その奥はけっこう深いようです。興味のある方はいろいろ試してみてください。

●参考文献

  1. 田中哲郎 「数当てゲーム (MOO, マスターマインド)」, 松原仁、竹内郁雄 編 『bit 別冊 ゲームプログラミング』 pp150 - 157, 共立出版, 1997

●プログラムリスト

#
# mastermind.jl : マスターマインドの解法
#
#                 Copyright (C) 2016-2021 Makoto Hiroi
#

# 定数
const CSIZE = 4

# 質問したコードとその結果
struct Query
    bulls::Int
    cows::Int
    code::Array{Int, 1}
end

# 0 - 9 から 4 個の数字を選ぶ順列を生成
function permutations(f, xs, n = 1)
    if n > CSIZE
        f(xs[1:CSIZE])
    else
        tmp = xs[n]
        for i = n : length(xs)
            xs[n] = xs[i]
            xs[i] = tmp
            permutations(f, xs, n + 1)
            xs[i] = xs[n]
            xs[n] = tmp
        end
    end
end

# bulls を数える
function count_bulls(xs, ys)
#=
    c = 0
    for i = 1 : CSIZE
        if xs[i] == ys[i]; c += 1; end
    end
    c
=#
    count(map(==, xs, ys))
end

# 同じ数字を数える
function count_same_number(xs, ys)
    c = 0
    for x = xs
        if x in ys c += 1 end
    end
    c
end

# 質問コードのチェック
function check(answer, xs)
    global query
    for q = query
        b = count_bulls(q.code, xs)
        c = count_same_number(q.code, xs) - b
        if b != q.bulls || c != q.cows
            return
        end
    end
    b = count_bulls(answer, xs)
    c = count_same_number(answer, xs) - b
    q = Query(b, c, xs)
    push!(query, q)
    n = length(query)
    println("$n: $xs, bulls = $b, cows = $c")
    if b == 4
        throw("Good Job!")
    end
end

# マスターマインドの解法
function solver(answer)
    global query
    query = Query[]
    try
        permutations(xs -> check(answer, xs), collect(0 : 9))
    catch e
        println(e)
    end
end

●マスターマインド (改)

M.Hiroi' Home Page で取り上げたマスターマインドは、0 から 9 までの重複しない 4 つの数字からなる隠しコードを当てるゲームでした。マスターマインドを解く場合、簡単な推測アルゴリズムを使うと、平均質問回数が 5.56 回で、質問回数の最大値は 9 回になります。

今回は数字の個数を 5 個に増やして、平均質問回数とその最大値がどうなるか、julia でプログラムを作って確かめてみました。プログラムは「マスターマインドの解法」を改造すると簡単に作ることができます。説明は割愛しますので、詳細はプログラムリストをお読みください。

結果ですが、平均質問回数が 5.99 回、質問回数の最大値は 9 で、そのときのコードは 84 通りになりました。もっと難しくなるかと思っていたので、予想外の結果にちょっと驚きました。

●プログラムリスト

#
# mastermind.jl : マスターマインドの解法
#                 (0 - 9 の数字から 5 個を選ぶ場合)
#
#                 Copyright (C) 2016 Makoto Hiroi
#

# 定数
const CSIZE = 5

# 質問したコードとその結果
struct Query
    bulls::Int
    cows::Int
    code::Array{Int, 1}
end

# 0 - 9 から 5 個の数字を選ぶ順列を生成
function permutations(f, xs, n = 1)
    if n > CSIZE
        f(xs[1:CSIZE])
    else
        tmp = xs[n]
        for i in n : length(xs)
            xs[n] = xs[i]
            xs[i] = tmp
            permutations(f, xs, n + 1)
            xs[i] = xs[n]
            xs[n] = tmp
        end
    end
end

# bulls を数える
function count_bulls(xs, ys)
    c = 0
    for i in 1 : CSIZE
        if xs[i] == ys[i]; c += 1; end
    end
    c
end

# 同じ数字を数える
function count_same_number(xs, ys)
    c = 0
    for x in xs
        for y in ys
            if x == y
                c += 1
                break
            end
        end
    end
    c
end

function check(answer, xs)
    global query
    for q in query
        b = count_bulls(q.code, xs)
        c = count_same_number(q.code, xs) - b
        if b != q.bulls || c != q.cows
            return
        end
    end
    b = count_bulls(answer, xs)
    c = count_same_number(answer, xs) - b
    q = Query(b, c, xs)
    push!(query, q)
    if b == CSIZE
        throw(length(query))
    end
end

function solver()
    c = 0
    m = 0
    max_code = []
    function solver_sub(answer)
        global query
        query = Query[]
        try
            permutations(xs -> check(answer, xs), collect(0:9))
        catch e
            if m < e
                m = e
                max_code = []
            end
            if m == e
                push!(max_code, answer)
            end
            c += e
        end
    end
    permutations(solver_sub, collect(0:9))
    println(c / (10 * 9 * 8 * 7 * 6))
    println(m)
    println(max_code)
    println(length(max_code))
end

solver()
5.994246031746032
9
Any[[1,8,3,9,0],[3,9,8,0,1],[5,2,9,1,7],[5,0,6,8,3],[5,7,8,1,2],[5,8,3,7,0],
[6,5,4,1,2],[6,5,4,0,2],[6,0,1,3,9],[7,2,3,4,5],[7,3,1,4,9],[7,3,8,0,6],
[7,3,8,2,5],[7,4,0,3,5],[7,4,9,2,6],[7,5,4,0,2],[7,5,4,0,1],[7,5,9,0,3],
[7,6,8,0,3],[7,8,0,2,5],[7,8,9,6,1],[7,9,1,6,3],[8,2,1,7,9],[8,2,7,6,0],
[8,3,6,4,2],[8,4,6,0,1],[8,6,4,3,0],[8,6,5,4,3],[8,6,0,1,2],[8,7,5,4,0],
[8,7,5,0,1],[8,7,6,2,3],[8,7,6,1,2],[8,7,6,0,2],[8,7,0,6,3],[8,7,0,2,5],
[8,7,0,9,1],[8,7,9,1,2],[8,7,9,0,2],[8,0,7,6,5],[8,0,7,9,4],[8,9,1,7,2],
[9,1,0,3,8],[9,1,0,4,7],[9,2,6,0,4],[9,3,7,6,5],[9,3,8,4,0],[9,4,3,7,6],
[9,4,1,8,0],[9,4,5,0,3],[9,4,6,3,0],[9,5,3,8,7],[9,5,4,2,0],[9,5,4,8,1],
[9,5,4,8,0],[9,5,6,8,7],[9,5,7,6,8],[9,5,7,0,3],[9,6,4,0,1],[9,7,3,8,1],
[9,7,3,8,0],[9,7,4,2,5],[9,7,5,1,2],[9,7,6,1,2],[9,7,1,6,3],[9,7,0,5,3],
[9,7,0,2,5],[9,8,2,4,0],[9,8,3,5,1],[9,8,3,6,0],[9,8,3,7,1],[9,8,4,1,5],
[9,8,4,0,7],[9,8,6,1,2],[9,8,6,0,2],[9,8,7,2,5],[9,8,1,5,3],[9,8,1,0,7],
[9,8,0,5,7],[9,8,0,6,7],[9,8,0,1,7],[9,0,2,8,3],[9,0,7,5,6],[9,0,8,6,5]]
84

初版 2018 年 10 月 21 日
改訂 2021 年 11 月 27 日