M.Hiroi's Home Page

Scala Programming

お気楽 Scala プログラミング入門

[ PrevPage | Scala | NextPage ]

整数の論理演算とビット操作

関数型言語や Scala では、データを表すのにリストがよく使われます。ところが、問題によってはリストよりもビットで表した方が、プログラムを作るのに都合がよい場合もあります。今回は Scala のビット操作について説明します。

●ビット演算子

Scala のビット演算子は Java と同じです。

表 : ビット演算子
演算子操作
x & y ビットごとの論理積
x | y ビットごとの論理和
x ^ y ビットごとの排他的論理和
~x ビットごとの否定
x << y x を y ビット左シフト
x >> y x を y ビット右シフト (算術シフト)
x >>> y x を y ビット右シフト

演算子 & はビットごとの論理積を返します。

scala> 5 & 3
val res0: Int = 1
     0101
 AND 0011
---------
     0001

演算子 l はビットごとの論理和を返します。

scala> 5 | 3
val res1: Int = 7
    0101
 OR 0011
--------
    0111

演算子 ^ はビットごとの排他的論理和を返します。

scala> 5 ^ 3
val res2: Int = 6
     0101
 XOR 0011
---------
     0110

演算子 ~ はビットごとの論理的な否定を返します。

scala> ~1
val res3: Int = -2

scala> ~0
val res4: Int = -1

<<, >>, >>> はビットをシフトする演算子です。>> は算術シフトなので、負の数で i ビット右シフトしたあと、上位 i 個のビットは 1 になります。>>> の場合、上位ビットには 0 が挿入されます。

scala> 1 << 8
val res5: Int = 256

scala> 1 << 16
val res6: Int = 65536

scala> 256 >> 8
val res7: Int = 1

scala> 65536 >> 8
val res8: Int = 256

scala> -256 >> 8
val res9: Int = -1

scala> -256 >>> 8
val res10: Int = 16777215

それでは簡単な例題として、基本的なビット操作関数を作ってみましょう。次のリストを見てください。

リスト : 基本的なビット操作

// ビットのテスト
def testBit(x: Int, n: Int): Boolean = (x & (1 << n)) != 0

// ビットセット
def setBit(x: Int, n: Int): Int = x | (1 << n)

// ビットクリア
def clearBit(x: Int, n: Int): Int = x & ~(1 << n)

testBit は整数 x の n 番目のビットが 1 ならば true を返します。最下位 (LSB) のビットが 0 番目になります。Int (32 bit) の場合、n は 0 から 31 になります。1 を n ビット左シフトして、x との論理積が 0 でなければ、n 番目のビットは 1 であることがわかります。

bitSet は x の n 番目のビットを 1 にセットします。1 を n ビット左シフトして、x との論理和を計算すれば、n 番目のビットを 1 にすることができます。clearBit は x の n 番目のビットを 0 にクリアします。これは n 番目以外のビットを 1 に、n 番目のビットを 0 にして、それと x の論理積を計算すれば、n 番目のビットをクリアすることができます。1 を n ビット左シフトして、その否定を計算すると、n 番目のビット以外は 1 になります。

それでは実際に試してみましょう。

scala> testBit(256, 7)
val res0: Boolean = false

scala> testBit(256, 8)
val res1: Boolean = true

scala> testBit(256, 9)
val res2: Boolean = false

scala> for (i <- 0 to 16) println(setBit(0, i))
1
2
4
8
16
32
64
128
256
512
1024
2048
4096
8192
16384
32768
65536

scala> for (i <- 0 to 15) println(clearBit(65535, i))
65534
65533
65531
65527
65519
65503
65471
65407
65279
65023
64511
63487
61439
57343
49151
32767

●組み合わせの生成

組み合わせの生成は拙作のページ 順列と組み合わせ で説明しました。このほかに、n 個の中から m 個を選ぶ組み合わせは、ビットの 0, 1 で表すことができます。たとえば、5 個の数字 (0 - 4) から 3 個を選ぶ場合、数字を 0 番目 から 4 番目のビットに対応させます。すると、1, 3, 4 という組み合わせは 11010 と表すことができます。簡単な例題として、ビットを使って組み合わせを求めてみましょう。

組み合わせを求めるプログラムは次のようになります。

リスト : 組み合わせの生成

def combinations(f: Int => Unit, n: Int, m: Int, a: Int = 0): Unit = {
  if (m == 0) f(a)
  else if (m == n) f(a | ((1 << m) - 1))
  else {
    combinations(f, n - 1, m, a)
    combinations(f, n - 1, m - 1, setBit(a, n - 1))
  }
}

関数 combinations は n 個の中から m 個を選ぶ組み合わせを生成して、引数の関数 f に渡します。組み合わせは引数 a にセットします。m が 0 になったら、組み合わせがひとつできたので f(a) を呼び出します。n が m と等しくなったならば、残り m 個を全て選びます。(1 << m) - 1 で m 個のビットをオンにして関数 f を呼び出します。

あとは combinations を再帰呼び出しします。最初の呼び出しは n 番目の数字を選ばない場合です。n - 1 個の中から m 個を選びます。次の呼び出しが n 番目の数字を選ぶ場合で、a の n - 1 番目のビットをオンにします。そして、n - 1 個の中から m - 1 個を選びます。

それでは 5 個の中から 3 個を選ぶ combinations(5, 3) の実行例を示します。

scala> combinations(println, 5, 3)
7
11
13
14
19
21
22
25
26
28

この場合、最小値は 7 (111) で最大値は 28 (11100) になります。このように、combinations は組み合わせを表す数を昇順で出力します。

●組み合わせに番号を付ける方法

次は、N 通りある組み合わせに 0 から N - 1 までの番号を付ける方法を紹介しましょう。たとえば、6 個の中から 3 個を選ぶ組み合わせは 20 通りありますが、この組み合わせに 0 から 19 までの番号を付けることができます。1 1 1 0 0 0 を例題に考えてみましょう。次の図を見てください。


    図 : 6C3 の組み合わせ

最初に 5 をチェックします。5 を選ばない場合は \({}_5 \mathrm{C}_3\) = 10 通りありますね。この組み合わせに 0 から 9 までの番号を割り当てることにすると、5 を選ぶ組み合わせの番号は 10 から 19 までとなります。

次に、4 をチェックします。4 を選ばない場合は、\({}_4 \mathrm{C}_2\) = 6 通りあります。したがって、5 を選んで 4 を選ばない組み合わせに 10 から 15 までの番号を割り当てることにすると、5 と 4 を選ぶ組み合わせには 16 から 19 までの番号となります。

最後に、3 をチェックします。同様に 3 を選ばない場合は 3 通りあるので、これに 16 から 18 までの番号を割り当て、5, 4, 3 を選ぶ組み合わせには 19 を割り当てます。これで組み合わせ 1 1 1 0 0 0 の番号を求めることができました。

では、0 0 0 1 1 1 はどうなるのでしょうか。左から順番にチェックしていくと、最初の 1 が見つかった時点で、その数字を選ばない組み合わせは存在しません。つまり、残りの数字をすべて選ぶしかないわけです。したがって、これが 0 番目となります。

このように、数字を選ぶときに、数字を選ばない場合の組み合わせの数を足し算していけば、その組み合わせの番号を求めることができるのです。

●組み合わせを番号に変換

組み合わせを番号に変換するプログラムは次のようになります。

リスト : 組み合わせを番号に変換

// 組み合わせの数
def combNum(n: Int, r: Int): Int =
  if (n == r || r == 0)
    1
  else
    combNum(n, r - 1) * (n - r + 1) / r

// 組み合わせを番号に変換
def combToNum(c: Int, n: Int, r: Int, value: Int = 0): Int =
  if (r == 0 || n == r) value
  else if (testBit(c, n - 1))
    combToNum(c, n - 1, r - 1, value + combNum(n - 1, r))
  else
    combToNum(c, n - 1, r, value)

関数 combNum は組み合わせの数を求めます。combToNum の引数 c はビットのオンオフで表した組み合わせ、引数 n と r は \({}_n \mathrm{C}_r\) の n と r を表しています。引数 value は求める番号を表します。n と r の値が同じになるか、もしくは r が 0 になれば、組み合わせの番号を計算できたので value を返します。

そうでない場合、c の n - 1 ビットの値を調べます。ビットがオンであれば、value に combNum(n - 1, r) の値を足し算し、r を -1 して combToNum を再帰呼び出しします。そうでなければ、value と r の値はそのままで combToNum を再帰呼び出しします。

簡単な実行例を示します。

scala> combToNum(7, 5, 3)
val res2: Int = 0

scala> combToNum(21, 5, 3)
val res3: Int = 5

scala> combToNum(28, 5, 3)
val res4: Int = 9

●番号を組み合わせに変換

逆に、番号から組み合わせを求めるプログラムも簡単に作ることができます。次のリストを見てください。

リスト : 番号を組み合わせに変換

def numToComb(value: Int, n: Int, r: Int, c: Int = 0): Int =
  if (r == 0) c
  else if (n == r) c | ((1 << n) - 1)
  else {
    val k = combNum(n - 1, r)
    if (value >= k)
      numToComb(value - k, n - 1, r - 1, setBit(c, n - 1))
    else
      numToComb(value, n - 1, r, c)
  }

引数 value が番号で、引数 n と r は \({}_n \mathrm{C}_r\) の n と r を表しています。引数 c が求める組み合わせです。たとえば、n = 5, r = 3 の場合、ビットが 1 になるのは \({}_4 \mathrm{C}_2\) = 6 通りあり、0 になるのは \({}_4 \mathrm{C}_3\) = 4 通りあります。したがって、数値が 0 - 3 の場合はビットを 0 にし、4 - 9 の場合はビットを 1 にすればいいわけです。

ビットを 0 にした場合、残りは \({}_4 \mathrm{C}_3\) = 4 通りになるので、同様に次のビットを決定します。ビットを 1 にした場合、残りは \({}_4 \mathrm{C}_2\) = 6 通りになるので、value から 4 を引いて numToComb を再帰呼び出しして次のビットを決定します。

r が 0 の場合は、組み合わせが完成したので c を返します。n と r が等しい場合は、残りのビットをすべて 1 にセットしてから c を返します。それ以外の場合は、\({}_{n-1} \mathrm{C}_r\) の値を combNum(n - 1, r) で求めて変数 k にセットします。value が k 以上であれば変数 c のビットを 1 にセットし、value から k を引き算して numToComb を再帰呼び出しします。そうでなければ、numToComb を再帰呼び出しするだけです。

それでは、n = 5, r = 3 の場合の実行例を示します。

scala> for (i <- 0 to 9) println(numToComb(i, 5, 3))
7
11
13
14
19
21
22
25
26
28

正常に動作していますね。この方法を使うと、n 個ある組み合わせの中の i 番目 (0 <= i < n) の組み合わせを簡単に求めることができます。

●ちょっと便利なビット操作

最も右側 (LSB 側) にある 1 を 0 にクリアする、逆に最も右側にある 0 を 1 にセットすることは簡単にできます。

(1) 右側にある 1 をクリア => x & (- x)

x     : 1 1 1 1
x - 1 : 1 1 1 0
----------------
 AND  : 1 1 1 0

x     : 1 0 0 0
x - 1 : 0 1 1 1
----------------
 AND  : 0 0 0 0

(2) 右側にある 0 を 1 にセット => x | (x + 1)

x     : 0 0 0 0
x + 1 : 0 0 0 1
----------------
  OR  : 0 0 0 1

x     : 0 1 1 1
x - 1 : 1 0 0 0
----------------
  OR  : 1 1 1 1

上図 (1) を見てください。x から 1 を引くと、右側から連続している 0 は桁借りにより 1 になり、最初に出現する 1 が 0 になります。したがって、x & (x - 1) を計算すると、最も右側にある 1 を 0 にクリアすることができます。(2) の場合、x に 1 を足すと、右側から連続している 1 は桁上がりにより 0 になり、最初に出現する 0 が 1 になります。x | (x + 1) を計算すれば、最も右側にある 0 を 1 にセットすることができます。

また、最も右側にある 1 を取り出すことも簡単にできます。簡単な例として 4 ビットの整数値を考えてみます。負の整数を 2 の補数で表した場合、4 ビットで表される整数は -8 から 7 になります。次の図を見てください。

 0 : 0000
 1 : 0001    -1 : 1111    1 & (-1) => 0001
 2 : 0010    -2 : 1110    2 & (-2) => 0010
 3 : 0011    -3 : 1101    3 & (-3) => 0001
 4 : 0100    -4 : 1100    4 & (-4) => 0100
 5 : 0101    -5 : 1011    5 & (-5) => 0001
 6 : 0110    -6 : 1010    6 & (-6) => 0010
 7 : 0111    -7 : 1001    7 & (-7) => 0001
             -8 : 1000


        図 : 最も右側にある 1 を取り出す方法

2 の補数はビットを反転した値 (1 の補数) に 1 を加算することで求めることができます。したがって、x と -x の論理積 x & (-x) は、最も右側にある 1 だけが残り、あとのビットはすべて 0 になります。

●ビットが 1 の個数を求める

次は、ビットが 1 の個数を数える処理を作ってみましょう。データ型を Int とすると、プログラムは次のようになります。

リスト : ビットカウント

def bitCount(n: Int): Int = {
  var c = 0
  var m = n
  while (m != 0) {
    m &= m - 1
    c += 1
  }
  c
}

整数 n の右側から順番に 1 をクリアしていき、0 になるまでの回数を求めます。とても簡単ですね。32 個のビットを順番に調べるよりも高速です。

Int を 32 bit とする場合、次の方法で 1 の個数をもっと高速に求めることができます。

リスト : ビットカウント (2)

def bitCount1(n: Int): Int = {
  val a = (n & 0x55555555) + ((n >>> 1) & 0x55555555)
  val b = (a & 0x33333333) + ((a >>> 2) & 0x33333333)
  val c = (b & 0x0f0f0f0f) + ((b >>> 4) & 0x0f0f0f0f)
  val d = (c & 0x00ff00ff) + ((c >>> 8) & 0x00ff00ff)
  (d & 0xffff) + (d >>> 16)
}

最初に、整数を 2 bit ずつに分割して、1 の個数を求めます。たとえば、整数 n を 4 bit で考えてみましょう。5 を 2 進数で表すと 0101 になり、n と論理積を計算すると 0, 2 番目のビットが 1 であれば、結果の 0, 2 番目のビットは 1 になります。同様に n を 1 ビット右シフトして論理積を計算すると、1, 3 番目のビットが 1 であれば、結果の 0, 2 番目のビットは 1 になります。あとは、それを足し算すれば 2 bit の中にある 1 の個数を求めることができます。

変数 a には 2 ビットの中の 1 の個数が格納されています。左隣の 2 ビットの値を足し算すれば、4 ビットの中の 1 の個数を求めることができます。次に、左隣の 4 ビットの値を足し算して 8 ビットの中の 1 の個数を求め、左隣の 8 ビットの値を足し算して、というように順番に値を加算していくと 32 ビットの中にある 1 の個数を求めることができます。

bitCount は 1 の個数が多くなると遅くなりますが、bitCount1 は 1 の個数に関係なく高速に動作します。興味のある方は試してみてください。

●BitSet

Scala には scala.collection.immutable に BitSet というビットで集合を表すクラスが用意されています。集合の要素は 0 以上の整数値になります。使い方は immutable なセット (Set) と同じです。

簡単な例を示しましょう。

scala> import scala.collection.immutable.{BitSet}
import scala.collection.immutable.BitSet

scala> val s = BitSet(1, 2, 3, 4, 5)
val s: scala.collection.immutable.BitSet = BitSet(1, 2, 3, 4, 5)

scala> s(0)
val res0: Boolean = false

scala> s(5)
val res1: Boolean = true

scala> s + 6
val res2: scala.collection.immutable.BitSet = BitSet(1, 2, 3, 4, 5, 6)

scala> s + (6, 7, 8, 9)
         ^
       warning: method + in trait SetOps is deprecated (since 2.13.0): 
       Use ++ with an explicit collection argument instead of + with varargs
val res3: scala.collection.immutable.BitSet = BitSet(1, 2, 3, 4, 5, 6, 7, 8, 9)

scala> s - 1
val res4: scala.collection.immutable.BitSet = BitSet(2, 3, 4, 5)

scala> s - (1, 2, 3)
         ^
       warning: method - in trait SetOps is deprecated (since 2.13.0): 
       Use &- with an explicit collection argument instead of - with varargs
val res5: scala.collection.immutable.BitSet = BitSet(4, 5)

以前のバージョンでは、タプルに格納した複数の要素をまとめて追加 (または削除) することができたのですが、ver 2.13 から非推奨になりました。複数の要素を追加する場合は演算子 ++ (削除する場合は演算子 --) を使ってください。

scala> s ++ List(6, 7, 8, 9)
val res7: scala.collection.immutable.BitSet = BitSet(1, 2, 3, 4, 5, 6, 7, 8, 9)

scala> s -- List(1, 2, 3)
val res8: scala.collection.immutable.BitSet = BitSet(4, 5)

BitSet を使うと組み合わせの生成は次のようになります。

リスト : 組み合わせの生成 (BitSet 版)

def combinationSet(f: BitSet => Unit, n: Int, m: Int, a: BitSet = BitSet()): Unit = {
  if (m == 0) f(a)
  else if (m == n) f(a ++ Range(1, m + 1))
  else {
    combinationSet(f, n - 1, m, a)
    combinationSet(f, n - 1, m - 1, a + n)
  }
}

プログラムは簡単ですね。それでは実行してみましょう。

scala> combinationSet(println, 5, 3)
BitSet(1, 2, 3)
BitSet(1, 2, 4)
BitSet(1, 3, 4)
BitSet(2, 3, 4)
BitSet(1, 2, 5)
BitSet(1, 3, 5)
BitSet(2, 3, 5)
BitSet(1, 4, 5)
BitSet(2, 4, 5)
BitSet(3, 4, 5)

●参考 URL

ビットが 1 の個数を数える方法は フィンローダさん初級C言語Q&A(15) を参考にさせていただきました。フィンローダさんに感謝いたします。


●プログラムリスト

//
// testbit.scala : ビット演算の簡単な例題
//
//                 Copyright (C) 2014-2021 Makoto Hiroi
//
import scala.collection.immutable.{BitSet}

object testbit {
  // 基本的なビット操作
  def testBit(x: Int, n: Int): Boolean = (x & (1 << n)) != 0
  def setBit(x: Int, n: Int): Int = x | (1 << n)
  def clearBit(x: Int, n: Int): Int = x & ~(1 << n)

  // 組み合わせの生成
  def combinations(f: Int => Unit, n: Int, m: Int, a: Int = 0): Unit = {
    if (m == 0) f(a)
    else if (m == n) f(a | ((1 << m) - 1))
    else {
      combinations(f, n - 1, m, a)
      combinations(f, n - 1, m - 1, setBit(a, n - 1))
    }
  }

  // BitSet 版
  def combinationSet(f: BitSet => Unit, n: Int, m: Int, a: BitSet = BitSet()): Unit = {
    if (m == 0) f(a)
    else if (m == n) f(a ++ Range(1, m + 1))
    else {
      combinationSet(f, n - 1, m, a)
      combinationSet(f, n - 1, m - 1, a + n)
    }
  }

  // 組み合わせの数
  def combNum(n: Int, r: Int): Int =
    if (n == r || r == 0)
      1
    else
      combNum(n, r - 1) * (n - r + 1) / r

  // 組み合わせを番号に変換
  def combToNum(c: Int, n: Int, r: Int, value: Int = 0): Int =
    if (r == 0 || n == r) value
    else if (testBit(c, n - 1))
      combToNum(c, n - 1, r - 1, value + combNum(n - 1, r))
    else
      combToNum(c, n - 1, r, value)

  // 番号を組み合わせに変換
  def numToComb(value: Int, n: Int, r: Int, c: Int = 0): Int =
    if (r == 0) c
    else if (n == r) c | ((1 << n) - 1)
    else {
      val k = combNum(n - 1, r)
      if (value >= k)
        numToComb(value - k, n - 1, r - 1, setBit(c, n - 1))
      else
        numToComb(value, n - 1, r, c)
    }

  // ビット 1 の個数を数える
  def bitCount(n: Int): Int = {
    var c = 0
    var m = n
    while (m != 0) {
      m &= m - 1
      c += 1
    }
    c
  }

  // 高速版
  def bitCount1(n: Int): Int = {
    val a = (n & 0x55555555) + ((n >>> 1) & 0x55555555)
    val b = (a & 0x33333333) + ((a >>> 2) & 0x33333333)
    val c = (b & 0x0f0f0f0f) + ((b >>> 4) & 0x0f0f0f0f)
    val d = (c & 0x00ff00ff) + ((c >>> 8) & 0x00ff00ff)
    (d & 0xffff) + (d >>> 16)
  }
}

初版 2014 年 10 月 18 日
改訂 2021 年 4 月 11 日

N Queens Problem

「8 クイーン」はコンピュータに解かせるパズルの中でも特に有名な問題です。このパズルは 8 行 8 列のチェス盤の升目に、8 個のクイーンを互いの利き筋が重ならないように配置する問題です。クイーンは将棋の飛車と角をあわせた駒で、縦横斜めに任意に動くことができます。解答の一例を下図に示します。


      図 : 8 クイーンの解答例

N Queens Problem は「8 クイーン」の拡張バージョンで、N 行 N 列の盤面に N 個のクイーンを互いの利き筋が重ならないように配置する問題です。まず最初に「8 クイーン」を解いてみて、そのあと N Queens Problem に挑戦することにしましょう。

●8 クイーンの解法

8 クイーンを解くには、すべての置き方を試してみるしか方法はありません。最初のクイーンは、盤上の好きなところへ置くことができるので、64 通りの置き方があります。次のクイーンは 63 通り、その次は 62 通りあるので、置き方の総数は 64 から 57 までの整数を掛け算した 178462987637760 通りもあります。

ところが、解答例を見ればわかるように、同じ行と列に 2 つ以上のクイーンを置くことはできません。上図の解答例をリストを使って表すと、 次のようになります。

  1  2  3  4  5  6  7  8    <--- 列の位置
---------------------------
 [1, 7, 5, 8, 2, 4, 6, 3]   <--- 要素が行の位置を表す  


        図 : リストでの行と列の表現方法

列をリストの位置に、行番号を要素に対応させれば、各要素には 1 から 8 までの数字が重複しないで入ることになります。すなわち、1 から 8 までの順列の総数である 8! = 40320 通りの置き方を調べるだけでよいのです。パズルを解く場合、そのパズル固有の性質をうまく使って、調べなければならない場合の数を減らすように工夫することが大切です。あとは、その順列が 8 クイーンの条件を満たしているかチェックすればいいわけです。

●単純な生成検定法

それでは、プログラムを作りましょう。次のリストを見てください。

リスト : 8 クイーンの解法

object nqueens {
  // 衝突しているか
  def attack(x: Int, ys: List[Int]): Boolean = {
    var n = 1
    for (y <- ys) {
      if (x == y + n || x == y - n) return true
      n += 1
    }
    false
  }

  // 安全か?
  def safe(xs: List[Int]): Boolean =
    xs match {
      case Nil => true
      case y::ys => if (attack(y, ys)) false else safe(ys)
    }

  def queen0(xs: List[Int]): Int = {
    val it = xs.permutations
    var c = 0
    while (it.hasNext) {
      val xs = it.next()
      if (safe(xs)){
        c += 1
        println(xs)
      }
    }
    c
  }
}

関数 queen0 でメソッド permutations を呼び出して、引数のリスト xs の順列を生成します。permutations はイテレータを返すことに注意してください。あとは、while ループで順列を取り出して、それが 8 クイーンの条件を満たしているかチェックします。関数 safe はリストの先頭の要素から順番に衝突のチェックを行います。端にあるクイーンから順番に調べるとすると、斜めの利き筋は次のように表すことができます。

    1 2 3    --> 調べる方向
  *-------------
  | . . . . . .
  | . . . -3. .  5 - 3 = 2
  | . . -2. . .  5 - 2 = 3
  | . -1. . . .  5 - 1 = 4
  | Q . . . . .  Q の位置は 5  
  | . +1. . . .  5 + 1 = 6
  | . . +2. . .  5 + 2 = 7
  | . . . +3. .  5 + 2 = 8
  *-------------


    図 : 衝突の検出

図を見てもらえばおわかりのように、Q が行 5 にある場合、ひとつ隣の列は 4 と 6 が利き筋に当たります。2 つ隣の列の場合は 3 と 7 が利き筋に当たります。このように単純な足し算と引き算で、利き筋を計算することができます。この処理を関数 attack で行います。

attack はリストの先頭から斜めの利き筋に当たるか調べます。引数 x がクイーンの位置、ys が残りのクイーンを格納したリストです。変数 n が差分を表します。for ループで ys からクイーン y を取り出し、y + n または y - n が x と等しいかチェックします。等しい場合は衝突しているので true を返します。そうでなければ、次のクイーンを調べます。このとき、差分 n を +1 することをお忘れなく。すべてのクイーンを調べたら false を返します。

●実行結果

これでプログラムは完成です。それでは実行してみましょう。

scala> :load nqueens.scala
val args: Array[String] = Array()
Loading nqueens.scala...
object nqueens

scala> nqueens.queen0(List(1,2,3,4,5,6,7,8))
List(1, 5, 8, 6, 3, 7, 2, 4)
List(1, 6, 8, 3, 7, 4, 2, 5)
List(1, 7, 4, 6, 8, 2, 5, 3)

・・・省略・・・

List(8, 2, 5, 3, 1, 7, 4, 6)
List(8, 3, 1, 6, 2, 5, 7, 4)
List(8, 4, 1, 3, 6, 2, 7, 5)
val res0: Int = 92

8 クイーンの場合、回転解や鏡像解を含めると全部で 92 通りあります。

ところで、クイーンの個数を増やすと、プログラムの実行時間は極端に遅くなります。クイーンの個数を増やすのは簡単です。次のリストを見てください。

リスト : N Queens Problem

  def test0(): Unit = {
    for (i <- 8 to 11) {
      val s = System.currentTimeMillis
      println(queen0(Range(1, i + 1).toList))
      println((System.currentTimeMillis - s) + "msec")
    }
  }

Range で 1 から i までの数列を生成し、それを toList でリストに変換して queen0 に渡します。queen0 は解を表示するのではなく、解の個数をカウントするように修正します。実行結果は次のようになりました。

      表 : 実行結果 (時間 : 秒)

  個数 :  8  :   9  :  10  :  11  
  -----+-----+------+------+------
   解  :  92 :  352 :  724 : 2680 
  時間 :0.12 : 0.36 : 1.76 : 19.4 

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

クイーンの個数をひとつ増やしただけでも、実行時間はとても遅くなります。実はこのプログラム、とても非効率なことをやっているのです。

●無駄を省く

実行速度が遅い理由は、失敗することがわかっている順列も生成してしまうからです。たとえば、最初 (1, 1) の位置にクイーンを置くと、次のクイーンは (2, 2) の位置に置くことはできませんね。したがって、[1, 2, X, X, X, X, X, X,] という配置はすべて失敗するのですが、順列を発生させてからチェックする方法では、このような無駄を省くことができません。

そこで、クイーンの配置を決めるたびに衝突のチェックを行うことにします。これをプログラムすると次のようになります。

リスト : N Queens Problem (改良版)

  def queen1(f: List[Int] => Unit, xs: List[Int], a: List[Int]): Unit =
    if (xs == Nil)
      f(a)
    else
      for (x <- xs if (!attack(x, a))) {
        queen1(f, xs.filterNot(_ == x), x::a)
      }

  def test1(): Unit = {
    for (i <- 8 to 13) {
      val s = System.currentTimeMillis
      var c = 0
      queen1(_ => c += 1, Range(1, i + 1).toList, Nil)
      println(c)
      println((System.currentTimeMillis - s) + "msec")
    }
  }

関数 queen1 は高階関数として定義します。解をひとつ見つけたら引数 f の関数を呼び出します。queen1 でクイーンを選択するとき、関数 attack を呼び出して選択したクイーンが衝突しないかチェックします。関数 test1 では、queen1 に渡す匿名関数で解の個数をカウントします。とても簡単ですね。

●実行結果 (2)

実行結果を示します。

          表 : 実行結果 (時間 : 秒)

  個数 :   8  :   9  :  10  :  11  :  12   :  13 
  -----+------+------+------+------+-------+-------
   解  :   92 :  352 :  724 : 2680 : 14200 : 73712
   (0) : 0.12 : 0.36 : 1.76 : 19.4 : ----- : -----
   (1) : ---- : ---- : 0.02 : 0.09 : 0.34  : 1.82

    実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

実行時間は速くなりましたが、クイーンの個数が 13 を超えると実行時間が極端に遅くなります。これは、斜めの利き筋をチェックする関数 attack に時間がかかるからです。そこで、藤原博文さん8クイーン@勉強会のページ を参考にプログラムを改良してみましょう。

●プログラムの改良

斜めの利き筋のチェックは、簡単な方法で高速化できます。次の図を見てください。


            図 : 斜めの利き筋のチェック

斜めの利き筋は、行と列の位置を足す、または行から列を引くと一定の値になることを利用してチェックしています。attack は確定済みのクイーンと衝突していないかひとつずつチェックしていますが、斜めの利き筋を配列にセットしておけば、もっと簡単にチェックすることができます。

右斜め上の利き筋を配列 rUsed, 左斜め上の利き筋を配列 lUsed で表すことにすると、(x, y) にクイーンを置いた場合は次のようにセットします。

rs(x + y) = true
ls(x - y + n - 1) = true

n は盤面の大きさ (クイーンの個数) です。バックトラックするときはリセットすることをお忘れなく。プログラムは次のようになります。

リスト : N Queens Problem (2)

  val MaxSize = 16
  val board = new Array[Int](MaxSize)
  val nUsed = new Array[Boolean](MaxSize)
  val rUsed = new Array[Boolean](MaxSize * 2)
  val lUsed = new Array[Boolean](MaxSize * 2)
  var size = 0
  var cnt  = 0

  def queen2(n: Int): Unit = {
    if (n == size)
      cnt += 1
    else {
      for (m <- 0 until size) {
        if (!nUsed(m) && !rUsed(m + n) && !lUsed(m - n + size - 1)) {
          board(n) = m
          rUsed(m + n) = true
          lUsed(m - n + size - 1) = true
          nUsed(m) = true
          queen2(n + 1)
          rUsed(m + n) = false
          lUsed(m - n + size - 1) = false
          nUsed(m) = false
        }
      }
    }
  }

  def test2(): Unit = {
    for (i <- 10 to 14) {
      val s = System.currentTimeMillis
      size = i
      cnt = 0
      queen2(0)
      println(cnt)
      println((System.currentTimeMillis - s) + "msec")
    }
  }

プログラムは配列を使って書き直しています。配列 nUsed は未使用の数字を false で、使用した数字を true で表します。あとは、とくに難しいところはないと思います。説明は割愛しますので、詳細はリストをお読みくださいませ。

●実行結果 (3)

実行結果を示します。

          表 : 実行結果 (時間 : 秒)

  個数 :  10  :  11  :  12   :  13   :   14
  -----+------+------+-------+-------+--------
   解  :  724 : 2680 : 14200 : 73712 : 365596
   (0) : 1.76 : 19.4 : ----- : ----- : ------
   (1) : 0.02 : 0.09 : 0.34  : 1.82  : ------
   (2) : 0.02 : 0.06 : 0.15  : 0.50  :  3.07

  実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

改良の効果は十分に出ていますね。

●ビット演算による高速化

最後に、ビット演算を使って高速化する方法を紹介します。オリジナルは Jeff Somers さんのプログラムですが、高橋謙一郎さん が再帰を使って書き直したプログラムを Nクイーン問題(解の個数を求める) で発表されています。今回は高橋さんのプログラムを参考にさせていただきました。高橋さんに感謝します。

プログラムのポイントは二つあります。一つはクイーンの選択処理をビット演算で行うこと、もう一つは斜めの利き筋のチェックをビット演算で行うことです。

クイーンの位置をビットオンで表すことします。つまり、i 行目のクイーンは i ビットを 1 にした値になります。この場合、未選択のクイーンは整数値で表すことができます。8 クイーンの場合、まだ一つもクイーンを選択していない状態は 255 になります。残っているクイーンを表す値を n とすると、次の処理でクイーンを順番に取り出していくことができます。

リスト : クイーンの選択処理

var m = n
while (m > 0) {
  q := m & (-m)
    ...
  m &= m - 1
}

while ループの最後で m &= m - 1 とすれば、右端の 1 を 0 にクリアすることができます。そして、ループの中で q := m & (- m) とすれば、右端の 1 を取り出すことができます。n から取り出した q を削除するのも簡単で、排他的論理和 n ^ q を計算するだけです。

次は斜めの利き筋のチェックを説明します。下図を見てください。

    0 1 2 3 4
  *-------------
  | . . . . . .
  | . . . -3. .  0x02
  | . . -2. . .  0x04
  | . -1. . . .  0x08 (1 bit 右シフト)
  | Q . . . . .  0x10 (Q の位置は 4)
  | . +1. . . .  0x20 (1 bit 左シフト)  
  | . . +2. . .  0x40
  | . . . +3. .  0x80
  *-------------


      図 : 斜めの利き筋のチェック

上図の場合、1 列目の右斜め上の利き筋は 3 番目 (0x08)、2 列目の右斜め上の利き筋は 2 番目 (0x04) になります。この値は 0 列目のクイーンの位置 0x10 を 1 ビットずつ右シフトすれば求めることができます。また、左斜め上の利き筋の場合、1 列目では 5 番目 (0x20) で 2 列目では 6 番目 (0x40) になるので、今度は 1 ビットずつ左シフトすれば求めることができます。

つまり、右斜め上の利き筋を right、左斜め上の利き筋を left で表すことにすると、right と left にクイーンの位置をセットしたら、隣の列を調べるときに right と left を 1 ビットシフトするだけで、斜めの利き筋を求めることができるわけです。

プログラムは次のようになります。

リスト : N Queens Problem (3)

  def queen3(n: Int, right: Int, left: Int): Unit = {
    if (n == 0)
      cnt += 1
    else {
      var m = n
      while (m > 0) {
        val q = m & (-m)
        if ((q & (right | left)) == 0) {
          queen3(n ^ q, (right | q) >> 1, (left | q) << 1)
        }
        m &= m - 1
      }
    }
  }

  def test3(): Unit = {
    for (i <- 10 to 15) {
      val s = System.currentTimeMillis
      cnt = 0
      queen3((1 << i) - 1, 0, 0)
      println(cnt)
      println((System.currentTimeMillis - s) + "msec")
    }
  }

関数 queen3 の引数 n が未選択のクイーン、引数 right が右斜め上の利き筋、left が左斜め上の利き筋を表します。(rigth | left) のビットオンの位置が斜めの利き筋にあたります。そして、n から斜めの利き筋にあたらないクイーンを選びます。

queen を再帰呼び出しするときは、right と left にクイーンの位置をセットして、それを 1 ビットシフトします。right と left は局所変数なので、元の値に戻す処理は必要ありません。あとは、とくに難しいところはないでしょう。詳細はプログラムリストをお読みください。

●実行結果 (4)

実行結果を示します。

          表 : 実行結果 (時間 : 秒)

  個数 :  11  :  12   :  13   :   14   :   15
  -----+------+-------+-------+--------+---------
   解  : 2680 : 14200 : 73712 : 365596 : 2279184
   (1) : 0.09 : 0.34  : 1.82  : ------ : -------
   (2) : 0.06 : 0.15  : 0.50  :  3.07  : -------
   (3) : ---- : 0.02  : 0.10  :  0.56  :  3.46

  実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

ビット演算の効果はきわめて大きいですね。ここまで速くなるとは M.Hiroi も大変驚きました。


●プログラムリスト

//
// nqueens.scala : N Queens Problems
//
//                 Copyright (C) 2014-2021 Makoto Hiroi
//
object nqueens {
  // 衝突のチェック
  def attack(x: Int, ys: List[Int]): Boolean = {
    var n = 1
    for (y <- ys) {
      if (x == y + n || x == y - n) return true
      n += 1
    }
    false
  }

  // 安全か?
  def safe(xs: List[Int]): Boolean =
    xs match {
      case Nil => true
      case y::ys => if (attack(y, ys)) false else safe(ys)
    }

  // 解法 (0)
  def queen0(xs: List[Int]): Int = {
    val it = xs.permutations
    var c = 0
    while (it.hasNext) {
      val xs = it.next()
      if (safe(xs)) {
        c += 1
        println(xs)
      }
    }
    c
  }

  def test0(): Unit = {
    for (i <- 8 to 11) {
      val s = System.currentTimeMillis
      println(queen0(Range(1, i + 1).toList))
      println(s"${System.currentTimeMillis - s} msec")
    }
  }

  // 解法 (1)
  def queen1(f: List[Int] => Unit, xs: List[Int], a: List[Int]): Unit =
    if (xs == Nil)
      f(a)
    else
      for (x <- xs if (!attack(x, a))) {
        queen1(f, xs.filterNot(_ == x), x::a)
      }

  def test1(): Unit = {
    for (i <- 8 to 13) {
      val s = System.currentTimeMillis
      var c = 0
      queen1(_ => c += 1, Range(1, i + 1).toList, Nil)
      println(c)
      println(s"${System.currentTimeMillis - s} msec")
    }
  }

  // 解法 (2)
  val MaxSize = 16
  val board = new Array[Int](MaxSize)
  val nUsed = new Array[Boolean](MaxSize)
  val rUsed = new Array[Boolean](MaxSize * 2)
  val lUsed = new Array[Boolean](MaxSize * 2)
  var size = 0
  var cnt  = 0

  def queen2(n: Int): Unit = {
    if (n == size)
      cnt += 1
    else {
      for (m <- 0 until size) {
        if (!nUsed(m) && !rUsed(m + n) && !lUsed(m - n + size - 1)) {
          board(n) = m
          rUsed(m + n) = true
          lUsed(m - n + size - 1) = true
          nUsed(m) = true
          queen2(n + 1)
          rUsed(m + n) = false
          lUsed(m - n + size - 1) = false
          nUsed(m) = false
        }
      }
    }
  }

  def test2(): Unit = {
    for (i <- 10 to 14) {
      val s = System.currentTimeMillis
      size = i
      cnt = 0
      queen2(0)
      println(cnt)
      println(s"${System.currentTimeMillis - s} msec")
    }
  }

  // 解法 (3)
  def queen3(n: Int, right: Int, left: Int): Unit = {
    if (n == 0)
      cnt += 1
    else {
      var m = n
      while (m > 0) {
        val q = m & (-m)
        if ((q & (right | left)) == 0) {
          queen3(n ^ q, (right | q) >> 1, (left | q) << 1)
        }
        m &= m - 1
      }
    }
  }

  def test3(): Unit = {
    for (i <- 10 to 15) {
      val s = System.currentTimeMillis
      cnt = 0
      queen3((1 << i) - 1, 0, 0)
      println(cnt)
      println(s"${System.currentTimeMillis - s} msec")
    }
  }
}

初版 2014 年 10 月 18 日
改訂 2021 年 4 月 11 日

Copyright (C) 2014-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | Scala | NextPage ]