PRML 読書会 #13 「10.2 変分混合ガウス分布」資料(2)

「パターン認識と機械学習」(PRML)読書会 #13 で担当する 10.2.1「変分事後分布」の資料の後半です。その1はこちら

負担率 r_nk を求める

q^*(π_k,μ_k,Λ_k) を推定したので、ρ(z_nk) の各項を計算できる。

 \ln \rho_{nk} = \mathbb{E}[\ln\pi_k]+\frac{1}{2}\mathbb{E}[\ln|\boldsymbol{\Lambda}_k|]-\frac{D}{2}\ln(2\pi)-\frac{1}{2}\mathbb{E}_{\boldsymbol{\mu}_k,\boldsymbol{\Lambda}_k}[(\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k)]

(B.21)より  \mathbb{E}[\ln\pi_k]=\psi(\alpha_k)-\psi(\hat\alpha)

  • ψ(・) はディガンマ関数 (B.25)
  •  \hat\alpha=\sum \alpha_k

(B.81) より
 \mathbb{E}[\ln|\boldsymbol{\Lambda}_k|] = \sum_{i=1}^D\psi\left(\frac{\nu_k+1-i}{2}\right)+D\ln 2+\ln|\boldsymbol{W}|


 (\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k) = \rm{Tr}\{ (\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k) \} = \rm{Tr}\{  \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k) (\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \}

であり、また積分とトレースは交換可能ゆえ

 \begin{eqnarray}&& \mathbb{E}_{\boldsymbol{\mu}_k,\boldsymbol{\Lambda}_k}[(\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k)]\\&=& \int\int q^*(\boldsymbol{\mu}_k|\boldsymbol{\Lambda}_k) q^*(\boldsymbol{\Lambda}_k) (\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T \boldsymbol{\Lambda}_k (\boldsymbol{x}_n-\boldsymbol{\mu}_k) d\boldsymbol{\mu}_k d\boldsymbol{\Lambda}_k \\&=& \rm{Tr}\left\{ \int q^*(\boldsymbol{\Lambda}_k) \boldsymbol{\Lambda}_k \int q^*(\boldsymbol{\mu}_k|\boldsymbol{\Lambda}_k) (\boldsymbol{x}_n-\boldsymbol{\mu}_k)(\boldsymbol{x}_n-\boldsymbol{\mu}_k)^T d\boldsymbol{\mu}_k d\boldsymbol{\Lambda}_k \right\} \\&=& \rm{Tr}\left\{ \int q^*(\boldsymbol{\Lambda}_k) \boldsymbol{\Lambda}_k \mathbb{E}_{\boldsymbol{\mu}_k}[(\boldsymbol{\mu}_k-\boldsymbol{x}_n)(\boldsymbol{\mu}_k-\boldsymbol{x}_n)^T]d\boldsymbol{\Lambda}_k \right\} \\&=& \rm{Tr}\left\{ \int q^*(\boldsymbol{\Lambda}_k) \boldsymbol{\Lambda}_k\{(\boldsymbol{m}_k-\boldsymbol{x}_n)(\boldsymbol{m}_k-\boldsymbol{x}_n)^T+(\beta_k \boldsymbol{\Lambda}_k)^{-1}\}d\boldsymbol{\Lambda}_k \right\} \\&=&  \rm{Tr}\left\{ \nu_k\boldsymbol{W}_k(\boldsymbol{x}_n-\boldsymbol{m}_k)(\boldsymbol{x}_n-\boldsymbol{m}_k)^T + \beta_k^{-1}I \right\}\\&=& \nu_k(\boldsymbol{x}_n-\boldsymbol{m}_k)^T\boldsymbol{W}_k(\boldsymbol{x}_n-\boldsymbol{m}_k) + \beta_k^{-1}D\end{eqnarray}


以上より、
 \ln\tilde{\Lambda}_k \equiv \mathbb{E}[\ln|\boldsymbol{\Lambda}_k|], \; \ln\tilde{\pi_k} \equiv \mathbb{E}[\ln\pi_k] とおけば、負担率 r_nk を以下の式から求めることが出来る。

 r_{nk}\propto \tilde{\pi_k}\tilde{\Lambda}_k^{1/2} \exp\left\{ -\frac{D}{2\beta_k}-\frac{\nu_k}{2}(\boldsymbol{x}_n-\boldsymbol{m}_k)^T\boldsymbol{W}_k(\boldsymbol{x}_n-\boldsymbol{m}_k) \right\}

この右辺を計算し、 \sum_{n=1}^N r_{nk} = 1 により正規化すればよい。

変分ベイズアルゴリズム(まとめ)

1. パラメータ α_0, m_0, β_0, ν_0, W_0 を適宜与える
パラメータ α_k, m_k, β_k, ν_k, W_k を初期化する。
初期化の目安は後述。

2. 「変分ベイズ E ステップ」
負担率 r_nk を得る

 \ln\tilde{\Lambda}_k \equiv \mathbb{E}[\ln|\boldsymbol{\Lambda}_k|] = \sum_{i=1}^D\psi\left(\frac{\nu_k+1-i}{2}\right)+D\ln 2+\ln|\boldsymbol{W}|
 \ln\tilde{\pi_k} \equiv \mathbb{E}[\ln\pi_k]=\psi(\alpha_k)-\psi(\hat\alpha)
より
 r_{nk}\propto \tilde{\pi_k}\tilde{\Lambda}_k^{1/2} \exp\left\{ -\frac{D}{2\beta_k}-\frac{\nu_k}{2}(\boldsymbol{x}_n-\boldsymbol{m}_k)^T\boldsymbol{W}_k(\boldsymbol{x}_n-\boldsymbol{m}_k) \right\}
の右辺を計算し、 \sum_{n=1}^N r_{nk} = 1 を満たすよう正規化する。


3. 「変分ベイズ M ステップ」

r_nk を用いて、統計量 N_k, \bar{\boldsymbol x}_k, S_k を求める。

 N_k = \sum_{n=1}^N r_{nk}
 \bar{\boldsymbol{x}}_k=\frac{1}{N_k}\sum_{n=1}^N r_{nk} \boldsymbol{x}_n
 S_k=\frac{1}{N_k}\sum_{n=1}^N r_{nk}(\boldsymbol{x}_n-\bar{\boldsymbol{x}}_k)(\boldsymbol{x}_n-\bar{\boldsymbol{x}}_k)^T

これらを用いて、パラメータ α_k, m_k, β_k, ν_k, W_k を更新する。

 \alpha_k = \alpha_0 + N_k, \; \beta_k = \beta_0 + N_k, \; \nu_k = \nu_0 + N_k
 \boldsymbol{m}_k = \frac{ \beta_0 \boldsymbol{m}_0 + N_k \bar{\boldsymbol{x}}_k}{\beta_0+N_k}
 \boldsymbol{W}_k=\boldsymbol{W}_0^{-1} + N_k S_k + \frac{\beta_0N_k}{\beta_0+N_k}(\bar{\boldsymbol{x}}_k-\boldsymbol{m}_0)(\bar{\boldsymbol{x}}_k-\boldsymbol{m}_0)^T

4. 2 & 3 を収束するまで繰り返す

パラメータの初期化

  • 「 m_0 は対称性から0にする」(p189) とあるが、そうしてしまったら E ステップで得られる負担率が均等になり(∀n, k に対し r_nk=1/K)、M ステップで得られる更新後のパラメータも一定、以下ループ。したがって m_0 は正規乱数などで散らす必要あり。
  • m_0 は 0 でよい。m_k の初回値に正規乱数を載せて散らす。→フォロー記事 もごらんください。
  • α_0 が小さいと N_k(=k番目の混合要素に含まれるデータ点の個数の期待値)にばらつきが多く、0に縮退(後述)することもある。α_0 が大きいと、N_k のばらつきが少なくなりやすい?
    • 実験しても、そこまで明確な差は感じられなかった。
  • β_0 は適当。ν_0 はウィシャート分布の制限により ν_0 > D - 1
  • W_0 は単位行列
  • α_k = α_0 + N / K, β_k = β_0 + N / K, ν_k = ν_0 + N / K, m_k = m_0, W_k = W_0 として E ステップから始める。

→パラメータの傾向については フォロー記事 でもいろいろ調べてます。


PRML 図10.6 の「0回目」の各混合要素の分布

m_0 散らばってるし! m_k の初回の値、と考えれば問題ありません。

縮退

  • どのデータ点も説明しない混合要素 k については r_nK≒0 であり、したがって N_k≒0
    • 自然にその混合要素が除外された状態になる(→関連度自動決定(ARD))
  • (B.17) より混合比の期待値は  \mathbb{E}[\pi_k] = \frac{\alpha_k}{\textstyle\sum_{j=1}^K \alpha_j} = \frac{\alpha_0+N_k}{K\alpha_0+N}
    • α_0 が小さいときは N_k (つまり r_nk) の影響が強くなる
    • α_0 → ∞ のとき、混合係数の期待値は E[π_k] → 1/K となる。
  • 「α_0<1 の事前分布からは混合比のいくつかの要素が0になる解が選ばれやすい」(p193)
    • 「α_0=10^-3→2つの混合係数だけが非零」
    • 「α_0=1→非零の混合係数が3つ」
    • 「α_0=10→6つ全ての混合係数が非零」

と書いてあるが、Old Faithful で実験したところ α_0 が小さいからといって縮退が発生するとは限らないし、大きくても縮退が発生することもある。


→ここに書いてある「縮退しない」というのは間違いです。フォロー記事 をごらんください。

Old Faithful を VB で推論

パラメータをランダムに変えながら、縮退がどれくらいの割合で発生するか確認。

  • K=6
  • α_0 = 10^x, (-4<=x<=2)
  • β_0 = x^2, (1<=x<=6)
  • ν_0 = D + x, (-1<=x<1)
  • m_0 は正規乱数、W_0 は単位行列
    • 2000回試行して、混合要素が5個残った(1つつぶれた)のが 194回、4個残ったのが 26回、3個残ったのが4回、2個までつぶれることは無し。


なかなかこんなにうまくいかない……

【追記】バグが見つかった! thanx>tsubosakaさん →フォロー記事

その他気がついたことなど

  • 10章の全ての「下限」→「下界」
  • 10.2.2 の変分下界は「導出と実装が正しいかをテストすることが出来る」とあるが、10.2.1 の計算途中に出てくるような変数も大量に使っている。これでは、テストできた! と安心するわけにはいかない気がする……
  • 「(一部の)変数(群)の関数に対する下界や上界を見つけていく」ことを、どうして「局所的」変分法と呼ぶのかと小一時間(ry "parameterized VB" とか他にもっといい名前があるだろうに。しかも解析的に積分して解いちゃった日には、変分法と呼ぶのすら微妙……
  • ラプラス近似は MAP の周りでのヘシアンを一致させたガウス分布に「近似」、事後分布はデータ点が増えればとがってくるし、MAP の近傍でのヘシアンがあれば周辺化のための積分が計算できる、というのがラプラス近似がそこそこ悪くない直感的な理解だったと思う。そのラプラス近似より「精度の良い近似となっている」(p214)局所的変分法については、そういう直感的な理解はないのかな。「q(w) を正規化したら下界でなくなる」んだから、本当にラプラス近似より(いつも必ず)精度がいいんだろうか……