オンラインEMアルゴリズムで混合ガウス分布推論

ずいぶん前にできていたのだが、変分ベイズのフォローのために、ブログに書くのを後回しにしてたオンラインEMについて。


確率的勾配法など、通常はオンラインの方がバッチより収束が遅い。
が、EMアルゴリズムについては、オンラインの方が収束が速いらしい。PRML にも「この逐次型のアルゴリズムはバッチ型に比べて速く収束する」と書かれており、また論文にもそういうデータが載っている。


EM アルゴリズムを実装してみると、確かに収束が遅い。
収束し始めてから遅いのなら許せるけど、2サイクル目くらいからすでに遅い。せめて最初くらいもうちょっと速くなんないの! と文句言いたくなる。
EM アルゴリズムの1サイクルは結構重いので、さらにその気分を助長する。


というわけで、オンラインEMアルゴリズムについてとても興味が湧いたので、実装してみた。
参考にしたのは Neal and Hinton の incremental EM の論文。

[Neal and Hinton, 1999] A View of the EM Algorithm that Justifies Incremental, Sparse, and Other Variants
http://www.cs.toronto.edu/~hinton/absps/emk.pdf

ただしこの論文で例として取り上げられているのは1次元2クラス混合ガウスという簡単なパターン。
PRML 9章10章を読んだ身としては、多次元多クラス混合ガウスでやってみたい。
というわけで、PRML の式 (9.78), (9.79) の更新式を参考に、共分散と混合係数の更新式も導いてみた。

多次元混合ガウス分布での incremental EM 更新式
http://d.hatena.ne.jp/n_shuyo/20100309/incremental

超力作な式なので、再掲しておこう。

E step

x_m について更新するとき、

\gamma(z_{nk})^{(t+1)} = \gamma(z_{nk})^{(t)}, \; (n \ne m)

\gamma(z_{mk})^{(t+1)} = \frac {\pi_k^{(t)} \mathcal{N}(\bf{x}_m | \bf{\mu}_k^{(t)}, \bf{\Sigma}_k^{(t)})} {\sum_{j=1}^K \pi_j^{(t)} \mathcal{N}(\bf{x}_m | \bf{\mu}_j^{(t)}, \bf{\Sigma}_j^{(t)})}, \; (k=1,...,K)

M step

\Delta_k^{(t+1)} = \gamma(z_{mk})^{(t+1)} - \gamma(z_{mk})^{(t)} とおくと


N_k^{(t+1)} = N_k^{(t)} + \Delta_k^{(t+1)}

\pi_k^{(t+1)} = \pi_k^{(t)} + \frac{\Delta_k^{(t+1)}}{N}

\bf{\mu}_k^{(t+1)} = \bf{\mu}_k^{(t)} + \frac{\Delta_k^{(t+1)}}{N_k^{(t+1)}} (\bf{x}_m - \bf{\mu}_k^{(t)})

\bf{\Sigma}_k^{(t+1)} = \left(1-\frac{\Delta_k^{(t+1)}}{N_k^{(t+1)}}\right) \left\{ \bf{\Sigma}_k^{(t)} + \frac{\Delta_k^{(t+1)}}{N_k^{(t+1)}} (\bf{x}_m - \bf{\mu}_k^{(t)})(\bf{x}_m - \bf{\mu}_k^{(t)})^T\right\}


\gamma(z_{nk}) が Neal and Hinton で言うところの十分統計量に当たる。
Σ_k の更新式の導出がなかなか骨があって間違いやすい*1ので、検算してやろうと思った人はちょっと覚悟しておくように(苦笑
E step の π_k は全部 N_k に置き換え可能なので、π_k の更新式は実は不要だが、PRML との対比したときにわかりやすいよう、式でも実装でも残してある。


実装にあたって、次に必要なのは初期値である。毎回これでハマるわけだが、今回も例に漏れず。


このあたりは順当。問題は \gamma(z_{nk})^{(0)} だった。


手始めに正規乱数を突っ込んでみたが、全くまともな挙動にならない。
あれこれ試して、最終的に「1回目のサイクルは通常のEMアルゴリズム。2回目以降を incremental EM にする」という方法でうまく動くようになった。
実は、それが一番最初に思いついた方法だったのだが、「通常の EMA と incremental EMA の両方を実装しないといけないのは嬉しくないよなあ」と避けようとしてしまったのが敗因。


そうやって実装したのがこちら。

同じ初期値を用いて、通常の EM と incremental EM を行い、結果を出力する。これを初期値を取り替えながら10回繰り返すようになっている。
データセットは、3次元3クラスのデータセット(450点)をランダムに生成するか、Old Faithful(コマンドラインに "faithful" を与える)を用いている。


Old Faithful の場合の結果がこちら。

Normal 1:convergence=20, likelihood=-884.3615, 7.70sec
Online 1:convergence=13, likelihood=-884.3615, 5.09sec

Normal 2:convergence=13, likelihood=-884.3617, 5.02sec
Online 2:convergence=9, likelihood=-884.3620, 3.47sec

Normal 3:convergence=11, likelihood=-884.3616, 4.33sec
Online 3:convergence=8, likelihood=-884.3619, 3.06sec

Normal 4:convergence=10, likelihood=-884.3620, 3.88sec
Online 4:convergence=8, likelihood=-884.3615, 3.11sec

Normal 5:convergence=10, likelihood=-884.3619, 3.88sec
Online 5:convergence=8, likelihood=-884.3615, 3.09sec

Normal 6:convergence=12, likelihood=-884.3617, 4.61sec
Online 6:convergence=9, likelihood=-884.3618, 3.46sec

Normal 7:convergence=9, likelihood=-884.3617, 3.51sec
Online 7:convergence=7, likelihood=-884.3616, 2.66sec

Normal 8:convergence=15, likelihood=-884.3619, 5.75sec
Online 8:convergence=11, likelihood=-884.3615, 4.36sec

Normal 9:convergence=10, likelihood=-884.3615, 3.85sec
Online 9:convergence=7, likelihood=-884.3622, 2.65sec

Normal 10:convergence=19, likelihood=-884.3616, 7.30sec
Online 10:convergence=13, likelihood=-884.3615, 5.23sec

同じ番号の Normal と Online はそれぞれ同じ初期値を与えた場合の通常の EMA と incremental EMA の結果を指している。
convergence は収束回数(Online では、最初の1回の通常 EM の分もカウントしている)。
likelihood は収束時の対数尤度。
sec はもちろん時間(全然最適化とかしていないので参考程度に)。


収束回数の平均は通常 EM が 12.9 回に対し、オンラインは 9.3 回。
オンラインでは、収束に必要なイテレーションの回数が7割程度に減っていることがわかる。
同じ対数尤度に収束していることもわかるだろう。


しかし Old Faithful はデータセットとしてはきれいすぎる。
というわけで、3次元3クラスのデータセット(450点)の場合。

Normal 1:convergence=30, likelihood=-3863.4727, 26.24sec
Online 1:convergence=20, likelihood=-3863.4719, 18.15sec

Normal 2:convergence=45, likelihood=-3863.4734, 39.02sec
Online 2:convergence=29, likelihood=-3863.4715, 26.50sec

Normal 3:convergence=25, likelihood=-3863.4730, 21.71sec
Online 3:convergence=17, likelihood=-3863.4717, 15.28sec

Normal 4:convergence=48, likelihood=-3863.4733, 41.70sec
Online 4:convergence=32, likelihood=-3863.4722, 29.22sec

Normal 5:convergence=58, likelihood=-3863.4728, 50.57sec
Online 5:convergence=36, likelihood=-3863.4723, 33.04sec

Normal 6:convergence=52, likelihood=-3863.4731, 45.40sec
Online 6:convergence=32, likelihood=-3863.4722, 29.32sec

Normal 7:convergence=140, likelihood=-3901.9555, 123.04sec
Online 7:convergence=70, likelihood=-3863.4716, 65.34sec

Normal 8:convergence=32, likelihood=-3863.4729, 27.98sec
Online 8:convergence=22, likelihood=-3863.4715, 20.00sec

Normal 9:convergence=17, likelihood=-3863.4727, 14.75sec
Online 9:convergence=12, likelihood=-3863.4714, 10.66sec

Normal 10:convergence=32, likelihood=-3863.4725, 27.89sec
Online 10:convergence=21, likelihood=-3863.4714, 19.15sec

収束回数の平均は、通常 EM が 47.9 回、online が 29.1 回となり、比率が6割に拡大している。
データ点が増えるほどオンライン EM は有利になる、というのが直感的な予想だが、きっと外してないだろう(未検証)。


それよりさらに注目して欲しいのは Normal 7 と Online 7。
きっと特異な初期値だったのだろう。通常 EM は 140 回も回したあげく、局所解に収束してしまっている(likelihood=-3901.9555)のだが、online版は 70 回で他のケースと同じ -3863.47 (おそらく大域解)に収束している。
online 版は局所解にハマりにくい傾向がある、というのを実証してくれた格好だ。


というわけで、online EM はいろいろと有効そうだ、という感触が得られた。


pLSI も実装してみたりしているのだが、こちらの EM アルゴリズムも例によって収束が遅い遅いほんとに遅い。
是非オンライン化してみたいところだが、更新式を導出するのが明らかに混合ガウスより大変そうなので、二の足を踏んでいるところ……

*1:変分下界もそこそこ大変だったが、これに比べればかわいいもん