M.Hiroi's Home Page

Functional Programming

お気楽 Haskell プログラミング入門

[ PrevPage | Haskell | NextPage ]

配列のソート (2)

ソートの続きです。今回はクイックソート、ヒープソート、マージソートについて説明します。

●クイックソート

最初に、高速なソートアルゴリズムとして有名な「クイックソート (quick sort)」を取り上げます。要素数を N とすると、クイックソートの平均的な実行時間は N * log N に比例しますが、最悪の場合は N の 2 乗に比例する遅いソートになってしまいます。

クイックソートはある値を基準にして、要素をそれより大きいものと小さいものの 2 つに分割していくことでソートを行います。2 つに分けた各々の区間を同様に分割して 2 つの区間に分けます。最後は区間の要素がひとつになってソートが完了します。

  9 5 3 7 6 4 2 8      最初の状態

  9 5 3 7 6 4 2 8      7 を枢軸にして左側から 7 以上の値を探し、
  L           R        右側から 7 以下の値を探す。

  2 5 3 7 6 4 9 8      交換する
  L           R

  2 5 3 7 6 4 9 8      検索する
        L   R

  2 5 3 4 6 7 9 8      交換する
        L   R

  2 5 3 4 6 7 9 8      検索する。R と L が交差したら分割終了。
          R L

  [2 5 3 4 6] [7 9 8]  この 2 つの区間について再び同様な分割を行う


                図 : クイックソート

基準になる値のことを「枢軸 (pivot)」といいます。枢軸は要素の中から適当な値を選びます。今回は区間の中央に位置する要素を選ぶことにしましょう。上図を見てください。左側から枢軸 7 以上の要素を探し、左側から 7 以下の要素を探します。探索のときは枢軸が番兵の役割を果たすので、ソート範囲外の要素を探索することはありません。見つけたらお互いの要素を交換します。探索位置が交差したら分割は終了です。

あとは同じ手順を分割した 2 つの区間に適用します。これは再帰定義を使えば簡単に実現できます。分割した区間の要素数が 1 になったときが再帰の停止条件になります。プログラムは次のようになります。

リスト : クイックソート (1)

search1 :: Ord a => IOArray Int a -> a -> Int -> IO Int
search1 buff pv i = do
  x <- readArray buff i
  if pv > x
    then search1 buff pv (i + 1)
    else return i

search2 :: Ord a => IOArray Int a -> a -> Int -> IO Int
search2 buff pv j = do
  x <- readArray buff j
  if pv < x
    then search2 buff pv (j - 1)
    else return j

quickPartition :: Ord a => IOArray Int a -> a -> Int -> Int -> IO (Int, Int)
quickPartition buff pv low high = do
  i <- search1 buff pv low
  j <- search2 buff pv high
  if i < j
    then do swapItem buff i j
            quickPartition buff pv (i + 1) (j - 1)
    else return (i, j)

quickSort :: Ord a => IOArray Int a -> IO ()
quickSort buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high = do
      pv <- readArray buff (low + (high - low) `div` 2)
      (i, j) <- quickPartition buff pv low high
      when (low < i - 1)  $ qsort low (i - 1)
      when (high > j + 1) $ qsort (j + 1) high

実際の処理は局所関数 qsort で行います。引数 low が区間の下限値、high が区間の上限値です。qsort は buff の low から high までの区間をソートします。最初に、区間の中央にあるデータを枢軸 (pv) として選びます。そして、関数 quickPartition で pv を基準にして区間を 2 つに分けます。

quickPartition では、関数 search1 で左側から枢軸以上の要素を探しています。ここでは枢軸以上という条件を、枢軸より小さい間は探索位置を進める、というように置き換えています。同様に関数 search2 で右側から枢軸以下の要素を探します。お互いの探索位置 i, j が交差したら分割は終了です。そうでなければお互いの要素を交換します。交換したあとは i と j の値を更新して quickPartition を再帰呼び出しします。

そして、分割した区間に対して qsort を再帰呼び出しします。このとき要素数をチェックして、2 個以上ある場合に再帰呼び出しを行います。この停止条件を忘れると正常に動作しません。ご注意ください。

それでは実行結果を示します。

*Main> a <- newListArray (0,9) [5,6,4,7,3,8,2,9,1,0] :: IO (IOArray Int Int)
*Main> quickSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
*Main> b <- newListArray (0,9) [9,8..0] :: IO (IOArray Int Int)
*Main> quickSort b
*Main> getElems b
[0,1,2,3,4,5,6,7,8,9]
*Main> c <- newListArray (0,9) [0..9] :: IO (IOArray Int Int)
*Main> quickSort c
*Main> getElems c
[0,1,2,3,4,5,6,7,8,9]
  表 : quickSort の結果 (単位 : 秒)

 [IOArray]
 個数  : 乱数   昇順   逆順   山型
-------------------------------------
 40000 : 0.010  0.002  0.003   1.293
 80000 : 0.022  0.005  0.006   5.157
160000 : 0.050  0.010  0.031  21.168
320000 : 0.111  0.022  0.028  86.607

 [IOUArray]
 個数  : 乱数   昇順   逆順   山型
------------------------------------
 40000 : 0.007  0.001  0.002   0.592
 80000 : 0.016  0.003  0.003   2.358
160000 : 0.031  0.005  0.006   9.620
320000 : 0.063  0.011  0.013  37.620

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

クイックソートは、枢軸の選び方で効率が大きく左右されます。区間の中央値 [*1] を枢軸に選ぶと、区間をほぼ半分に分割することができます。この場合がいちばん効率が良く、データ数を N とすると N * log N に比例する時間でソートすることができます。

逆に、区間での最大値または最小値を枢軸に選ぶと、その要素と残りの要素の 2 つに分割にされることになります。これが最悪の場合で、分割のたびに最大値もしくは最小値を選ぶと、実行時間は要素数の 2 乗に比例することになります。つまり、単純挿入ソートと同じくらい遅いソートになります。それだけでなく、要素数が多くなるとスタックがオーバーフローする危険性もあります。

今回は区間の中央に位置する要素を枢軸としたので、中央付近に大きい要素があるデータが最悪の場合にあてはまります。つまり、山型データがこのプログラムでは最悪の結果になります。実行結果を見ると、データ数が 2 倍になると実行時間が約 4 倍になっている、つまり N2 に比例する遅いソートになっていることがわかります。

-- note --------
[*1] N 個の要素を昇順に並べたとき、中央に位置する要素 (N / 2 番目の要素) を「中央値」といいます。中央値のことを「メディアン (median)」と呼びます。

●クイックソートの改良

それでは、クイックソートのプログラムを改良してみましょう。まずは枢軸の選び方を工夫します。区間の中からいくつかの要素を選び、その中で中央値を持つ要素を枢軸とします。たくさんの要素を選ぶとそれだけ最悪の枢軸を選ぶ危険性は減少しますが、中央値を選ぶのに時間がかかってしまいます。今回は 9 つの要素を選んで、その中から枢軸を選ぶことにしましょう。

次に、2 つに分割した区間の短い方からソートしていきます。そうすると、再帰呼び出しの深さは要素数を N とすると log2 N 程度におさまります。たとえば、100 万個の要素をソートする場合でも、再帰呼び出しの深さは 20 程度ですみます。最後に、要素数が少なくなったらクイックソートを打ち切り、単純挿入ソートに切り替えます。データ数が少ない場合は、クイックソートよりも単純なソートアルゴリズムの方が高速です。

まずは枢軸を選択するプログラムを作りましょう。次のリストを見てください。

リスト : 枢軸の選択

median3 :: Ord a => a -> a -> a -> a
median3 a b c =
  if a > b
    then if b > c
           then b
           else if a < c then a else c
    else if b < c
           then b
           else if a < c then c else a

selectPv9 :: Ord a => IOArray Int a -> Int -> Int -> IO a
selectPv9 buff low high = do
  x1 <- readArray buff low
  x2 <- readArray buff (low + m8)
  x3 <- readArray buff (low + m4)
  x4 <- readArray buff (low + m2 - m8)
  x5 <- readArray buff (low + m2)
  x6 <- readArray buff (low + m2 + m8)
  x7 <- readArray buff (high - m4)
  x8 <- readArray buff (high - m8)
  x9 <- readArray buff high
  return (median3 (median3 x1 x2 x3)
                  (median3 x4 x5 x6)
                  (median3 x7 x8 x9))
  where m2 = (high - low) `div` 2
        m4 = m2 `div` 2
        m8 = m4 `div` 2

関数 median3 は引数 a, b, c の中から中央値を返します。関数 selectPv9 は区間 (low, high) から 9 つの要素を選びます。区間を (0, 1) とすると、0, 1/8, 1/4, 3/8, 1/2, 5/8, 3/4, 7/8, 1 の位置にある要素を選びます。次に、9 つの要素を 3 つのグループ (0, 1/8, 1/4), (3/18, 1/2, 5/8), (3/4, 7/8, 1) に分けて、おのおののグループの中央値を median3 で求めます。さらに、その 3 つから中央値を median3 で選び、その値が枢軸となります。

M.Hiroi はこの方法をネットで検索して知りました。3 つの要素から枢軸を選ぶ方法を median-of-3 といい、9 つの要素から枢軸を選ぶ方法を median-of-9 と呼ぶようです。今回の方法は 9 つの要素の中から中央値を選択しているわけではありませんが、これでも十分に効果を発揮するようです。

次はクイックソートのプログラムを改良します。

リスト : クイックソートの改良

quickSort' :: Ord a => IOArray Int a -> IO ()
quickSort' buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high =
      if high - low < 16
        then insertSort' buff low high
        else do
          pv <- selectPv9 buff low high
          (i, j) <- quickPartition buff pv low high
          let a = (i - 1) - low
              b = high - (j + 1)
          if a < b
            then do qsort low (i - 1)
                    qsort (j + 1) high
            else do qsort (j + 1) high
                    qsort low (i - 1)

局所関数 qsort は、最初に high - low の値をチェックして 16 未満になったらクイックソートを打ち切り、単純挿入ソート (insertSort') に切り替えます。区間が 16 以上の場合はクイックソートを行います。区間の分割は今までのプログラムと同じです。そして、短い方の区間からクイックソートします。

次は挿入ソートを修正します。関数 insertSort' は次のようになります。

リスト : 単純挿入ソート

search_move :: Ord a => IOArray Int a -> Int -> a -> Int -> Int -> IO Int
search_move buff low x i g
  | i < low = return (i + g)
  | otherwise = do
      y <- readArray buff i
      if x < y
        then do writeArray buff (i + g) y
                search_move buff low x (i - g) g
        else return (i + g)

insertElement :: Ord a => IOArray Int a -> Int -> Int -> Int -> IO ()
insertElement buff low i gap = do
  tmp <- readArray buff i
  pos <- search_move buff low tmp (i - gap) gap
  writeArray buff pos tmp

insertSort :: Ord a => IOArray Int a -> IO ()
insertSort buff = do
  (low, high) <- getBounds buff
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

insertSort' :: Ord a => IOArray Int a -> Int -> Int -> IO ()
insertSort' buff low high = do
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

insertSort' はソートする区間を引数 low, high で指定します。insertElement と search_move の引数に下限値 low を指定して、引数 i から low までの間でデータを挿入する位置を探します。

それでは実行結果を示します。

  表 : quickSort' の結果 (単位 : 秒)

 [IOArray]
 個数    乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.022  0.006  0.007  0.013
160000 : 0.048  0.012  0.014  0.036
320000 : 0.109  0.026  0.031  0.068
640000 : 0.254  0.054  0.063  0.129

 [IOUArray]
 個数    乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.014  0.004  0.005  0.009
160000 : 0.031  0.007  0.008  0.020
320000 : 0.067  0.015  0.017  0.047
640000 : 0.138  0.032  0.038  0.090

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

昇順、降順のデータは qickSort よりも遅くなりましたが、乱数のデータは quickSort よりも速くなりました。山型のデータも高速にソートすることができます。枢軸の選択を改良した効果は十分に出ていると思います。median-of-9 は少ないコストで最悪のケースを回避する優れた方法だと思います。もちろん、median-of-9 でも最悪のケースが存在するはずですが、最悪のケースに遭遇する確率は median-of-3 よりも低くなると思います。興味のある方はいろいろ試してみてください。

●ヒープソート

ヒープ (heap) は拙作のページ ヒープ で説明したデータ構造です。実は、このヒープを使ったソートも優秀なアルゴリズムの一つです。実行時間は N * log2 N に比例しますが、平均するとクイックソートよりも遅くなります。しかし、クイックソートとは違って、データの種類によって性能が劣化することはありません。

プログラムは次のようになります。

リスト : ヒープソート

-- 子の選択
selectChild :: Ord a => IOArray Int a -> Int -> Int -> IO Int
selectChild buff c1 h = do
  let c2 = c1 + 1
  a <- readArray buff c1
  if c2 > h
    then return c1
    else do b <- readArray buff c2
            if a < b
              then return c2
              else return c1

heapIter :: Ord a => IOArray Int a -> a -> Int -> Int -> IO ()
heapIter buff x n h = do
  let c1 = 2 * n + 1
  if c1 <= h
    then do
      c <- selectChild buff c1 h
      y <- readArray buff c
      if x < y
        then do
          writeArray buff n y
          heapIter buff x c h
        else writeArray buff n x
    else writeArray buff n x

-- 葉の方向に向かってヒープを構築
downheap :: Ord a => IOArray Int a -> Int -> Int -> IO ()
downheap buff n h = do
  x <- readArray buff n
  heapIter buff x n h

heapSort :: Ord a => IOArray Int a -> IO ()
heapSort buff = do
  (low, high) <- getBounds buff
  iterR_ (low, (high - low + 1) `div` 2 - 1)
         (\i -> downheap buff i high)
  iterR_ (low + 1, high)
         (\i -> do swapItem buff 0 i
                   downheap buff 0 (i - 1))

heapSort の前半部分で関数 downheap を呼び出してヒープを構築します。親子関係が ヒープ の説明と逆になっていることに注意してください。つまり、親が子より大きいという関係を満たすようにヒープを構築します。したがって、配列の先頭 (buff[0]) が最大値になります。

後半部分で、最大値を取り出してヒープを再構築します。配列の先頭には最大値がセットされているので、これを配列の最後尾のデータと交換します。あとは、そのデータを除いた範囲でヒープを再構築すれば、その次に大きいデータを求めることができます。これを繰り返すことで、大きいデータが配列の後ろから整列していくことになります。

なお、downheap は compItem と swapItem を使わずにプログラムしています。downheap の中で swapItem を呼び出すと、実行速度はかなり遅くなります。ご注意くださいませ。

それでは実行結果を示します。

*Main> a <- newListArray (0,9) [5,6,4,7,3,8,2,9,1,0] :: IO (IOArray Int Int)
*Main> heapSort a
[0,1,2,3,4,5,6,7,8,9]
*Main> a <- newListArray (0,9) [9,8..0] :: IO (IOArray Int Int)
*Main> heapSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
*Main> a <- newListArray (0,9) [0..9] :: IO (IOArray Int Int)
*Main> heapSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
  表 : heap sort の結果 (単位 : 秒)

 [IOArray]
  個数   乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.042  0.028  0.028  0.029
160000 : 0.115  0.057  0.061  0.065
320000 : 0.343  0.125  0.126  0.135
640000 : 0.983  0.258  0.269  0.286

 [IOUArray]
  個数   乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.019  0.018  0.015  0.015
160000 : 0.040  0.028  0.029  0.033
320000 : 0.089  0.057  0.062  0.061
640000 : 0.207  0.125  0.126  0.134

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

このように、ヒープソートはどのデータに対しても、そこそこの速度でソートすることができます。ただし、実行時間はクイックソートよりも遅くなりました。参考文献 2 によると、ヒープソートの速度はクイックソートの半分くらいといわれています。ヒープソートの処理内容はクイックソートよりも複雑なので、時間がかかるのは仕方がないところでしょう。

●マージソート

マージ (併合 : merge) とはソート済みの複数の列を一つの列にまとめる操作のことです。このマージを使ったソートを「マージソート (merge sort)」といいます。最初にマージについて簡単に説明します。次の図を見てください。


      図 : マージの考え方

2 つのリスト a と b があります。これらのリストはソート済みとしましょう。これらのリストをソート済みのリストにまとめることを考えます。 a と b はソート済みなので先頭のデータがいちばん小さな値です。したがって、上図のように先頭のデータを比較し、小さい方のデータを取り出して順番に並べていけば、ソート済みのリストにまとめることができます。途中でどちらかのリストが空になったら、残ったリストのデータをそのまま追加します。当たり前だと思われるでしょうが、これがマージソート (merge sort) の原理です。次の図を見てください。

  9 5 3 7 6 4 2 8  最初の状態

 |5 9|3 7|4 6|2 8| 長さ2の列に併合

 |3 5 7 9|2 4 6 8| 長さ4の列に併合 

  2 3 4 5 6 7 8 9  ソート終了


        図 : マージソート

マージをソートに応用する場合、最初は各要素をソート済みの配列 (リスト) として考えます。この状態で隣の配列とマージを行い、長さ 2 の配列を作ります。次に、この配列に対して再度マージを行い、長さ 4 の配列を作ります。このように順番にマージしていくと、最後には一つの配列にマージされソートが完了します。

それではプログラムを作りましょう。配列の長さを 1, 2, 4, 8, ... と増やしていくよりも、再帰的に考えた方が簡単です。マージは 2 つの列を一つの列にまとめる操作です。そこで、まずソートする配列を 2 つに分けて、前半部分をソートします。次に後半部分をソートして、その結果をマージすればいいわけです。

では、どうやってソートするのかというと、再帰呼び出しするのです。そうすると、どんどん配列を 2 つに割っていくことになり、最後にデータが一つとなります。それはソート済みの配列と考えることができるので、再帰呼び出しを終了してマージ処理に移ることができます。あとはデータを順番にマージしていってソートが完了します。

プログラムは次のようになります。

リスト : マージソート

move :: Ord a => IOArray Int a -> Int -> Int -> Int -> IOArray Int a -> IO ()
move buff mid i k work
  | i <= mid = do x <- readArray work i
                  writeArray buff k x
                  move buff mid (i + 1) (k + 1) work
  | otherwise = return ()

merge :: Ord a => IOArray Int a -> Int -> Int -> Int -> Int -> Int -> IOArray Int a -> IO ()
merge buff mid high i j k work
  | i <= mid && j <= high = do
      a <- readArray work i
      b <- readArray buff j
      if a <= b
        then do writeArray buff k a
                merge buff mid high (i + 1) j (k + 1) work
        else do writeArray buff k b
                merge buff mid high i (j + 1) (k + 1) work
  | otherwise =
      if i <= mid
        then move buff mid i k work
        else return ()

msort :: Ord a => IOArray Int a -> Int -> Int -> IOArray Int a -> IO ()
msort buff low high work
  | high - low < 16 = insertSort' buff low high
  | otherwise = do
      let mid = low + (high - low) `div` 2
      msort buff low mid work
      msort buff (mid + 1) high work
      -- low から mid までの要素を work に退避
      iter_ (low, mid) (\i -> do x <- readArray buff i
                                 writeArray work i x)
      -- マージする
      merge buff mid high low (mid + 1) low work

mergeSort :: Ord a => IOArray Int a -> IO ()
mergeSort buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOArray Int a)
  msort buff low high work

最初に作業用の配列を newArray_ で生成して、それを関数 msort に渡します。今回のプログラムでは、ソートする配列と同じ大きさの作業用領域を用意しましたが、参考文献 1 によると、作業用領域の大きさはソートする配列の半分ですむそうです。興味のある方はプログラムを改良してください。

msort は、最初に区間の幅が 16 未満になったかチェックします。そうであれば、単純挿入ソート (insertSort') に切り替えてソートします。この方が少しですが速くなります。これが再帰呼び出しの停止条件になります。区間の幅が 16 以上の場合はマージソートを行います。

まず列の中央の位置を求めて変数 mid にセットします。最初に前半部、それから後半部をマージソートします。これは msort を再帰呼び出しするだけです。再帰呼び出しから戻ってくると、配列の前半部分と後半部分はソートされているのでマージ処理を行います。この処理を関数 merge で行います。

まず前半部分を作業領域 work に退避してから、merge を呼び出します。前半部分もしくは後半部分どちらかにデータがある間、データの比較と移動を繰り返し行います。前半部分と後半部分を先頭から順番に比較し、小さい方を区間の先頭から順番にセットしていきます。後半部分のデータが先になくなって、作業領域 work にデータが残っている場合は、関数 move でデータを後ろに追加します。

それでは実行結果を示します。

*Main> a <- newListArray (0,9) [5,6,4,7,3,8,2,9,1,0] :: IO (IOArray Int Int)
*Main> mergeSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
*Main> a <- newListArray (0,9) [9,8..0] :: IO (IOArray Int Int)
*Main> mergeSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
*Main> a <- newListArray (0,9) [0..9] :: IO (IOArray Int Int)
*Main> mergeSort a
*Main> getElems a
[0,1,2,3,4,5,6,7,8,9]
  表 : merge sort の結果 (単位 : 秒)

 [IOArray]
 個数  : 乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.055  0.022  0.037  0.031
160000 : 0.100  0.049  0.080  0.066
320000 : 0.219  0.109  0.158  0.141
640000 : 0.509  0.228  0.338  0.310

 [IOUArray]
 個数  : 乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.039  0.019  0.032  0.026
160000 : 0.079  0.039  0.063  0.055
320000 : 0.161  0.083  0.127  0.111
640000 : 0.344  0.182  0.278  0.236

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

マージソートの実行時間は、要素数を N とすると平均して N * log N に比例します。マージソートはクイックソートと同様に高速なアルゴリズムですが、実際にプログラムを作って比較してみると、クイックソートの方が高速になります。マージソートとヒープソートを比べると、一般的にはマージソートのほうが速いといわれていますが、IOUArray の結果を見ると、ヒープソートのほうが速くなりました。マージソートは配列と作業領域との間でデータの転送が行われます。このときに時間がかかっていると思われます。

マージソートは配列を単純に二分割していくため、クイックソートと違ってデータの種類によって性能が劣化することはありません。ヒープソートと同様に、どのようなデータに対しても力を発揮してくれるわけです。ただし、ヒープソートとは違って作業領域が必要になります。

●マージソートの改良

ところで、配列 buff と同じ大きさの作業領域 work を使うのであれば、最初に buff を work にコピーしておいて、再帰のたびに buff と work を交互に入れ換えることで、マージソートの実行速度を改善することができます。

なお、この方法は C++によるソート(sort)のページ 修正マージソート を参考にさせていただきました。同ページによると、『修正マージソートは、Java のクラス型のソートに採用されています。』 とのことです。有用な情報を公開されている作者様に感謝いたします。

プログラムは次のようになります。

リスト : マージソート (改良版)

mergeSort' :: Ord a => IOArray Int a -> IO ()
mergeSort' buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOArray Int a)
  copy buff work low high
  msort work buff low high
  where
    copy :: Ord a => IOArray Int a -> IOArray Int a -> Int -> Int -> IO ()
    copy src dst low high =
      iter_ (low, high) (\i -> do x <- readArray src i
                                  writeArray dst i x)
    msort a b low high
      | high - low < 16 = insertSort' b low high
      | otherwise = do
          let mid = (low + high) `div` 2
          msort b a low mid
          msort b a (mid + 1) high
          -- マージする
          merge mid low (mid + 1) low
          where
            merge mid i j k
              | i <= mid && j <= high = do
                  x <- readArray a i
                  y <- readArray a j
                  if x <= y
                    then do writeArray b k x
                            merge mid (i + 1) j (k + 1)
                    else do writeArray b k y
                            merge mid i (j + 1) (k + 1)
              | otherwise = do
                  move a b i mid k
                  move a b j high k
                  where
                    move a b i j k
                      | i > j = return ()
                      | otherwise = do
                          x <- readArray a i
                          writeArray b k x
                          move a b (i + 1) j (k + 1)

最初に、作業用の配列 work を確保して、局所関数 copy で buff の内容を work へコピーします。局所関数 msort a b low high は、配列 a の区間 (low, high) を二分割してソートし、その結果をマージするときに配列 b を使います。したがって、msort は msort work buff low high のように呼び出します。これで配列 buff をソートすることができます。msort を再帰呼び出しするときは、msort b a ... のように a と b を逆にすることに注意してください。

二つの区間をソートしたあと、配列 a の前半部分と後半部分はソートされているので、局所関数 merge で 2 つの区間をマージします。二つの区間をマージした結果は配列 b の区間 (low, high) にセットします。改良前の mergeSort では、あらかじめ buff の前半部分を work に退避していましたが、buff を work にコピーしておいて、buff と work を交互に切り替えることで、buff の前半部分を退避する処理が不要になります。最後に、区間内に要素が残っていたら局所関数 move で配列 b に転送します。

それでは実行結果を示します。

  表 : mergeSort' の結果 (単位 : 秒)

 [IOArray]
 個数  : 乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.046  0.020  0.031  0.025
160000 : 0.078  0.045  0.061  0.055
320000 : 0.174  0.103  0.119  0.111
640000 : 0.401  0.211  0.253  0.259

 [IOUArray]
 個数  : 乱数   昇順   逆順   山型
-----------------------------------
 80000 : 0.027  0.017  0.022  0.022
160000 : 0.057  0.036  0.047  0.046
320000 : 0.119  0.076  0.090  0.092
640000 : 0.251  0.164  0.191  0.186

実行環境 : Ubunts 18.04 (WSL), Intel Core i5-6200U 2.30GHz

mergeSort よりも mergeSort' のほうが速くなりました。改良の効果は十分に出ていると思います。メモリを多く使用することになりますが、このような簡単な方法でマージソートを改良できるとは驚きました。

なお、プログラムの実行時間は、筆者のコーディング、実行したマシン、使用するプログラミング言語(またはコンパイラ)などの環境に大きく依存しています。また、これらの環境だけではなく、データの種類によっても実行時間は大きく左右されます。興味のある方は、いろいろなデータをご自分の環境で試してみてください。

●参考文献

  1. 奥村晴彦, 『C言語による最新アルゴリズム事典』, 技術評論社, 1991
  2. 近藤嘉雪, 『Cプログラマのためのアルゴリズムとデータ構造』, ソフトバンク, 1998

●プログラムリスト1

--
-- sort1.hs : ソート
--
--            Copyright (C) 2013-2021 Makoto Hiroi
--
import Data.Array.IO
import Control.Monad
import Data.Time
import System.Random
import System.Environment

-- 要素の比較
compItem :: Ord a => IOArray Int a -> Int -> Int -> IO Ordering
compItem buff i j = liftM2 (compare) (readArray buff i) (readArray buff j)

-- 要素の交換
swapItem :: IOArray Int a -> Int -> Int -> IO ()
swapItem buff i j = do
  a <- readArray buff i
  b <- readArray buff j
  writeArray buff i b
  writeArray buff j a

-- イテレータ
iter_ :: (Int, Int) -> (Int -> IO ()) -> IO ()
iter_ (low, high) fn
  | low > high = return ()
  | otherwise  = do
      fn low
      iter_ (low + 1, high) fn

iterR_ :: (Int, Int) -> (Int -> IO ()) -> IO ()
iterR_ (low, high) fn
  | low > high = return ()
  | otherwise  = do
      fn high
      iterR_ (low, high - 1) fn

-- 単純挿入ソート
search_move :: Ord a => IOArray Int a -> Int -> a -> Int -> Int -> IO Int
search_move buff low x i g
  | i < low = return (i + g)
  | otherwise = do
      y <- readArray buff i
      if x < y
        then do writeArray buff (i + g) y
                search_move buff low x (i - g) g
        else return (i + g)

insertElement :: Ord a => IOArray Int a -> Int -> Int -> Int -> IO ()
insertElement buff low i gap = do
  tmp <- readArray buff i
  pos <- search_move buff low tmp (i - gap) gap
  writeArray buff pos tmp

insertSort :: Ord a => IOArray Int a -> IO ()
insertSort buff = do
  (low, high) <- getBounds buff
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

insertSort' :: Ord a => IOArray Int a -> Int -> Int -> IO ()
insertSort' buff low high = do
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

-- ヒープソート

-- 子の選択
selectChild :: Ord a => IOArray Int a -> Int -> Int -> IO Int
selectChild buff c1 h = do
  let c2 = c1 + 1
  a <- readArray buff c1
  if c2 > h
    then return c1
    else do b <- readArray buff c2
            if a < b
              then return c2
              else return c1

heapIter :: Ord a => IOArray Int a -> a -> Int -> Int -> IO ()
heapIter buff x n h = do
  let c1 = 2 * n + 1
  if c1 <= h
    then do
      c <- selectChild buff c1 h
      y <- readArray buff c
      if x < y
        then do
          writeArray buff n y
          heapIter buff x c h
        else writeArray buff n x
    else writeArray buff n x

-- 葉の方向に向かってヒープを構築
downheap :: Ord a => IOArray Int a -> Int -> Int -> IO ()
downheap buff n h = do
  x <- readArray buff n
  heapIter buff x n h

heapSort :: Ord a => IOArray Int a -> IO ()
heapSort buff = do
  (low, high) <- getBounds buff
  iterR_ (low, (high - low + 1) `div` 2 - 1)
         (\i -> downheap buff i high)
  iterR_ (low + 1, high)
         (\i -> do swapItem buff 0 i
                   downheap buff 0 (i - 1))

--
-- クイックソート
--
search1 :: Ord a => IOArray Int a -> a -> Int -> IO Int
search1 buff pv i = do
  x <- readArray buff i
  if pv > x
    then search1 buff pv (i + 1)
    else return i

search2 :: Ord a => IOArray Int a -> a -> Int -> IO Int
search2 buff pv j = do
  x <- readArray buff j
  if pv < x
    then search2 buff pv (j - 1)
    else return j

quickPartition :: Ord a => IOArray Int a -> a -> Int -> Int -> IO (Int, Int)
quickPartition buff pv low high = do
  i <- search1 buff pv low
  j <- search2 buff pv high
  if i < j
    then do swapItem buff i j
            quickPartition buff pv (i + 1) (j - 1)
    else return (i, j)

quickSort :: Ord a => IOArray Int a -> IO ()
quickSort buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high = do
      pv <- readArray buff (low + (high - low) `div` 2)
      (i, j) <- quickPartition buff pv low high
      when (low < i - 1)  $ qsort low (i - 1)
      when (high > j + 1) $ qsort (j + 1) high

median3 :: Ord a => a -> a -> a -> a
median3 a b c =
  if a > b
    then if b > c
           then b
           else if a < c then a else c
    else if b < c
           then b
           else if a < c then c else a

selectPv9 :: Ord a => IOArray Int a -> Int -> Int -> IO a
selectPv9 buff low high = do
  x1 <- readArray buff low
  x2 <- readArray buff (low + m8)
  x3 <- readArray buff (low + m4)
  x4 <- readArray buff (low + m2 - m8)
  x5 <- readArray buff (low + m2)
  x6 <- readArray buff (low + m2 + m8)
  x7 <- readArray buff (high - m4)
  x8 <- readArray buff (high - m8)
  x9 <- readArray buff high
  return (median3 (median3 x1 x2 x3)
                  (median3 x4 x5 x6)
                  (median3 x7 x8 x9))
  where m2 = (high - low) `div` 2
        m4 = m2 `div` 2
        m8 = m4 `div` 2

quickSort' :: Ord a => IOArray Int a -> IO ()
quickSort' buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high =
      if high - low < 16
        then insertSort' buff low high
        else do
          pv <- selectPv9 buff low high
          (i, j) <- quickPartition buff pv low high
          let a = (i - 1) - low
              b = high - (j + 1)
          if a < b
            then do qsort low (i - 1)
                    qsort (j + 1) high
            else do qsort (j + 1) high
                    qsort low (i - 1)

--
-- マージソート
--
move :: Ord a => IOArray Int a -> Int -> Int -> Int -> IOArray Int a -> IO ()
move buff mid i k work
  | i <= mid = do x <- readArray work i
                  writeArray buff k x
                  move buff mid (i + 1) (k + 1) work
  | otherwise = return ()

merge :: Ord a => IOArray Int a -> Int -> Int -> Int -> Int -> Int -> IOArray Int a -> IO ()
merge buff mid high i j k work
  | i <= mid && j <= high = do
      a <- readArray work i
      b <- readArray buff j
      if a <= b
        then do writeArray buff k a
                merge buff mid high (i + 1) j (k + 1) work
        else do writeArray buff k b
                merge buff mid high i (j + 1) (k + 1) work
  | otherwise =
      if i <= mid
        then move buff mid i k work
        else return ()

msort :: Ord a => IOArray Int a -> Int -> Int -> IOArray Int a -> IO ()
msort buff low high work
  | high - low < 16 = insertSort' buff low high
  | otherwise = do
      let mid = low + (high - low) `div` 2
      msort buff low mid work
      msort buff (mid + 1) high work
      -- low から mid までの要素を work に退避
      iter_ (low, mid) (\i -> do x <- readArray buff i
                                 writeArray work i x)
      -- マージする
      merge buff mid high low (mid + 1) low work

mergeSort :: Ord a => IOArray Int a -> IO ()
mergeSort buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOArray Int a)
  msort buff low high work

-- 改良版
mergeSort' :: Ord a => IOArray Int a -> IO ()
mergeSort' buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOArray Int a)
  copy buff work low high
  msort work buff low high
  where
    copy :: Ord a => IOArray Int a -> IOArray Int a -> Int -> Int -> IO ()
    copy src dst low high =
      iter_ (low, high) (\i -> do x <- readArray src i
                                  writeArray dst i x)
    msort a b low high
      | high - low < 16 = insertSort' b low high
      | otherwise = do
          let mid = (low + high) `div` 2
          msort b a low mid
          msort b a (mid + 1) high
          -- マージする
          merge mid low (mid + 1) low
          where
            merge mid i j k
              | i <= mid && j <= high = do
                  x <- readArray a i
                  y <- readArray a j
                  if x <= y
                    then do writeArray b k x
                            merge mid (i + 1) j (k + 1)
                    else do writeArray b k y
                            merge mid i (j + 1) (k + 1)
              | otherwise = do
                  move a b i mid k
                  move a b j high k
                  where
                    move a b i j k
                      | i > j = return ()
                      | otherwise = do
                          x <- readArray a i
                          writeArray b k x
                          move a b (i + 1) j (k + 1)

test :: (IOArray Int Int -> IO ()) -> Int -> IO ()
test sort n = do
  let m = n `div` 2
      check i ary
        | i == n    = return ()
        | otherwise = do
            t <- compItem ary (i - 1) i
            if t == GT then error "test error"
            else check (i + 1) ary
  a <- newListArray (0, n - 1) (take n (randoms (mkStdGen 11) :: [Int])) :: IO (IOArray Int Int)
  b <- newListArray (0, n - 1) [1..n] :: IO (IOArray Int Int)
  c <- newListArray (0, n - 1) [n,n-1..1] :: IO (IOArray Int Int)
  d <- newListArray (0, n - 1) ([1..m] ++ [m,m-1..1]) :: IO (IOArray Int Int)
  x1 <- getCurrentTime
  sort a
  x2 <- getCurrentTime
  check 1 a
  print (diffUTCTime x2 x1)
  x3 <- getCurrentTime
  sort b
  x4 <- getCurrentTime
  check 1 b
  print (diffUTCTime x4 x3)
  x5 <- getCurrentTime
  sort c
  x6 <- getCurrentTime
  check 1 c
  print (diffUTCTime x6 x5)
  x7 <- getCurrentTime
  sort d
  x8 <- getCurrentTime
  check 1 d
  print (diffUTCTime x8 x7)

main :: IO ()
main = do
  let xs2 = [40000, 80000, 160000, 320000, 640000]
  (x:_) <- getArgs
  case x of
    "heapSort"   -> mapM_ (test heapSort)   xs2
    "mergeSort"  -> mapM_ (test mergeSort)  xs2
    "mergeSort'" -> mapM_ (test mergeSort') xs2
    "quickSort"  -> mapM_ (test quickSort)  xs2
    "quickSort'" -> mapM_ (test quickSort') xs2

●プログラムリスト2

--
-- sortu1.hs : ソート (unboxed type)
--
--            Copyright (C) 2013-2021 Makoto Hiroi
--
import Data.Array.IO
import Control.Monad
import Data.Time
import System.Random
import System.Environment

-- 要素の比較
compItem :: IOUArray Int Int -> Int -> Int -> IO Ordering
compItem buff i j = liftM2 (compare) (readArray buff i) (readArray buff j)

-- 要素の交換
swapItem :: IOUArray Int Int -> Int -> Int -> IO ()
swapItem buff i j = do
  a <- readArray buff i
  b <- readArray buff j
  writeArray buff i b
  writeArray buff j a

-- イテレータ
iter_ :: (Int, Int) -> (Int -> IO ()) -> IO ()
iter_ (low, high) fn
  | low > high = return ()
  | otherwise  = do
      fn low
      iter_ (low + 1, high) fn

iterR_ :: (Int, Int) -> (Int -> IO ()) -> IO ()
iterR_ (low, high) fn
  | low > high = return ()
  | otherwise  = do
      fn high
      (iterR_ (low, high - 1) fn)

-- 単純挿入ソート
search_move :: IOUArray Int Int -> Int -> Int -> Int -> Int -> IO Int
search_move buff low x i g
  | i < low = return (i + g)
  | otherwise = do
      y <- readArray buff i
      if x < y
        then do writeArray buff (i + g) y
                search_move buff low x (i - g) g
        else return (i + g)

insertElement :: IOUArray Int Int -> Int -> Int -> Int -> IO ()
insertElement buff low i gap = do
  tmp <- readArray buff i
  pos <- search_move buff low tmp (i - gap) gap
  writeArray buff pos tmp

insertSort :: IOUArray Int Int -> IO ()
insertSort buff = do
  (low, high) <- getBounds buff
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

insertSort' :: IOUArray Int Int -> Int -> Int -> IO ()
insertSort' buff low high = do
  iter_ (low + 1, high) (\i -> insertElement buff low i 1)

-- ヒープソート

-- 葉の方向に向かってヒープを構築
selectChild :: IOUArray Int Int -> Int -> Int -> IO (Int, Int)
selectChild buff c1 h = do
  let c2 = c1 + 1
  a <- readArray buff c1
  if c2 > h
    then return (c1, a)
    else do
      b <- readArray buff c2
      if a < b
        then return (c2, b)
        else return (c1, a)

downheap :: IOUArray Int Int -> Int -> Int -> IO ()
downheap buff n h = do
  x <- readArray buff n
  iter x n h
  where
    iter x n h = do
      let c1 = 2 * n + 1
      if c1 <= h
        then do
          (c, y) <- selectChild buff c1 h
          if x < y
            then do
              writeArray buff n y
              iter x c h
            else writeArray buff n x
        else  writeArray buff n x

heapSort :: IOUArray Int Int -> IO ()
heapSort buff = do
  (low, high) <- getBounds buff
  iterR_ (low, (high - low + 1) `div` 2 - 1)
         (\i -> downheap buff i high)
  iterR_ (low + 1, high)
         (\i -> do swapItem buff 0 i
                   downheap buff 0 (i - 1))

--
-- クイックソート
--
search1 :: IOUArray Int Int -> Int -> Int -> IO Int
search1 buff pv i = do
  x <- readArray buff i
  if pv > x
    then search1 buff pv (i + 1)
    else return i

search2 :: IOUArray Int Int -> Int -> Int -> IO Int
search2 buff pv j = do
  x <- readArray buff j
  if pv < x
    then search2 buff pv (j - 1)
    else return j

quickPartition :: IOUArray Int Int -> Int -> Int -> Int -> IO (Int, Int)
quickPartition buff pv low high = do
  i <- search1 buff pv low
  j <- search2 buff pv high
  if i < j
    then do swapItem buff i j
            quickPartition buff pv (i + 1) (j - 1)
    else return (i, j)

quickSort :: IOUArray Int Int -> IO ()
quickSort buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high = do
      pv <- readArray buff (low + (high - low) `div` 2)
      (i, j) <- quickPartition buff pv low high
      when (low < i - 1)  $ qsort low (i - 1)
      when (high > j + 1) $ qsort (j + 1) high

median3 :: Int -> Int -> Int -> Int
median3 a b c =
  if a > b
    then if b > c
           then b
           else if a < c then a else c
    else if b < c
           then b
           else if a < c then c else a

selectPv9 :: IOUArray Int Int -> Int -> Int -> IO Int
selectPv9 buff low high = do
  x1 <- readArray buff low
  x2 <- readArray buff (low + m8)
  x3 <- readArray buff (low + m4)
  x4 <- readArray buff (low + m2 - m8)
  x5 <- readArray buff (low + m2)
  x6 <- readArray buff (low + m2 + m8)
  x7 <- readArray buff (high - m4)
  x8 <- readArray buff (high - m8)
  x9 <- readArray buff high
  return (median3 (median3 x1 x2 x3)
                  (median3 x4 x5 x6)
                  (median3 x7 x8 x9))
  where m2 = (high - low) `div` 2
        m4 = m2 `div` 2
        m8 = m4 `div` 2

quickSort' :: IOUArray Int Int -> IO ()
quickSort' buff = do
  (low, high) <- getBounds buff
  qsort low high
  where
    qsort low high =
      if high - low < 16
        then insertSort' buff low high
        else do
          pv <- selectPv9 buff low high
          (i, j) <- quickPartition buff pv low high
          let a = (i - 1) - low
              b = high - (j + 1)
          if a < b
            then do qsort low (i - 1)
                    qsort (j + 1) high
            else do qsort (j + 1) high
                    qsort low (i - 1)

--
-- マージソート
--
move :: IOUArray Int Int -> Int -> Int -> Int -> IOUArray Int Int -> IO ()
move buff mid i k work
  | i <= mid = do x <- readArray work i
                  writeArray buff k x
                  move buff mid (i + 1) (k + 1) work
  | otherwise = return ()

merge :: IOUArray Int Int -> Int -> Int -> Int -> Int -> Int -> IOUArray Int Int -> IO ()
merge buff mid high i j k work
  | i <= mid && j <= high = do
      a <- readArray work i
      b <- readArray buff j
      if a <= b
        then do writeArray buff k a
                merge buff mid high (i + 1) j (k + 1) work
        else do writeArray buff k b
                merge buff mid high i (j + 1) (k + 1) work
  | otherwise =
      if i <= mid
        then move buff mid i k work
        else return ()

msort :: IOUArray Int Int -> Int -> Int -> IOUArray Int Int -> IO ()
msort buff low high work
  | high - low < 16 = insertSort' buff low high
  | otherwise = do
      let mid = low + (high - low) `div` 2
      msort buff low mid work
      msort buff (mid + 1) high work
      -- low から mid までの要素を work に退避
      iter_ (low, mid) (\i -> do x <- readArray buff i
                                 writeArray work i x)
      -- マージする
      merge buff mid high low (mid + 1) low work

mergeSort :: IOUArray Int Int -> IO ()
mergeSort buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOUArray Int Int)
  msort buff low high work

-- 改良版
mergeSort' :: IOUArray Int Int -> IO ()
mergeSort' buff = do
  (low, high) <- getBounds buff
  work <- newArray_ (low, high) :: IO (IOUArray Int Int)
  copy buff work low high
  msort work buff low high
  where
    copy :: IOUArray Int Int -> IOUArray Int Int -> Int -> Int -> IO ()
    copy src dst low high =
      iter_ (low, high) (\i -> do x <- readArray src i
                                  writeArray dst i x)
    msort a b low high
      | high - low < 16 = insertSort' b low high
      | otherwise = do
          let mid = (low + high) `div` 2
          msort b a low mid
          msort b a (mid + 1) high
          -- マージする
          merge mid low (mid + 1) low
          where
            merge mid i j k
              | i <= mid && j <= high = do
                  x <- readArray a i
                  y <- readArray a j
                  if x <= y
                    then do writeArray b k x
                            merge mid (i + 1) j (k + 1)
                    else do writeArray b k y
                            merge mid i (j + 1) (k + 1)
              | otherwise = do
                  move a b i mid k
                  move a b j high k
                  where
                    move a b i j k
                      | i > j = return ()
                      | otherwise = do
                          x <- readArray a i
                          writeArray b k x
                          move a b (i + 1) j (k + 1)

test :: (IOUArray Int Int -> IO ()) -> Int -> IO ()
test sort n = do
  let m = n `div` 2
      check i ary
        | i == n    = return ()
        | otherwise = do
            t <- compItem ary (i - 1) i
            if t == GT then error "test error"
            else check (i + 1) ary
  a <- newListArray (0, n - 1) (take n (randoms (mkStdGen 11) :: [Int])) :: IO (IOUArray Int Int)
  b <- newListArray (0, n - 1) [1..n] :: IO (IOUArray Int Int)
  c <- newListArray (0, n - 1) [n,n-1..1] :: IO (IOUArray Int Int)
  d <- newListArray (0, n - 1) ([1..m] ++ [m,m-1..1]) :: IO (IOUArray Int Int)
  x1 <- getCurrentTime
  sort a
  x2 <- getCurrentTime
  --check 1 a
  print (diffUTCTime x2 x1)
  x3 <- getCurrentTime
  sort b
  x4 <- getCurrentTime
  --check 1 b
  print (diffUTCTime x4 x3)
  x5 <- getCurrentTime
  sort c
  x6 <- getCurrentTime
  --check 1 c
  print (diffUTCTime x6 x5)
  x7 <- getCurrentTime
  sort d
  x8 <- getCurrentTime
  --check 1 d
  print (diffUTCTime x8 x7)

main :: IO ()
main = do
  let xs2 = [80000, 160000, 320000, 640000]
  (x:_) <- getArgs
  case x of
    "heapSort"   -> mapM_ (test heapSort)   xs2
    "mergeSort"  -> mapM_ (test mergeSort)  xs2
    "mergeSort'" -> mapM_ (test mergeSort') xs2
    "quickSort"  -> mapM_ (test quickSort)  xs2
    "quickSort'" -> mapM_ (test quickSort') xs2

初版 2013 年 5 月 26 日
改訂 2021 年 8 月 8 日

Copyright (C) 2013-2021 Makoto Hiroi
All rights reserved.

[ PrevPage | Haskell | NextPage ]