ニューラルネットワークでXORを学習させてみた


PRML 読書会 #6 がありました。皆さんお疲れ様でした。
読書会の内容については、別途。


id:tsubosaka さんに「(ニューラルネットワークの実装で) XOR の学習できました?」と聞かれて「出来るように作ったはずだけど、まだ試してない〜」と答えたので、試した。


XOR は線形分離できないので、パーセプトロン等では単純には解けないが、ニューラルネットワークなら大丈夫! というのがメリットの一つなので、それを検証してみる。
正しく実装できていれば、きっと学習できるはず……!


入力2個、隠れユニット4個(tanh)、出力1個(シグモイド)のネットワークを構築して、XOR を学習させるスクリプト
ただし誤差関数は交差エントロピーではなく二乗和誤差。

require "neural.rb"

# training data (XOR)
D = [
  [[0, 0], [0]],
  [[1, 1], [0]],
  [[0, 1], [1]],
  [[1, 0], [1]],
]

# units
in_units = [Unit.new("x1"), Unit.new("x2")]
bias = [BiasUnit.new("1")]
hiddenunits = [TanhUnit.new("z1"), TanhUnit.new("z2"), TanhUnit.new("z3"), TanhUnit.new("z4")]
out_unit = [SigUnit.new("y1")]

# network
network = Network.new
network.in  = in_units
network.link in_units + bias, hiddenunits
network.link hiddenunits + bias, out_unit
network.out = out_unit

# learning
eta = 0.1
sum_e = 999999
1000.times do |tau|
  s = 0
  D.each do |data|
    s += network.sum_of_squares_error(data[0], data[1])
  end
  puts "sum of errors: #{tau} => #{s}"
  break if s > sum_e
  sum_e = s

  D.sort{rand}.each do |data|
    grad = network.gradient_E(data[0], data[1])
    network.descent_weights eta, grad
  end
end
network.weights.dump

ざくっと実行。
誤差関数が増加に転じたら学習を停止する仕組みを入れたが、1000回で収束しなかった……。交差エントロピーならもうちょっと収束早いのかな−。


結果を R でグラフに。

f <- function(x1, x2) {
#### xor.rb の結果(ここに挿入)
z1 <- tanh( 0.969456061013559 * x1 - 1.74829473902389 * x2 - 0.70646976479347 );
z2 <- tanh( -2.82278910957702 * x1 - 3.0634918736921 * x2 + 1.09531741119177  );
z3 <- tanh( -2.20675222612653 * x1 - 1.27935644243531 * x2 - 1.19663674571256 );
z4 <- tanh( -1.71577976900074 * x1 - 0.786283825284134 * x2 + 1.67666557366823 );
y1 <- sig( 2.12626231146572 * z1 - 3.09386676517861 * z2 + 0.869358796456109 * z3 + 2.71748870208687 * z4 - 0.189921028360212 );
#### xor.rb の結果(ここまで)
y1;
}

x1 <- seq(0,1,length=30)
z <- outer(x1, x1, f)
persp(x1, x1, z, phi=45)


ちなみに各点での分類関数の値はこんな感じ。

> f(0, 0);
[1] 0.1047931
> f(0, 1);
[1] 0.8540736
> f(1, 0);
[1] 0.9078301
> f(1, 1);
[1] 0.1509714


ちゃんと XOR 学習できてますね!
汎用的に作るの、それなりに苦労したけど、甲斐はあったようだ。


全ソースは github に。

http://github.com/shuyo/iir/tree/6148e3426933c423e10b045a7a368cbd02f7a47a/neural