Tensorflow オートエンコーダでmnistを学習(Qiita記事を見ながら..)
◆目的:
・mnistデータをオートエンコーダで学習してみる。
◆キーワード:
・Tensorflow、オートエンコーダ、エンコード/デコード、mnist
・まず読込み。tensorflowのチュートリアル。mnistデータをダウンロード。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
・ニューラルネットワークの構造、ハイパーパラメータを決める。
h = 50 batch_size = 100 dropout_rate = 0.5
・重みやバイアスの定義
def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial)
・入力層を定義、データが入るところなので”placeholder”で。
x = tf.placeholder(tf.float32, [None, 784])
・隠れ層(W、b1)
W = weight_variable((784, h))
b1 = bias_variable([h])
・隠れ層の活性化関数はソフトサイン関数
h = tf.nn.softsign(tf.matmul(x, W) + b1)
keep_prob = tf.placeholder("float")
h_drop = tf.nn.dropout(h, keep_prob)
・出力層、変数、W2は転置、ReLu関数
・tf.transposeの例 こんな感じで行↔列が入れ替えられる
W2 = tf.transpose(W)
b2 = bias_variable([784])
y = tf.nn.relu(tf.matmul(h_drop, W2) + b2)
・損失関数
loss = tf.nn.l2_loss(y - x) / batch_size
・adamオプティマイザー、初期化、
train_step = tf.train.AdamOptimizer().minimize(loss) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init)
・トレーニング
for step in range(5000): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x: batch_xs, keep_prob: (1-dropout_rate)}) # Collect Summary summary_op = tf.merge_all_summaries() # Print Progress if step % 200 == 0: print(loss.eval(session=sess, feed_dict={x: batch_xs, keep_prob: 1.0}))
⇒ 問題!
”module 'tensorflow' has no attribute 'merge_all_summaries'”のエラー
⇒ 対処
・記述を一部修正
tf.summary.merge_all
実行すると、誤差関数の値が順次更新(下がっていく)
38.9383
11.8077
9.61637
…略…
6.01636
6.32462
5.61356
・結果の出力
N_COL = 5 N_ROW = 1 plt.figure(figsize=(N_COL, N_ROW*2.5)) batch_xs, _ = mnist.train.next_batch(N_COL*N_ROW) for row in range(N_ROW): for col in range(N_COL): i = row*N_COL + col data = batch_xs[i:i+1] # Draw Input Data(x) plt.subplot(2*N_ROW, N_COL, 2*row*N_COL+col+1) plt.imshow(data.reshape((28, 28)), cmap="gray", clim=(0, 1.0), origin='upper') plt.tick_params(labelbottom="off") plt.tick_params(labelleft="off") # Draw Output Data(y) plt.subplot(2*N_ROW, N_COL, 2*row*N_COL + N_COL+col+1) y_value = y.eval(session=sess, feed_dict={x: data, keep_prob: 1.0}) plt.imshow(y_value.reshape((28, 28)), cmap="gray", clim=(0, 1.0), origin='upper') plt.tick_params(labelbottom="off") plt.tick_params(labelleft="off") plt.savefig("result.png") plt.show()
◆ 参考にした記事
http://qiita.com/mokemokechicken/items/8216aaad36709b6f0b5c