top page > computer > haskell > algorithm > memoization > lazy
更新日:
文責: 重城良国

Haskell: 遅延型を利用したメモ化

(工事中 70%)

このページでは

このページでは以下のような内容を扱う。

単純な再帰的関数では非常に時間がかかってしまう処理を遅延型を利用したメモ化を使うことで効率的に実行することができる。そのような手法を「単純な機械の例」を使って説明する。素直な実装では200億年かかる演算がメモ化を使うことで0.2秒で終了することを示す。

弱頭部正規形までの評価

Haskellでは代数的データ型の中身についても、必要になるまでは評価しないという特徴がある。正式に言うと弱頭部正規形までの評価ということになる。これを利用することで効率的なメモ化を簡潔なコードで行うことが可能になる。この手法はパックラット構文解析に使用されている。

メモ化のために必要な条件

メモ化するためにはメモリ上の表現として同一である必要がある。たとえば、以下のような場合にはg 3は一度しか実行されず、ある意味ではメモ化されていると言うことができる。

let f x = x + x in f $ g 3

しかし、以下の場合には1回目のg 3と2回目のg 3とは別々に評価される。つまりメモ化されていないと言える。

g 3 + g 3

つまり、関数の引数が同一であったとしても、もともとの表記上で同一でなければ、その都度別々に評価される。

単純な例

フィボナッチ数列の素直な定義として以下のようなものがある。

fib 0 = 0
fib 1 = 1
fib n = fib (n - 2) + fib (n - 1)

しかし、これは評価に指数関数時間が必要となる。これを以下のように書き換えることができる。

fib = (fibs !!)
fibs@(_ : tfibs) = 0 : 1 : zipWith (+) fibs tfibs

これはメモ化による最適化の一例である。

バックトラック

何かの作業がうまくいかなかったときに、すこしもどって、もう一度やりなおすといったアルゴリズムがある。たとえばある場所に行きたいとき、ある道順で行って行き止まりになってしまったとする。そしたらその直前の交差点までもどり別の道を行く。その交差点から行けるすべての道を試してもたどりつけなかったならば、さらに前の交差点までもどり同じことをくりかえす。これがバックトラックである。

全ての交差点で道が2又に分かれている道を考える。すると試す回数は、目的地までの距離(交差点の数)をnとして、だいたい2^n程度になる。つまり、バックトラックを愚直に実装すると最悪の場合、指数関数時間がかかってしまうことになる。

入力列から値を作成する機械の例

入力列をたどりながら値を作成する機械を考える。まずは数字の列を取り、それに続いて'+'または'-'の文字列を取る。結果は数字の列と逆順に'+'または'-'を適用していったものとする。たとえば"3456+-++"ならば3(4(5(6+)-)+)+のように解釈され、結果は0+6-5+4+3で8になる。これはスタックマシンを使えば効率的に実装できる。

再帰的処理による実装

ここではあえて再帰的処理によって実装し、それが指数関数時間になることを見ていこう。

machine_rec.hs (コード解説)

type Result v = Maybe (v, String)
char :: String -> Result Char
char (x : xs) = Just (x, xs)
char _ = Nothing
run :: String -> Maybe Int
run src | Just (n, _) <- rule src = Just n
run _ = Nothing
rule :: String -> Result Int
rule s0 = msum [ do
	(c, s) <- char s0
	guard $ isDigit c
	(n, s') <- rule s
	('+', s'') <- char s'
	return (n + fromDigit c, s''), do
	(c, s) <- char s0
	guard $ isDigit c
	(n, s') <- rule s
	('-', s'') <- char s'
	return (n - fromDigit c, s''),
	return (0, s0) ]
fromDigit :: Char -> Int
fromDigit '0' = 0
fromDigit '1' = 1
...
fromDigit '9' = 9
fromDigit c = error $ show c ++ " is not digit."

サンプル入力作成用の関数を作成して試してみる。

sample n = take n (concat $ repeat "0123456789")
	++ replicate n '-'
% time ghc -e 'run $ sample 0' machine_rec.hs
Just 0
ghc -e 'run $ sample 0' machine_rec.hs 0.20s user 0.01s
system 94% cpu 0.224 total
% time ghc -e 'run $ sample 16' machine_rec.hs
Just (-60)
ghc -e 'run $ sample 16' machine_rec.hs 0.74s user 0.02s
system 90% cpu 0.841 total
% time ghc -e 'run $ sample 17' machine_rec.hs
Just (-66)
ghc -e 'run $ sample 17' machine_rec.hs 1.29s user 0.02s
system 93% cpu 1.406 total
% time ghc -e 'run $ sample 18' machine_rec.hs
Just (-73)
ghc -e 'run $ sample 18' machine_rec.hs 2.40s user 0.02s
system 90% cpu 2.686 total

実行時間を表にすると以下のようになる。

数字の列の長さ実行時間長さ0との差
00.20s0s
160.74s0.54s
171.29s1.09s
182.40s2.20s
.........
76200億年200億年

数字の列の長さが1増えると実行時間が2倍になっているのがわかる。つまりO(2^n)であり、指数関数時間がかかっているということになる。なお、長さ76のデータは理論値である。宇宙が始まったのが138億年前と言われている(ウィキペディア:宇宙の年表参照)。

メモ化による効率化

machine_memo.hs (コード解説 工事中 40%)

type Result v = Maybe (v, Derivs)
data Derivs = Derivs {
	rule :: Result Int,
	char :: Result Char }
run :: String -> Maybe Int
run src | Just (n, _) <- rule $ derivs src = Just n
run _ = Nothing
derivs :: String -> Derivs
derivs src = d
	where
	d = Derivs rl ch
	rl = pRule d
	ch = case src of
		c : cs -> Just (c, derivs cs)
		_ -> Nothing
pRule :: Derivs -> Result Int
pRule d0 = msum [ do
	(c, d) <- char d0
	guard $ isDigit c
	(n, d') <- rule d
	('+', d'') <- char d'
	return (n + fromDigit c, d''), do
	(c, d) <- char d0
	guard $ isDigit c
	(n, d') <- rule d
	('-', d'') <- char d'
	return (n - fromDigit c, d''),
	return (0, d0) ]

試してみる。

% time ghc -e 'run $ sample 0' machine_memo.hs
Just 0
ghc -e 'run $ sample 0' machine_memo.hs 0.20s user 0.02s
system 96% cpu 0.224 total
% time ghc -e 'run $ sample 76' machine_memo.hs
Just (-330)
ghc -e 'run $ sample 76' machine_memo.hs 0.21s user 0.01s
system 95% cpu 0.228 total
% time ghc -e 'run $ sample 100000' machine_memo.hs
Just (-450000)
ghc -e 'run $ sample 100000' machine_memo.hs 1.12s user 0.04s
System 90% cpu 1.283 total
% time ghc -e 'run $ sample 200000' machine_memo.hs
Just (-900000)
ghc -e 'run $ sample 200000' machine_memo.hs 2.10s user 0.05s
System 87% cpu 2.452 total
% time ghc -e 'run $ sample 400000' machine_memo.hs
Just (-1800000)
ghc -e 'run $ sample 400000' machine_memo.hs 3.97s user 0.09s
System 83% cpu 4.870 total
数字の列の長さ実行時間長さ0との差
00.20s0s
760.21s0.01s
1000001.12s0.92s
2000002.10s1.90s
4000003.97s3.77s
.........
1億15分15分

長さ1億のデータは理論値である。メモ化なしでは200億年かかる演算がメモ化することにより、0.2秒で終了する。長さ10万についての計算が1秒で終了し、長さ40万については4秒で終了する。予想通りO(n)、つまり線形時間で終了するようになった。

「メモ化」トップへもどる

正当なCSSです! HTML5 Powered with CSS3 / styling, and Semantics