変分ベイズ実装(PRML 10.2)

「Old Faithful の推論を K-means と EM について、Rで実装」の続き。


追記】実装にバグが見つかり、この記事の末尾の「うまく縮退しない」は間違いでした。→フォロー記事へ


PRML 10章、変分推論(変分ベイズ)がいまいちわからない。

  • 観測&隠し変数のハイパーパラメータにも事前分布を導入(ここが「ベイズ」)
  • 隠し変数+ハイパーパラメータ間に、「適当な独立性を仮定」して排反なグループに分割し、それぞれ任意の分布を想定(ここが「変分近似」???)
  • 個々のグループごとに、対数同時分布の事後期待値を最適化(ここは EM と同様の枠組み)

ということだろう、と理解したつもりだが、実感として掴めない。「排反なグループに分割」って漠然と言われてもなあ。
例によって、言葉を定義せずに使うし。「変分近似」って、何ね?


毎度ながら、こういうときは手を動かすに限る。
明日の PRML 読書会の範囲は 10.1 までだが、10.2 で変分ベイズによる混合ガウス分布の推論を具体的に取り扱っているので、そちらに従って実装していく。

変分ベイズアルゴリズム

PRML 9章では EM アルゴリズムの手順をまとめてくれていたが、10章の変分ベイズにはない。
そこでまず、PRML 10.2 を参照しながら混合ガウス分布の変分ベイズ推論のアルゴリズムをまとめておこう。

1. パラメータ α_0, m_0, β_0, ν_0, W_0 を初期化する。
それぞれのパラメータの意味は (10.39) および (10.40) 参照。


2. 「変分ベイズ E ステップ」
(10.65)〜(10.67) により負担率 r_nk を得る(※注:(10.66) のΨはψの誤植)。


3. 「変分ベイズ M ステップ」
r_nk を用いて、(10.51)〜(10.53) により統計量 N_k, x_k, S_k を求め、
それらを用いて、(10.58), (10.60)〜(10.63) によりパラメータ α_k, m_k, β_k, ν_k, W_k を更新する。


4. 2 & 3 を収束するまで繰り返す


パラメータの数、参照している式の数が手強さを物語っている……が、それ以前にこの手順にはいくつか問題がある。


ループの初回で、まず r_nk を求めようとする。
それを求めるには (10.67) から W_k などが必要で、それを求めるには (10.62) から N_k などが必要で、それを求めるには (10.51) から r_nk が必要……って、あれれ??


つまりパラメータが循環参照されていて、どこから始めたらいいのかわからない(上述のアルゴリズムにても、式番号が昇順になっていない)。
*_k = *_0 で初回を回す、というのをやってみたが、α_k たちが初回と2回目以降で値が変わりすぎるので、どうにも腑に落ちない。
変分ベイズの実装に関する他の説明などをみると、α_k = α_0 + N / K, β_k = β_0 + N / K, ν_k = ν_0 + N / K, m_k = m_0, W_k = W_0 で E ステップから始めるのが妥当っぽい。


また、パラメータの初期値をどうすればよいか。
PRML にて言及があるのは「 m_0 は対称性から0にする」と、α_0 の取り方によって収束の仕方が変わる(後述)の2点。
ところがそれを信じて、残りの β_0, ν_0, W_0 に適当な値を入れてアルゴリズムを実行すると、得られた負担率は必ず均等なものになり(すなわち r_nk = 1/K, ∀n, k )、当然 M ステップで得られる更新パラメータも k によらず一定になり……以下ループ。
なので初期値を k に対して散らしてあげるしかないのだが、α_0, β_0, ν_0, W_0 については k ごとに変える理由が全くない(アルゴリズムの実装上も嬉しくない)ので、m_0 をランダムにするのが最も妥当っぽい。PRML に書いてある内容とあまりにも違うことをしないといけないのが気になるけど……


初回 E ステップだけは m_k を散らすけど、m_0 は0にしておく、というのも思いついて試してみたが、m_k が0に収束していく。残念。

変分ベイズの実装


できるだけ上のアルゴリズムの逐語訳になるように実装した*1
パラメータの初期値は α_0=0.001, β_0=25, ν_0=D, W_0=単位行列, m_0=K個の正規乱数ベクトル としている。
βは適当、νは「正規化係数のガンマ関数が適切に定義されることを保証するように ν>D-1 と制限する」から。

# Old Faithful dataset を取得して正規化
data("faithful");
xx <- scale(faithful, apply(faithful, 2, mean), apply(faithful, 2, sd));

# クラス数
K <- 2;

# 1. パラメータ α_0, m_0, β_0, ν_0, W_0 を初期化する。
D <- ncol(xx);
N <- nrow(xx)
init_param <- list(
    alpha = 0.001,
    beta  = 25,
    nyu   = D,
    W     = diag(D),
    m     = matrix(rnorm(K * D), nrow=K)
);
param <- list(
    alpha = numeric(K) + init_param$alpha + N / K,
    beta  = numeric(K) + init_param$beta + N / K,
    nyu   = numeric(K) + init_param$nyu + N / K,
    W     = list(),
    m     = init_param$m
);
for(k in 1:K) param$W[[k]] <- init_param$W;


# 2. (10.65)〜(10.67) により負担率 r_nk を得る
VB_Estep <- function(xx, param) {
	K <- length(param$alpha);
	D <- ncol(xx);

	# (10.65)
	ln_lambda <- sapply(1:K, function(k) {
		sum(digamma((param$nyu[k] + 1 - 1:D) / 2)) + D * log(2) + log(det(param$W[[k]]));
	});

	# (10.66)
	ln_pi <- exp(digamma(param$alpha) - digamma(sum(param$alpha)));

	# (10.67)
	t(apply(xx, 1, function(x){
		quad <- sapply(1:K, function(k) {
			xm <- x - param$m[k,];
			t(xm) %*% param$W[[k]] %*% xm;
		});
		ln_rho <- ln_pi + ln_lambda / 2 - D / 2 / param$beta - param$nyu / 2 * quad;
		ln_rho <- ln_rho - max(ln_rho);   # exp を Inf にさせないよう
		rho <- exp(ln_rho);
		rho / sum(rho);
	}));
}


# 3. r_nk を用いて、(10.51)〜(10.53) により統計量 N_k, x_k, S_k を求め、
# それらを用いて、(10.58), (10.60)〜(10.63) によりパラメータ α_k, m_k, β_k, ν_k, W_k を更新する。
VB_Mstep <- function(xx, init_param, resp) {
	K <- ncol(resp);
	D <- ncol(xx);
	N <- nrow(xx);

	# (10.51)
	N_k <- colSums(resp);

	# (10.52)
	x_k <- (t(resp) %*% xx) / N_k;

	# (10.53)
	S_k <- list();
	for(k in 1:K) {
		S <- matrix(numeric(D * D), D);
		for(n in 1:N) {
			x <- xx[n,] - x_k[k,];
			S <- S + resp[n,k] * ( x %*% t(x) );
		}
		S_k[[k]] <- S / N_k[k];
	}

	param <- list(
	  alpha = init_param$alpha + N_k,    # (10.58)
	  beta  = init_param$beta + N_k,     # (10.60)
	  nyu   = init_param$nyu + N_k,      # (10.63)
	  W     = list()
	);

	# (10.61)
	param$m <- (init_param$beta * init_param$m + N_k * x_k) / param$beta;

	# (10.62)
	W0_inv <- solve(init_param$W);
	for(k in 1:K) {
		x <- x_k[k,] - init_param$m[k,];
		Wk_inv <- W0_inv + N_k[k] * S_k[[k]] + init_param$beta * N_k[k] * ( x %*% t(x)) / param$beta[k];
		param$W[[k]] <- solve(Wk_inv);
	}

	param;
}


# 以降、収束するまで繰り返し

resp <- VB_Estep(xx, param);
plot(xx, col=rgb(resp[,1],0,resp[,2]), xlab=paste(sprintf(" %1.3f",t(param$m)),collapse=","), ylab="");
points(param$m, pch = 8);
param <- VB_Mstep(xx, init_param, resp);


長い。これでもずいぶんがんばったのだが。
Rのプロに「Rで for 文ネストってwww」とか笑われそう。


ポイントは、ρ_nk を正規化して r_nk を得るところで、exp が発散してしまって r_nk が NaN だらけになるのを避けるために ln_rho <- ln_rho - max(ln_rho); しているところ*2


Old Faithful を推論して描いたチャートはこんな感じ。

*マークは m_k の位置を出しているので、K-means や EM の結果とは ずれている。
本当は E[μ_k] を出すべきなのだろうけど、ウィシャート分布の計算で心が折れて……

縮退?


PRML 図10.6 には、K=6, α_0=0.001 で Old Faithful を推論すると、4つの混合要素が縮退して(N_k が0に、混合係数の期待値が0に近づく)、2つの混合要素だけが残る様子を示し、K に大きな値を選んでも over fitting が起こらず、適切な混合要素数を交差検定などに頼らずに求められる可能性がある、と述べられている。


ほうほう、それはすごい。K=6 にして、さっそく試してみよう。
データ点を r_nk に応じて6色に塗り分け&縮退している様子がわかるようにラベル部に N_k が出るようにして、と……

plot(xx, xlab=paste(sprintf(" %1.3f",colSums(resp)),collapse=","), ylab="",
  col=rgb(rowSums(resp[,1:3])*0.9, rowSums(resp[,3:5])*0.8, rowSums(resp[,c(2,4,6)])*0.9));
points(param$m, pch = 8)

ダメだ。もう一回。

んー……


α_0, β_0, ν_0 についてあれこれ変えて何度も試してみたが、期待したような現象は一度も確認できなかった。
あるループで、1つの m_k が あさっての場所にあって、N_k が 1 未満の状態になっていたとしても、ループを回している内に少しずつデータ点に寄っていって、そのうち N_k がそこそこの値になってしまう。


PRML には、α<1 のディリクレ分布はいくつかの要素が0になりやすいから、と書いてあり、その説明でふむふむなるほどそうなんだー、と納得してしまっていたのだが……


追記】実装にバグが見つかり、「うまく縮退しない」は間違いでした。→フォロー記事へ

*1:S_k のあたりとか、明らかに最適化できるところも放置

*2:ソフトマックスで同様のことが起きるのを id:tsubosaka さんに教えてもらってて助かった