PRML 読書会 #6 がありました。皆さんお疲れ様でした。
読書会の内容については、別途。
id:tsubosaka さんに「(ニューラルネットワークの実装で) XOR の学習できました?」と聞かれて「出来るように作ったはずだけど、まだ試してない〜」と答えたので、試した。
XOR は線形分離できないので、パーセプトロン等では単純には解けないが、ニューラルネットワークなら大丈夫! というのがメリットの一つなので、それを検証してみる。
正しく実装できていれば、きっと学習できるはず……!
入力2個、隠れユニット4個(tanh)、出力1個(シグモイド)のネットワークを構築して、XOR を学習させるスクリプト。
ただし誤差関数は交差エントロピーではなく二乗和誤差。
require "neural.rb" # training data (XOR) D = [ [[0, 0], [0]], [[1, 1], [0]], [[0, 1], [1]], [[1, 0], [1]], ] # units in_units = [Unit.new("x1"), Unit.new("x2")] bias = [BiasUnit.new("1")] hiddenunits = [TanhUnit.new("z1"), TanhUnit.new("z2"), TanhUnit.new("z3"), TanhUnit.new("z4")] out_unit = [SigUnit.new("y1")] # network network = Network.new network.in = in_units network.link in_units + bias, hiddenunits network.link hiddenunits + bias, out_unit network.out = out_unit # learning eta = 0.1 sum_e = 999999 1000.times do |tau| s = 0 D.each do |data| s += network.sum_of_squares_error(data[0], data[1]) end puts "sum of errors: #{tau} => #{s}" break if s > sum_e sum_e = s D.sort{rand}.each do |data| grad = network.gradient_E(data[0], data[1]) network.descent_weights eta, grad end end network.weights.dump
ざくっと実行。
誤差関数が増加に転じたら学習を停止する仕組みを入れたが、1000回で収束しなかった……。交差エントロピーならもうちょっと収束早いのかな−。
結果を R でグラフに。
f <- function(x1, x2) { #### xor.rb の結果(ここに挿入) z1 <- tanh( 0.969456061013559 * x1 - 1.74829473902389 * x2 - 0.70646976479347 ); z2 <- tanh( -2.82278910957702 * x1 - 3.0634918736921 * x2 + 1.09531741119177 ); z3 <- tanh( -2.20675222612653 * x1 - 1.27935644243531 * x2 - 1.19663674571256 ); z4 <- tanh( -1.71577976900074 * x1 - 0.786283825284134 * x2 + 1.67666557366823 ); y1 <- sig( 2.12626231146572 * z1 - 3.09386676517861 * z2 + 0.869358796456109 * z3 + 2.71748870208687 * z4 - 0.189921028360212 ); #### xor.rb の結果(ここまで) y1; } x1 <- seq(0,1,length=30) z <- outer(x1, x1, f) persp(x1, x1, z, phi=45)
ちなみに各点での分類関数の値はこんな感じ。
> f(0, 0); [1] 0.1047931 > f(0, 1); [1] 0.8540736 > f(1, 0); [1] 0.9078301 > f(1, 1); [1] 0.1509714
ちゃんと XOR 学習できてますね!
汎用的に作るの、それなりに苦労したけど、甲斐はあったようだ。
全ソースは github に。
http://github.com/shuyo/iir/tree/6148e3426933c423e10b045a7a368cbd02f7a47a/neural