ずいぶん前にできていたのだが、変分ベイズのフォローのために、ブログに書くのを後回しにしてたオンラインEMについて。
確率的勾配法など、通常はオンラインの方がバッチより収束が遅い。
が、EMアルゴリズムについては、オンラインの方が収束が速いらしい。PRML にも「この逐次型のアルゴリズムはバッチ型に比べて速く収束する」と書かれており、また論文にもそういうデータが載っている。
EM アルゴリズムを実装してみると、確かに収束が遅い。
収束し始めてから遅いのなら許せるけど、2サイクル目くらいからすでに遅い。せめて最初くらいもうちょっと速くなんないの! と文句言いたくなる。
EM アルゴリズムの1サイクルは結構重いので、さらにその気分を助長する。
というわけで、オンラインEMアルゴリズムについてとても興味が湧いたので、実装してみた。
参考にしたのは Neal and Hinton の incremental EM の論文。
- [Neal and Hinton, 1999] A View of the EM Algorithm that Justifies Incremental, Sparse, and Other Variants
- http://www.cs.toronto.edu/~hinton/absps/emk.pdf
ただしこの論文で例として取り上げられているのは1次元2クラス混合ガウスという簡単なパターン。
PRML 9章10章を読んだ身としては、多次元多クラス混合ガウスでやってみたい。
というわけで、PRML の式 (9.78), (9.79) の更新式を参考に、共分散と混合係数の更新式も導いてみた。
- 多次元混合ガウス分布での incremental EM 更新式
- http://d.hatena.ne.jp/n_shuyo/20100309/incremental
超力作な式なので、再掲しておこう。
E step
x_m について更新するとき、
M step
とおくと
が Neal and Hinton で言うところの十分統計量に当たる。
Σ_k の更新式の導出がなかなか骨があって間違いやすい*1ので、検算してやろうと思った人はちょっと覚悟しておくように(苦笑
E step の π_k は全部 N_k に置き換え可能なので、π_k の更新式は実は不要だが、PRML との対比したときにわかりやすいよう、式でも実装でも残してある。
実装にあたって、次に必要なのは初期値である。毎回これでハマるわけだが、今回も例に漏れず。
- = 正規乱数ベクトル
- = 単位行列
このあたりは順当。問題は だった。
手始めに正規乱数を突っ込んでみたが、全くまともな挙動にならない。
あれこれ試して、最終的に「1回目のサイクルは通常のEMアルゴリズム。2回目以降を incremental EM にする」という方法でうまく動くようになった。
実は、それが一番最初に思いついた方法だったのだが、「通常の EMA と incremental EMA の両方を実装しないといけないのは嬉しくないよなあ」と避けようとしてしまったのが敗因。
そうやって実装したのがこちら。
同じ初期値を用いて、通常の EM と incremental EM を行い、結果を出力する。これを初期値を取り替えながら10回繰り返すようになっている。
データセットは、3次元3クラスのデータセット(450点)をランダムに生成するか、Old Faithful(コマンドラインに "faithful" を与える)を用いている。
Old Faithful の場合の結果がこちら。
Normal 1:convergence=20, likelihood=-884.3615, 7.70sec Online 1:convergence=13, likelihood=-884.3615, 5.09sec Normal 2:convergence=13, likelihood=-884.3617, 5.02sec Online 2:convergence=9, likelihood=-884.3620, 3.47sec Normal 3:convergence=11, likelihood=-884.3616, 4.33sec Online 3:convergence=8, likelihood=-884.3619, 3.06sec Normal 4:convergence=10, likelihood=-884.3620, 3.88sec Online 4:convergence=8, likelihood=-884.3615, 3.11sec Normal 5:convergence=10, likelihood=-884.3619, 3.88sec Online 5:convergence=8, likelihood=-884.3615, 3.09sec Normal 6:convergence=12, likelihood=-884.3617, 4.61sec Online 6:convergence=9, likelihood=-884.3618, 3.46sec Normal 7:convergence=9, likelihood=-884.3617, 3.51sec Online 7:convergence=7, likelihood=-884.3616, 2.66sec Normal 8:convergence=15, likelihood=-884.3619, 5.75sec Online 8:convergence=11, likelihood=-884.3615, 4.36sec Normal 9:convergence=10, likelihood=-884.3615, 3.85sec Online 9:convergence=7, likelihood=-884.3622, 2.65sec Normal 10:convergence=19, likelihood=-884.3616, 7.30sec Online 10:convergence=13, likelihood=-884.3615, 5.23sec
同じ番号の Normal と Online はそれぞれ同じ初期値を与えた場合の通常の EMA と incremental EMA の結果を指している。
convergence は収束回数(Online では、最初の1回の通常 EM の分もカウントしている)。
likelihood は収束時の対数尤度。
sec はもちろん時間(全然最適化とかしていないので参考程度に)。
収束回数の平均は通常 EM が 12.9 回に対し、オンラインは 9.3 回。
オンラインでは、収束に必要なイテレーションの回数が7割程度に減っていることがわかる。
同じ対数尤度に収束していることもわかるだろう。
しかし Old Faithful はデータセットとしてはきれいすぎる。
というわけで、3次元3クラスのデータセット(450点)の場合。
Normal 1:convergence=30, likelihood=-3863.4727, 26.24sec Online 1:convergence=20, likelihood=-3863.4719, 18.15sec Normal 2:convergence=45, likelihood=-3863.4734, 39.02sec Online 2:convergence=29, likelihood=-3863.4715, 26.50sec Normal 3:convergence=25, likelihood=-3863.4730, 21.71sec Online 3:convergence=17, likelihood=-3863.4717, 15.28sec Normal 4:convergence=48, likelihood=-3863.4733, 41.70sec Online 4:convergence=32, likelihood=-3863.4722, 29.22sec Normal 5:convergence=58, likelihood=-3863.4728, 50.57sec Online 5:convergence=36, likelihood=-3863.4723, 33.04sec Normal 6:convergence=52, likelihood=-3863.4731, 45.40sec Online 6:convergence=32, likelihood=-3863.4722, 29.32sec Normal 7:convergence=140, likelihood=-3901.9555, 123.04sec Online 7:convergence=70, likelihood=-3863.4716, 65.34sec Normal 8:convergence=32, likelihood=-3863.4729, 27.98sec Online 8:convergence=22, likelihood=-3863.4715, 20.00sec Normal 9:convergence=17, likelihood=-3863.4727, 14.75sec Online 9:convergence=12, likelihood=-3863.4714, 10.66sec Normal 10:convergence=32, likelihood=-3863.4725, 27.89sec Online 10:convergence=21, likelihood=-3863.4714, 19.15sec
収束回数の平均は、通常 EM が 47.9 回、online が 29.1 回となり、比率が6割に拡大している。
データ点が増えるほどオンライン EM は有利になる、というのが直感的な予想だが、きっと外してないだろう(未検証)。
それよりさらに注目して欲しいのは Normal 7 と Online 7。
きっと特異な初期値だったのだろう。通常 EM は 140 回も回したあげく、局所解に収束してしまっている(likelihood=-3901.9555)のだが、online版は 70 回で他のケースと同じ -3863.47 (おそらく大域解)に収束している。
online 版は局所解にハマりにくい傾向がある、というのを実証してくれた格好だ。
というわけで、online EM はいろいろと有効そうだ、という感触が得られた。
pLSI も実装してみたりしているのだが、こちらの EM アルゴリズムも例によって収束が遅い遅いほんとに遅い。
是非オンライン化してみたいところだが、更新式を導出するのが明らかに混合ガウスより大変そうなので、二の足を踏んでいるところ……
*1:変分下界もそこそこ大変だったが、これに比べればかわいいもん