M.Hiroi's Home Page

Julia Language Programming

お気楽 Julia プログラミング超入門


Copyright (C) 2018-2021 Makoto Hiroi
All rights reserved.

番外編 : 式の計算

●再帰降下法

再帰降下法で式 (四則演算とカッコ) を計算するプログラムです。アルゴリズムの詳細は以下に示す拙作のページをお読みください。

●プログラムリスト

#
# calc.jl : 式の計算 (単純な再帰降下法)
#
#           Copyright (C) 2016-2018 Makoto Hiroi
#
using Printf

# 大域変数
# ch    : 文字 (Char)
# token : トークン (Symbol)
# value : 値 (Float64)

# 記号の先読み
function nextch()
    global ch
    ch = read(stdin, Char)
end

# 記号の読み込み
getch() = ch

# 整数値の読み込み
function get_fixnum(buff)
    while isdigit(getch())
        push!(buff, getch())
        nextch()
    end
end

# 数値を求める
function get_number()
    buff::Vector{Char} = []
    get_fixnum(buff)
    if getch() == '.'
        push!(buff, getch())
        nextch()
        get_fixnum(buff)
    end
    if getch() == 'e' || getch() == 'E'
        push!(buff, getch())
        nextch()
        if getch() == '+' || getch() == '-'
            push!(buff, getch())
            nextch()
        end
        get_fixnum(buff)
    end
    parse(Float64, join(buff))
end

# トークンの切り分け
function get_token()
    global token, value
    # 空白文字の読み飛ばし
    while isspace(getch()); nextch(); end
    if isdigit(getch())
        token = :NUMBER
        value = get_number()
    elseif getch() == '+'
        token = :ADD
        nextch()
    elseif getch() == '-'
        token = :SUB
        nextch()
    elseif getch() == '*'
        token = :MUL
        nextch()
    elseif getch() == '/'
        token = :DIV
        nextch()
    elseif getch() == '('
        token = :LPAR
        nextch()
    elseif getch() == ')'
        token = :RPAR
        nextch()
    elseif getch() == ';'
        token = :SEMIC
        nextch()
    else
        token = :OTHERS
    end
end

# 構文解析

# 式
function expression()
    val = term()
    while true
        if token === :ADD
            get_token()
            val += term()
        elseif token === :SUB
            get_token()
            val -= term()
        else
            break
        end
    end
    val
end

# 項
function term()
    val = factor()
    while true
        if token === :MUL
            get_token()
            val *= factor()
        elseif token === :DIV
            get_token()
            val /= factor()
        else
            break
        end
    end
    val
end

# 因子
function factor_sub()
    get_token()
    v = expression()
    if token === :RPAR
        get_token()
    else
        error("')' expected")
    end
    v
end

function factor()
    if token === :LPAR
        factor_sub()
    elseif token === :NUMBER
        get_token()
        value
    elseif token === :ADD
        get_token()
        factor()
    elseif token === :SUB
        get_token()
        -factor()
    else
        error("unexpected token")
    end
end

# トップレベル
function toplevel()
    val = expression()
    if token === :SEMIC
        @printf "=> %.14g\nCalc> " val
        flush(stdout)
    else
        error("invalid token")
    end
end

function calc()
    print("Calc> ")
    flush(stdout)
    nextch()
    while true
        try
            get_token()
            toplevel()
        catch e
            print("ERROR: ")
            showerror(stdout, e)
            println("")
            # 入力のクリア
            while getch() != '\n'; nextch(); end
            print("Calc> ")
            flush(stdout)
        end
    end
end

# 実行
calc()

●実行例

Calc> 1 + 2 + 3;
=> 6
Calc> (1 + 2) * (3 - 4);
=> -3
Calc> 1.2345678 * 1.1111111;
=> 1.3717419862826
Calc> 1 / 7;
=> 0.14285714285714
Calc> -1;
=> -1
Calc> -10 * -10;
=> 100
Calc> 1 + * 2;
ERROR: unexpected token
Calc> (1 + 2;
ERROR: ')' expected
Calc>  <--- Ctrl-C で終了

●構文木の構築

再帰降下法で構文木を構築して式 (四則演算とカッコ) を計算するプログラムです。アルゴリズムの詳細は以下の拙作のページをお読みください。

●プログラムリスト

#
# calc1.jl : 式の計算 (再帰降下法で構文木を構築)
#
#            Copyright (C) 2016-2018 Makoto Hiroi
#
using Printf

# 大域変数
# ch    : 文字 (Char)
# token : トークン (Symbol)
# value : 値 (Float64)

# 記号の先読み
function nextch()
    global ch
    ch = read(stdin, Char)
end

# 記号の読み込み
getch() = ch

# 整数値の読み込み
function get_fixnum(buff)
    while isdigit(getch())
        push!(buff, getch())
        nextch()
    end
end

# 数値を求める
function get_number()
    buff::Vector{Char} = []
    get_fixnum(buff)
    if getch() == '.'
        push!(buff, getch())
        nextch()
        get_fixnum(buff)
    end
    if getch() == 'e' || getch() == 'E'
        push!(buff, getch())
        nextch()
        if getch() == '+' || getch() == '-'
            push!(buff, getch())
            nextch()
        end
        get_fixnum(buff)
    end
    parse(Float64, join(buff))
end

# トークンの切り分け
function get_token()
    global token, value
    # 空白文字の読み飛ばし
    while isspace(getch()); nextch(); end
    if isdigit(getch())
        token = :NUMBER
        value = get_number()
    elseif getch() == '+'
        token = :ADD
        nextch()
    elseif getch() == '-'
        token = :SUB
        nextch()
    elseif getch() == '*'
        token = :MUL
        nextch()
    elseif getch() == '/'
        token = :DIV
        nextch()
    elseif getch() == '('
        token = :LPAR
        nextch()
    elseif getch() == ')'
        token = :RPAR
        nextch()
    elseif getch() == ';'
        token = :SEMIC
        nextch()
    else
        token = :OTHERS
    end
end

# 構文木
# 二項演算子
struct Op2
    op
    left
    right
end

# 単項演算子
struct Op1
    op
    right
end

# 構文解析

# 式
function expression()
    expr = term()
    while true
        if token === :ADD
            get_token()
            expr = Op2(:ADD2, expr, term())
        elseif token === :SUB
            get_token()
            expr = Op2(:SUB2, expr, term())
        else
            break
        end
    end
    expr
end

# 項
function term()
    expr = factor()
    while true
        if token === :MUL
            get_token()
            expr = Op2(:MUL2, expr, factor())
        elseif token === :DIV
            get_token()
            expr = Op2(:DIV2, expr, factor())
        else
            break
        end
    end
    expr
end

# 因子
function factor_sub()
    get_token()
    expr = expression()
    if token === :RPAR
        get_token()
    else
        error("')' expected")
    end
    expr
end

function factor()
    if token === :LPAR
        factor_sub()
    elseif token === :NUMBER
        get_token()
        value
    elseif token === :ADD
        get_token()
        Op1(:ADD1, factor())
    elseif token === :SUB
        get_token()
        Op1(:SUB1, factor())
    else
        error("unexpected token")
    end
end

# 式の評価
function eval_expr(expr)
    if typeof(expr) == Float64
        expr
    elseif typeof(expr) == Op2
        l_val = eval_expr(expr.left)
        r_val = eval_expr(expr.right)
        if expr.op === :ADD2
            l_val + r_val
        elseif expr.op === :SUB2
            l_val - r_val
        elseif expr.op === :MUL2
            l_val * r_val
        elseif expr.op === :DIV2
            l_val / r_val
        else
            error("unknown Op2")
        end
    elseif typeof(expr) == Op1
        val = eval_expr(expr.right)
        if expr.op == :ADD1
            val
        elseif expr.op == :SUB1
            -val
        else
            error("unknown Op1")
        end
    else
        error("broken expression")
    end
end

# トップレベル
function toplevel()
    expr = expression()
    if token === :SEMIC
        @printf "=> %.14g\nCalc> " eval_expr(expr)
        flush(stdout)
    else
        error("invalid token")
    end
end

function calc()
    print("Calc> ")
    flush(stdout)
    nextch()
    while true
        try
            get_token()
            toplevel()
        catch e
            print("ERROR: ")
            showerror(stdout, e)
            println("")
            # 入力のクリア
            while getch() != '\n'; nextch(); end
            print("Calc> ")
            flush(stdout)
        end
    end
end

# 実行
calc()

●組み込み関数と大域変数の追加

式を計算するプログラムに組み込み関数と大域変数の機能を追加したものです。アルゴリズムの詳細は拙作の以下のページをお読みください。

●プログラムリスト

#
# calc3.jl : 式の計算 (組み込み関数と大域変数の追加)
#
#            Copyright (C) 2016-2018 Makoto Hiroi
#
using Printf

# 大域変数
# ch    : 文字 (Char)
# token : トークン (Symbol)
# value : 値 (Float64)

# 記号の先読み
function nextch()
    global ch
    ch = read(stdin, Char)
end

# 記号の読み込み
getch() = ch

# 整数値の読み込み
function get_fixnum(buff)
    while isdigit(getch())
        push!(buff, getch())
        nextch()
    end
end

# 数値を求める
function get_number()
    buff::Vector{Char} = []
    get_fixnum(buff)
    if getch() == '.'
        push!(buff, getch())
        nextch()
        get_fixnum(buff)
    end
    if getch() == 'e' || getch() == 'E'
        push!(buff, getch())
        nextch()
        if getch() == '+' || getch() == '-'
            push!(buff, getch())
            nextch()
        end
        get_fixnum(buff)
    end
    parse(Float64, join(buff))
end

# 識別子 (identifier) を求める
function get_ident()
    buff::Vector{Char} = []
    while isletter(getch()) || isdigit(getch())
        push!(buff, getch())
        nextch()
    end
    Symbol(join(buff))
end

# トークンの切り分け
function get_token()
    global token, value
    # 空白文字の読み飛ばし
    while isspace(getch()); nextch(); end
    if isdigit(getch())
        token = :NUMBER
        value = get_number()
    elseif isletter(getch())    # isalpha => isletter
        token = :IDENT
        value = get_ident()
    elseif getch() == '+'
        token = :ADD
        nextch()
    elseif getch() == '-'
        token = :SUB
        nextch()
    elseif getch() == '*'
        token = :MUL
        nextch()
    elseif getch() == '/'
        token = :DIV
        nextch()
    elseif getch() == '('
        token = :LPAR
        nextch()
    elseif getch() == ')'
        token = :RPAR
        nextch()
    elseif getch() == ','
        token = :COMMA
        nextch()
    elseif getch() == '='
        token = :ASGN
        nextch()
    elseif getch() == ';'
        token = :SEMIC
        nextch()
    else
        token = :OTHERS
    end
end

# 構文木
# 二項演算子
struct Op2
    op
    left
    right
end

# 単項演算子
struct Op1
    op
    right
end

# 関数
struct Func
    fn
    args
end

# 組み込み関数
func_table = Dict{Symbol, Function}(
    :sqrt => sqrt,
    :sin => sin,
    :cos => cos,
    :tan => tan,
    :asin => asin,
    :acos => acos,
    :atan => atan,
    # :atan2 => atan2, atan(x, y) が atan2 になった
    :exp => exp,
    :pow => ^,
    :ln => log,
    :log => x -> log(10, x),
    :sinh => sinh,
    :cosh => cosh,
    :tanh => tanh
)

# 大域変数
const global_variable = Dict{Symbol, Float64}()

# 構文解析

# 式
function expression()
    expr = expr1()
    if token === :ASGN
        if typeof(expr) == Symbol
            get_token()
            expr = Op2(:ASGN2, expr, expression())
        else
            error("invalid assign form")
        end
    end
    expr
end

function expr1()
    expr = term()
    while true
        if token === :ADD
            get_token()
            expr = Op2(:ADD2, expr, term())
        elseif token === :SUB
            get_token()
            expr = Op2(:SUB2, expr, term())
        else
            break
        end
    end
    expr
end

# 項
function term()
    expr = factor()
    while true
        if token === :MUL
            get_token()
            expr = Op2(:MUL2, expr, factor())
        elseif token === :DIV
            get_token()
            expr = Op2(:DIV2, expr, factor())
        else
            break
        end
    end
    expr
end

# 因子
function factor_sub()
    get_token()
    expr = expression()
    if token === :RPAR
        get_token()
    else
        error("')' expected")
    end
    expr
end

# 引数の取得
function get_arguments()
    args = []
    if token !== :LPAR
        error("'(' expected")
    end
    get_token()
    while true
        push!(args, expression())
        if token !== :COMMA; break; end
        get_token()
    end
    if token !== :RPAR
        error("')' expected")
    end
    get_token()
    args
end

function factor()
    if token === :LPAR
        factor_sub()
    elseif token === :NUMBER
        get_token()
        value
    elseif token === :ADD
        get_token()
        Op1(:ADD1, factor())
    elseif token === :SUB
        get_token()
        Op1(:SUB1, factor())
    elseif token === :IDENT
        get_token()
        if haskey(func_table, value)
            Func(func_table[value], get_arguments())
        else 
            value
        end
    else
        error("unexpected token")
    end
end

# 式の評価
function eval_expr(expr)
    if typeof(expr) == Float64
        expr
    elseif typeof(expr) == Symbol
        if haskey(global_variable, expr)
            global_variable[expr]
        else
            error("undefined variable")
        end
    elseif typeof(expr) == Op2
        if expr.op == :ASGN2
            global_variable[expr.left] = eval_expr(expr.right)
        else
            l_val = eval_expr(expr.left)
            r_val = eval_expr(expr.right)
            if expr.op === :ADD2
                l_val + r_val
            elseif expr.op === :SUB2
                l_val - r_val
            elseif expr.op === :MUL2
                l_val * r_val
            elseif expr.op === :DIV2
                l_val / r_val
            else
                error("unknown Op2")
            end
        end
    elseif typeof(expr) == Op1
        val = eval_expr(expr.right)
        if expr.op === :ADD1
            val
        elseif expr.op === :SUB1
            -val
        else
            error("unknown Op1")
        end
    elseif typeof(expr) == Func
        args = map(x -> eval_expr(x), expr.args)
        expr.fn(args...)
    else
        error("broken expression")
    end
end

# トップレベル
function toplevel()
    expr = expression()
    if token === :SEMIC
        @printf "=> %.14g\nCalc> " eval_expr(expr)
        flush(stdout)
    else
        error("invalid token")
    end
end

function calc()
    print("Calc> ")
    flush(stdout)
    nextch()
    while true
        try
            get_token()
            toplevel()
        catch e
            print("ERROR: ")
            showerror(stdout, e)
            println("")
            # 入力のクリア
            while getch() != '\n'; nextch(); end
            print("Calc> ")
            flush(stdout)
        end
    end
end

# 実行
calc()

●実行例

Calc> a = 10;
=> 10
Calc> a;
=> 10
Calc> a * 10;
=> 100
Calc> (b = 20) * 10;
=> 200
Calc> b;
=> 20
Calc> x = y = z = 0;
=> 0
Calc> x;
=> 0
Calc> y;
=> 0
Calc> z;
=> 0
Calc> p = p + 1;
ERROR: undefined variable
Calc> q = 1;
=> 1
Calc> q = q + 1;
=> 2
Calc> q;
=> 2
Calc> sqrt(2);
=> 1.4142135623731
Calc> pow(2, 32);
=> 4294967296
Calc> pow(2, 32) - 1;
=> 4294967295
Calc> pi = asin(0.5) * 6;
=> 3.1415926535898
Calc> sin(0);
=> 0
Calc> sin(pi);
=> -3.2162452993533e-16
Calc> sin(pi/2);
=> 1

●最小の Lisp

小さな小さな Scheme ライクの Lisp インタプリタです。最小の Lisp については、拙作の以下のページをお読みください。

●micro Scheme の仕様

●プログラムリスト

#
# mscm.jl : Micro Scheme インタプリタ
#
#           Copyright (C) 2016-2018 Makoto Hiroi
#
using Printf

# 大域変数
# ch : 文字 (Char)

# 記号の先読み
function nextch()
    global ch
    ch = read(stdin, Char)
end

# 記号の読み込み
getch() = ch

# 整数値の読み込み
function get_fixnum(buff)
    while isdigit(getch())
        push!(buff, getch())
        nextch()
    end
end

# 数値を求める
function get_number()
    buff::Vector{Char} = []
    flag = false
    push!(buff, getch())
    nextch()
    get_fixnum(buff)
    if getch() == '.'
        flag = true
        push!(buff, getch())
        nextch()
        get_fixnum(buff)
    end
    if getch() == 'e' || getch() == 'E'
        flag = true
        push!(buff, getch())
        nextch()
        if getch() == '+' || getch() == '-'
            push!(buff, getch())
            nextch()
        end
        get_fixnum(buff)
    end
    if length(buff) == 1 && !isdigit(buff[1])
        Symbol(buff[1])
    elseif flag
        parse(Float64, join(buff))
    else
        parse(Int128, join(buff))
    end
end

# シンボルに含めてよい記号
code_list = "!&*+-/:<=>?@^_~"

# 識別子 (identifier) を求める
function get_ident()
    buff::Vector{Char} = []
    while isletter(getch()) || isdigit(getch()) || findfirst(isequal(getch()), code_list) !== nothing
        push!(buff, getch())
        nextch()
    end
    Symbol(join(buff))
end

# セル (空リストは :nil)
struct Cons
    car
    cdr
end

# クロージャ
struct Closure
    para
    body
    env
end

# コンスセルか?
consp(xs) = typeof(xs) == Cons

# 空リストか?
null(xs) = xs === :nil

# 数か?
numberp(x) = isa(x, Number)

# シンボルか?
symbolp(x) = typeof(x) == Symbol

#
# read
#

# 空白文字の読み飛ばし
skipspace() = while isspace(getch()); nextch(); end

# リストの読み込み
function read_list()
    skipspace()
    if getch() == ')'
        nextch()
        :nil
    elseif getch() == '.'
        nextch()
        x = read_s()
        skipspace()
        if getch() != ')'
            error("invalid dot list")
        end
        nextch()
        x
    else
        Cons(read_s(), read_list())
    end
end

function read_s()
    skipspace()
    c = getch()
    if isdigit(c) || c == '+' || c == '-'
        get_number()
    elseif isletter(c) || findfirst(isequal(c), code_list) !== nothing
        get_ident()
    elseif getch() == '\''
        nextch()
        Cons(:quote, Cons(read_s(), :nil))
    elseif getch() == '('
        nextch()
        read_list()
    else
        error("invalid token")
    end
end

#
# print
#
function print_s(xs)
    if consp(xs)
        print("(")
        while consp(xs)
            if typeof(xs.car) == Cons
                print_s(xs.car)
            else
                print(xs.car)
            end
            if !null(xs.cdr); print(" "); end
            xs = xs.cdr
        end
        if !null(xs)
            print(". "); print(xs)
        end
        print(")")
    else
        print(xs)
    end
end

#
# 環境
#

# 大域変数
oblist = Dict{Symbol, Any}(
    # 真偽値 (Common Lisp ライク)
    :t => :t,
    :nil => :nil,

    :car => xs -> xs.car,
    :cdr => xs -> xs.cdr,
    :cons => (x, y) -> Cons(x, y),
    Symbol("eq?") => (x, y) -> x === y ? :t : :nil,
    Symbol("pair?") => xs -> consp(xs) ? :t : :nil,

    :+ => (args...) -> +(args...),
    :* => (args...) -> *(args...),
    :- => (n, args...) -> if length(args) == 0
                              -n
                          else
                              for x in args; n -= x; end
                              n
                          end,
    :/ => (n, args...) -> if length(args) == 0
                              1 / n
                          else
                              for x in args; n /= x; end
                              n
                          end,
    Symbol("=") => (x, y) -> x == y ? :t : :nil,
    :<  => (x, y) -> x <  y ? :t : :nil,
    :<= => (x, y) -> x <= y ? :t : :nil,
    :>  => (x, y) -> x >  y ? :t : :nil,
    :>= => (x, y) -> x >= y ? :t : :nil,
    :exit => () -> exit()
)

# 連想リストの探索
function assoc(xs, x)
    while !null(xs)
        if xs.car.car == x; return xs.car; end
        xs = xs.cdr
    end
    :nil
end

# 参照
function lookup(var, env)
    # 局所変数の探索
    xs = assoc(env, var)
    if !null(xs)
        xs.cdr
    elseif haskey(oblist, var)
        # 大域変数
        oblist[var]
    else
        error("undefind variable")
    end
end

#
# 特殊形式
#
function define_f(xs, env)
    if !symbolp(xs.car)
        error("Symbol required")
    end
    oblist[xs.car] = eval_s(xs.cdr.car, env)
    xs.car
end

function if_f(xs, env)
    if eval_s(xs.car, env) !== :nil
        eval_s(xs.cdr.car, env)
    elseif !null(xs.cdr.cdr)
        eval_s(xs.cdr.cdr.car, env)
    else
        :nil
    end
end

#
# eval
#

# 引数の評価
function eval_arguments(args, env)
    a = []
    while !null(args)
        push!(a, eval_s(args.car, env))
        args = args.cdr
    end
    a
end

# 変数束縛
function add_binding(para, args, env)
    for x in args
        if !consp(para)
            error("wrong number of arguments")
        end
        env = Cons(Cons(para.car, x), env)
        para = para.cdr
    end
    if !null(para)
        error("wrong number of arguments")
    end
    env
end

# 本体の評価
function eval_body(xs, env)
    local result
    while !null(xs)
        result = eval_s(xs.car, env)
        xs = xs.cdr
    end
    result
end

# 関数適用
function apply_s(fn, args)
    if typeof(fn) <: Function
        fn(args...)                 # 組み込み関数
    elseif typeof(fn) == Closure
        eval_body(fn.body, add_binding(fn.para, args, fn.env))
    else
        error("invalid function")
    end
end

function eval_s(xs, env)
    if numberp(xs)
        xs
    elseif symbolp(xs)
        lookup(xs, env)
    elseif consp(xs)
        if xs.car == :quote
            xs.cdr.car
        elseif xs.car == :define
            define_f(xs.cdr, env)
        elseif xs.car == :if
            if_f(xs.cdr, env)
        elseif xs.car == :lambda
            Closure(xs.cdr.car, xs.cdr.cdr, env)
        else
            # 関数呼び出し
            fn = eval_s(xs.car, env)
            apply_s(fn, eval_arguments(xs.cdr, env))
        end
    else
        error("unknown object")
    end
end

# REPL
function mscm()
    print("mscm> ")
    flush(stdout)
    nextch()
    while true
        try
            sexp = read_s()
            print_s(eval_s(sexp, :nil))
            print("\nmscm> ")
        catch e
            print("ERROR: ")
            showerror(stdout, e)
            println("")
            # 入力のクリア
            while getch() != '\n'; nextch(); end
            print("mscm> ")
            flush(stdout)
        end
    end
end

# 実行
mscm()

●実行例

$ julia mscm.jl
mscm> 'a
a
mscm> 12345
12345
mscm> 1.2345
1.2345
mscm> '(1 2 3 4 5)
(1 2 3 4 5)
mscm> (car '(a b c))
a
mscm> (cdr '(a b c))
(b c)
mscm> (cons 'a 'b)
(a . b)
mscm> (pair? '(a b c))
t
mscm> (pair? 'a)
nil
mscm> (eq? 'a 'a)
t
mscm> (eq? 'a 'b)
nil
mscm> (define a 10)
a
mscm> a
10
mscm> (define square (lambda (x) (* x x)))
square
mscm> (square 123)
15129
mscm> ((lambda (x) (* x x)) 10)
100
mscm> (define fact (lambda (n) (if (= n 0) 1 (* n (fact (- n 1))))))
fact
mscm> (fact 10)
3628800
mscm> (fact 20)
2432902008176640000
mscm> (define iota (lambda (n m) (if (> n m) nil (cons n (iota (+ n 1) m)))))
iota
mscm> (iota 1 10)
(1 2 3 4 5 6 7 8 9 10)
mscm> (define map (lambda (f xs) (if (pair? xs) (cons (f (car xs)) (map f (cdr xs))))))
map
mscm> (map (lambda (x) (* x x)) (iota 1 20))
(1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361 400)
mscm> (define foo (lambda (x) (lambda (y) (+ x y))))
foo
mscm> (define foo100 (foo 100))
foo100
mscm> (foo100 100)
200
mscm> (foo100 1000)
1100
mscm> (exit)
$ 

簡単なプログラム

●直積集合

集合を表すベクトル xs, ys の「直積集合 (direct product)」を求める関数 product(xs, ys) を作りましょう。xs の要素を xi, ys 要素を yj とすると、直積集合の要素は (xi, yj) となります。たとえば、xs = [1, 2, 3], ys = [4, 5] とすると、直積集合は [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] になります。

julia> product(xs, ys) = [(x, y) for x = xs for y = ys]
product (generic function with 1 method)

julia> product([1, 2, 3], [4, 5])
6-element Vector{Tuple{Int64, Int64}}:
 (1, 4)
 (1, 5)
 (2, 4)
 (2, 5)
 (3, 4)
 (3, 5)

このように、内包表記を使うと product() は簡単に定義できます。次は、引数を 2 つに限定せず、1 個以上の集合を受け付けるように改良してみましょう。次のリストを見てください。

リスト : 直積集合

function product(args...)
    if length(args) == 0
        [()]
    else
        xs = [(x,) for x = args[1]]
        for i = 2 : length(args)
            xs = [(x..., y) for x = xs for y = args[i]]
        end
        xs
    end
end

args が空の配列であれば [()] を返します。そうでなければ、先頭の配列 args[1] の要素をタプルに包み、それを格納した配列に変換します。あとは args から順番に配列を取り出し、タプル x の後ろに配列の要素 y を追加していくだけです。Julia の場合、リストやタプルの後ろに ... を付けると、要素が展開されることに注意してください。

それでは実際に試してみましょう。

julia> product()
1-element Vector{Tuple{}}:
 ()

julia> product([1, 2, 3])
3-element Vector{Tuple{Int64}}:
 (1,)
 (2,)
 (3,)

julia> product([1, 2, 3], [4, 5, 6])
9-element Vector{Tuple{Int64, Int64}}:
 (1, 4)
 (1, 5)
 (1, 6)
 (2, 4)
 (2, 5)
 (2, 6)
 (3, 4)
 (3, 5)
 (3, 6)

julia> product([1, 2], [3, 4], [5, 6])
8-element Vector{Tuple{Int64, Int64, Int64}}:
 (1, 3, 5)
 (1, 3, 6)
 (1, 4, 5)
 (1, 4, 6)
 (2, 3, 5)
 (2, 3, 6)
 (2, 4, 5)
 (2, 4, 6)

julia> product([1, 2], [3, 4], [5, 6], [7, 8])
16-element Vector{NTuple{4, Int64}}:
 (1, 3, 5, 7)
 (1, 3, 5, 8)
 (1, 3, 6, 7)
 (1, 3, 6, 8)
 (1, 4, 5, 7)
 (1, 4, 5, 8)
 (1, 4, 6, 7)
 (1, 4, 6, 8)
 (2, 3, 5, 7)
 (2, 3, 5, 8)
 (2, 3, 6, 7)
 (2, 3, 6, 8)
 (2, 4, 5, 7)
 (2, 4, 5, 8)
 (2, 4, 6, 7)
 (2, 4, 6, 8)

正常に動作していますね。なお、Julia には直積集合を求める関数 Iterators.product(iter...) が用意されています。Julia の product() はイテレータを返すことに注意してください。

julia> Iterators.product([1, 2, 3])
Base.Iterators.ProductIterator{Tuple{Vector{Int64}}}(([1, 2, 3],))

julia> for x in Iterators.product([1,2,3]) println(x) end
(1,)
(2,)
(3,)

julia> for x in Iterators.product([1, 2, 3], [4, 5, 6]) println(x) end
(1, 4)
(2, 4)
(3, 4)
(1, 5)
(2, 5)
(3, 5)
(1, 6)
(2, 6)
(3, 6)

julia> for x in Iterators.product([1, 2], [3, 4], [5, 6], [7, 8]) println(x) end
(1, 3, 5, 7)
(2, 3, 5, 7)
(1, 4, 5, 7)
(2, 4, 5, 7)
(1, 3, 6, 7)
(2, 3, 6, 7)
(1, 4, 6, 7)
(2, 4, 6, 7)
(1, 3, 5, 8)
(2, 3, 5, 8)
(1, 4, 5, 8)
(2, 4, 5, 8)
(1, 3, 6, 8)
(2, 3, 6, 8)
(1, 4, 6, 8)
(2, 4, 6, 8)

●重複組み合わせ

ベクトル xs から重複を許して n 個の要素を選ぶ組み合わせを求める関数 repeatcomb(f, n, xs) を作ります。

リスト : 重複組み合わせ

function repeatcomb(f, n, xs)
    ys::typeof(xs) = []

    function comb(m)
        if length(ys) == n
            f(ys)
        elseif length(xs) >= m
            push!(ys, xs[m])
            comb(m)
            pop!(ys)
            comb(m + 1)
        end
    end

    comb(1)
end

重複組み合わせを求める repeatcomb() は簡単です。局所関数 comb() の elseif 節で、xs の m 番目の要素を選んだら、その要素を取り除かないで、そこから残りの要素を選ぶようにします。これで同じ要素を何回も選ぶことができます。

julia> repeatcomb(println, 2, [1, 2, 3])
[1, 1]
[1, 2]
[1, 3]
[2, 2]
[2, 3]
[3, 3]

julia> repeatcomb(println, 3, [1, 2, 3, 4])
[1, 1, 1]
[1, 1, 2]
[1, 1, 3]
[1, 1, 4]
[1, 2, 2]
[1, 2, 3]
[1, 2, 4]
[1, 3, 3]
[1, 3, 4]
[1, 4, 4]
[2, 2, 2]
[2, 2, 3]
[2, 2, 4]
[2, 3, 3]
[2, 3, 4]
[2, 4, 4]
[3, 3, 3]
[3, 3, 4]
[3, 4, 4]
[4, 4, 4]

●重複順列

ベクトル xs から重複を許して n 個の要素を選ぶ順列を求める関数 repeatperm(f, n, xs) を作ります。

リスト : 重複順列

function repeatperm(f, n, xs)
    ys::typeof(xs) = []

    function perm()
        if length(ys) == n
            f(ys)
        else
            for x = xs
                push!(ys, x)
                perm()
                pop!(ys)
            end
        end
    end

    perm()
end

重複順列はとても簡単です。同じ要素を何度も選んでいいので、重複要素のチェックを省くだけです。

julia> repeatperm(println, 2, [1, 2, 3])
[1, 1]
[1, 2]
[1, 3]
[2, 1]
[2, 2]
[2, 3]
[3, 1]
[3, 2]
[3, 3]

julia> repeatperm(println, 3, [1, 2, 3])
[1, 1, 1]
[1, 1, 2]
[1, 1, 3]
[1, 2, 1]
[1, 2, 2]
[1, 2, 3]
[1, 3, 1]
[1, 3, 2]
[1, 3, 3]
[2, 1, 1]
[2, 1, 2]
[2, 1, 3]
[2, 2, 1]
[2, 2, 2]
[2, 2, 3]
[2, 3, 1]
[2, 3, 2]
[2, 3, 3]
[3, 1, 1]
[3, 1, 2]
[3, 1, 3]
[3, 2, 1]
[3, 2, 2]
[3, 2, 3]
[3, 3, 1]
[3, 3, 2]
[3, 3, 3]

●完全順列

m 個の整数 1, 2, ..., m の順列を考えます。このとき、i 番目 (先頭要素が 1 番目) の要素が整数 i ではない順列を「完全順列 (derangement)」といいます。今回は 1 から m までの整数値で完全順列を生成する関数を作ってみましょう。

リスト : 完全順列

function derangement(f, xs)
    m = length(xs)
    ys = zeros(Int, m)

    function perm(m::Int, n::Int)
        if m < n
            f(ys)
        else
            for x = xs
                if x == n || x in ys[1:n-1] continue end
                ys[n] = x
                perm(m, n + 1)
            end
        end
    end

    perm(m, 1)
end

関数 derangement() は、基本的には 1 から m までの数字を m 個選ぶ順列を生成する処理と同じです。n 番目の数字を選ぶとき、数字 x が n と等しい場合は x を選択しません。n が m より大きい場合は m 個の数字を選んだので、関数 f に ys を渡して実行します。これで完全順列を生成することができます。

julia> derangement(println, [1,2,3])
[2, 3, 1]
[3, 1, 2]

julia> derangement(println, [1,2,3,4])
[2, 1, 4, 3]
[2, 3, 4, 1]
[2, 4, 1, 3]
[3, 1, 4, 2]
[3, 4, 1, 2]
[3, 4, 2, 1]
[4, 1, 2, 3]
[4, 3, 1, 2]
[4, 3, 2, 1]

●モンモール数

完全順列の総数を「モンモール数 (Montmort number)」といいます。モンモール数は次の漸化式で求めることができます。

\(\begin{array}{l} A_1 = 0 \\ A_2 = 1 \\ A_n = (n - 1) \times (A_{n-1} + A_{n-2}) \quad \mathrm{if} \ n \geq 3 \end{array}\)

モンモール数を求める関数 montmort() を作りましょう。

リスト : 完全順列の総数 (モンモール数)

function montmort(n)
    if n == 1
        0
    elseif n == 2
        1
    else
        (n - 1) * (montmort(n - 1) + montmort(n - 2))
    end
end

# 別解
function montmort1(n)
    a, b = 0, 1
    for i = 1 : n-1
        a, b = b, (i + 1) * (a + b)
    end
    a
end

関数 montmort() は公式をそのままプログラムしただけです。二重再帰になっているので、実行速度はとても遅くなります。これを繰り返しに変換すると別解のようになります。

考え方は「フィボナッチ数」と同じです。変数 a に i 番目の値を、b に i + 1 番目の値を保存しておきます。すると、i + 2 番目の値は (i + 1) * (a + b) で計算することができます。あとは b の値を a に、新しい値を b にセットして処理を繰り返すだけです。

julia> for i = 1 : 20 println(montmort(i)) end
0
1
2
9
44
265
1854
14833
133496
1334961
14684570
176214841
2290792932
32071101049
481066515734
7697064251745
130850092279664
2355301661033953
44750731559645106
895014631192902121

julia> for i = 1 : 20 println(montmort1(i)) end
0
1
2
9
44
265
1854
14833
133496
1334961
14684570
176214841
2290792932
32071101049
481066515734
7697064251745
130850092279664
2355301661033953
44750731559645106
895014631192902121

julia> montmort1(big(30))
97581073836835777732377428235481

julia> montmort1(big(50))
11188719610782480504630258070757734324011354208865721592720336801

●ベル数

集合を分割する方法の総数を「ベル数 (Bell Number)」といい、次の漸化式で求めることができます。

\(\begin{array}{l} B(0) = 1 \\ B(1) = 1 \\ B(n+1) = \displaystyle \sum_{k=0}^n {}_n \mathrm{C}_k B(k) \quad \mathrm{if} \ n \geq 1 \end{array}\)

ベル数を求める関数 bellnumber(n) を作りましょう。

リスト : ベル数

# 組み合わせの数
function combination_number(n, r)
    if n == 0 || r == 0
        1
    else
        div(combination_number(n, r - 1) * (n - r + 1), r)
    end
end

# ベル数
function bellnumber(n)
    bs = [big(1)]
    for i = 0 : n - 1
        a = 0
        for (k, b) = enumerate(bs)
            a += combination_number(i, k - 1) * b
        end
        push!(bs, a)
    end
    bs[end]
end

bellnumber() は公式をそのままプログラムするだけです。累積変数 bs にベル数を格納します。\({}_n \mathrm{C}_k\) は関数 combination_number() で求めます。二番目の for ループで \({}_n \mathrm{C}_k \times B(k)\) の総和を計算します。Julia の配列は 1 から始まるので、k を combination_number() に渡すときは -1 していることに注意してください。あとは、その値を push!() で bs に追加するだけです。

julia> for i = 0 : 10 println(i, " ", bellnumber(i)) end
0 1
1 1
2 2
3 5
4 15
5 52
6 203
7 877
8 4140
9 21147
10 115975

julia> bellnumber(20)
51724158235372

julia> bellnumber(40)
157450588391204931289324344702531067

julia> bellnumber(60)
976939307467007552986994066961675455550246347757474482558637

初版 2018 年 11 月 3 日
改訂 2021 年 12 月 5 日