EM アルゴリズム実装(勉強用)


最近忙しくて*1PRML の予習が滞り中。
しかし、次の PRML 読書会徒手空拳で行ったら、気持ちよく昇天してしまいそうなので、なんとか頑張って読んでみる。
EM アルゴリズムは何となくわかるが、変分ベイズがわからん……

というわけで、Old Faithful の混合正規分布での推論を K-means と EM と変分ベイズについて、Rで実装してみる。

K-means


Old Faithful + K-means については、すでに 前回の記事でお試し済み
その記事では、イテレーションを1行で書いてネタっぽくしてしまってたので、わかりやすく整理したのが以下のコード。
距離を取るところは少し変えて短くしてある。

# Old Faithful dataset を取得して正規化
data("faithful");
xx <- scale(faithful, apply(faithful, 2, mean), apply(faithful, 2, sd));

# クラス数
K <- 2;

# 中心の初期値(正規乱数)
mu <- matrix(rnorm(K*ncol(xx)), nrow=K);

# 以降、収束するまで繰り返し

# 各点について、一番近いμ_iを探す
nearest_index <- max.col(-sapply(1:K, function(i) {
	colSums((t(xx)-mu[i,])^2)
}));

# 各 i ごとにベクトルの平均を取り、新しいμとする
mu <- t(sapply(1:K, function(k) {
	colMeans(xx[nearest_index==k,]
})));


ライブラリを作るのが目的ではないので、繰り返しは手動。収束は雰囲気でw。
実行の様子は前回の記事参照。


実装のポイントその1は、 max.col() という各行ごとに最大値のインデックスを返す関数。実際に欲しいのは最小値なので、符号反転して突っ込んでいる( min.col() は無いのだ)。
実装のポイントその2は、そうやって得た nearest_index(一番近い u_i) を xx[nearest_index==k,] と使うことで、u_k に近いベクトルを抽出、colMeans() でその平均を取っているところ。


Rはこういうのが短く書けるのが嬉しい。
やり過ぎるとアクロバティックになって読めなくなるが……

EM Algorithm


EM アルゴリズムは「同時分布 P(X,Z) の最適化は容易に可能」という仮定の下、周辺分布 P(X) を最適化する代わりに、「事後分布のもとでの対数尤度の期待値 Σ P(Z|X) lnP(X,Z)」を最大化する、というもの。
PRML 9.4 では、いきなり KL ダイバージェンスに分解するだなんて大上段に振りかぶっていたが、Jensen 不等式から導出して「でもこれって KL だよね?」とかしてくれたほうが、個人的には腑に落ちた。


って、EM アルゴリズムについて説明するのが目的ではないからこれくらいにしておいて、EM アルゴリズムでの混合正規分布推論を実装。
「Rはこういうのが短く書けるのが嬉しい」と書いたしりから長い(苦笑)。


でも、一番最初に書き下した奴は、この倍以上あった。
共分散まわりがループを回す方法しか思いつかないのが敗因。
きっと1次元ならコード量3分の1になるんだけど。

# Old Faithful dataset を取得して正規化
data("faithful");
xx <- scale(faithful, apply(faithful, 2, mean), apply(faithful, 2, sd));

# クラス数
K <- 2;

# 平均、共分散、混合率の初期値(正規乱数)
mu <- matrix(rnorm(K*ncol(xx)), nrow=K);
mix <- numeric(K)+1/K;
sig <- list();
for(k in 1:K) sig[[k]] <- diag(ncol(xx));

# 多次元正規分布密度関数(パッケージ使えって?)
dmnorm <- function(x,mu,sig) {
	D <- length(mu);
	1/(sqrt((2 * pi)^D * det(sig))) * exp(- t(x-mu) %*% solve(sig) %*% (x-mu) / 2)[1];
}

# EM アルゴリズムの E ステップ
Estep <- function(xx, mu, sig, mix) {
	K <- nrow(mu);
	t(apply(xx, 1, function(x){
		numer <- sapply(1:K, function(k) {
			mix[k] * dmnorm(x, mu[k,], sig[[k]])
		});
		numer / sum(numer);
	}))
}

# EM アルゴリズムの M ステップ
Mstep <- function(xx, gamma_nk) {
	K <- ncol(gamma_nk);
	D <- ncol(xx);
	N <- nrow(xx);

	N_k <- colSums(gamma_nk);
	new_mix <- N_k / N;
	new_mu <- (t(gamma_nk) %*% xx) / N_k;

	new_sig <- list();
	for(k in 1:K) {
		sig <- matrix(numeric(D^2), D);
		for(n in 1:N) {
			x <- xx[n,] - new_mu[k,];
			sig <- sig + gamma_nk[n, k] * (x %*% t(x));
		}
		new_sig[[k]] <- sig / N_k[k]
	}

	list(new_mu, new_sig, new_mix);
}

# 以降、収束するまで繰り返し

gamma_nk <- Estep(xx, mu, sig, mix);
(ret <- Mstep(xx, gamma_nk));
mu <- ret[[1]]; sig <- ret[[2]]; mix <- ret[[3]];


追記西尾さんからの指摘 により、dmnorm の正規化係数が間違っていたのを修正。すまん&ありがとう。でも、dmnorm を呼んだ後に正規化しているので結果には影響なし。よかったよかった(ぇ【/追記


PRML 9.2.2 の「混合ガウス分布のためのEMアルゴリズム」の通りの実装で、E ステップと M ステップそれぞれわかれており、収束するまで交互に呼び出していく。


以下は実行を繰り返してほぼ収束したところ。
K-means とだいたい同じあたりに平均が出ていることがわかる。

> (ret <- Mstep(xx, gamma_nk));
[[1]]
      eruptions   waiting
[1,] -1.2716236 -1.207692
[2,]  0.7025575  0.667236

[[2]]
[[2]][[1]]
      eruptions    waiting
[1,] 0.05309447 0.02804473
[2,] 0.02804473 0.18232160

[[2]][[2]]
      eruptions    waiting
[1,] 0.13047113 0.06061833
[2,] 0.06061833 0.19503065


[[3]]
[1] 0.3558729 0.6441271


全出力を載せたいところだが、出力が多い上、PRML にも書いてあるとおり収束が遅いのでやめておいた。
K-means は5回も回せば収束してしまうのだが、EM だと15回以上かかる。
ちょうど今、別口で pLSI を囓っているのだが、こちらも収束しないしない。
online EM は収束が速いという噂なので、ちょっと試してみたいところ。


もちろんチャートも描いておこう。
各データ点の負担率(どのガウス分布から生成されているか)は隠れ変数の事後確率 p(z_nk=1|X) であり、それが γ(z_nk) だったわけだから、E ステップの返値 gamma_nk を使えば、PRML の図 9.8 のような図を描くことが出来る*2

plot(xx, col=rgb(gamma_nk[,1],0,gamma_nk[,2]), xlab=paste(sprintf("%1.3f",t(mu)),collapse=","), ylab="");
points(mu, col = 1:2, pch = 8)


12回の繰り返しまであげてみたが、平均がじりじり動いていく様子が見えるだろうか。
よく見ると、まだ全然収束していない。ほんと遅い。


ちなみに、一応まじめに実装してあるので、K=3 とかにしてもちゃんと動く。
3次元以上でも動く……と思うけど、試してない。



あと、PRML 9.2.1 で説明している、ガウス分布の1つが1点でつぶれてしまうような特異性を再現させてみた。
平均の片方をデータ点の1つに一致させながらループを回すくらいのことでは全然発生しなかったが、分散を恣意的に小さくしてやると、そのまま0になってしまうことが確認できた。


長くなったので、変分ベイズはまた次回。

*1:斬撃のレギンレイヴ」というゲームがありまして。

*2:等高線は……