oikakerublogの日記

知識ゼロから色々しらべてみた話し

Tensorflow オートエンコーダ(Qiita記事を見ながら写経)

◆Mnistデータをオートエンコーダで学習してみる。

【Tensorflow、オートエンコーダ、エンコード、デコード、mnist】

☞ 参考にした記事
http://qiita.com/mokemokechicken/items/8216aaad36709b6f0b5c



f:id:oikakerublog:20170519165332p:plain

# 2017-5-19 Autoencoder、Tfチュートリアル
# 参考にした記事
# Qiita TensorFlowで機械学習と戯れる: AutoEncoderを作ってみる

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt

# データ読込み
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
H = 50
BATCH_SIZE = 100
DROP_OUT_RATE = 0.5
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

#重みW、バイアスbの変数定義

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

# 入力層 x 28*28=784
x = tf.placeholder(tf.float32, [None, 784])

# Variable: W, b1
W = weight_variable((784, H))
b1 = bias_variable([H])

# 隠れ層 h ソフトサイン関数
h = tf.nn.softsign(tf.matmul(x, W) + b1)
keep_prob = tf.placeholder("float")
h_drop = tf.nn.dropout(h, keep_prob)

# デコード側の変数、W2は転置、ReLu関数
W2 = tf.transpose(W)
b2 = bias_variable([784])

y = tf.nn.relu(tf.matmul(h_drop, W2) + b2)

# loss関数
loss = tf.nn.l2_loss(y - x) / BATCH_SIZE
In [34]:
# Adam Optimizer
train_step = tf.train.AdamOptimizer().minimize(loss)

# 初期化、Session
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# トレーニング
for step in range(2000):
    batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
    sess.run(train_step, feed_dict={x: batch_xs, keep_prob: (1-DROP_OUT_RATE)})
    
    # Collect Summary
    summary_op = tf.merge_all_summaries()
    
    # Print Progress
    if step % 100 == 0:
        print(loss.eval(session=sess, feed_dict={x: batch_xs, keep_prob: 1.0}))

# => 
38.9964
20.0019
15.4875
12.9807
12.3335
10.0853
10.5489
10.4406
10.1243
9.82316
9.74813
9.9554
9.81496
10.0637
9.55074
9.21692
8.61834
8.72896
8.79601
9.06065

# Draw Encode/Decode Result
N_COL = 10
N_ROW = 2
plt.figure(figsize=(N_COL, N_ROW*2.5))
batch_xs, _ = mnist.train.next_batch(N_COL*N_ROW)
for row in range(N_ROW):
    for col in range(N_COL):
        i = row*N_COL + col
        data = batch_xs[i:i+1]

        # Draw Input Data(x)
        plt.subplot(2*N_ROW, N_COL, 2*row*N_COL+col+1)
        plt.title('IN:%02d' % i)
        plt.imshow(data.reshape((28, 28)), cmap="magma", clim=(0, 1.0), origin='upper')
        plt.tick_params(labelbottom="off")
        plt.tick_params(labelleft="off")

        # Draw Output Data(y)
        plt.subplot(2*N_ROW, N_COL, 2*row*N_COL + N_COL+col+1)
        plt.title('OUT:%02d' % i)
        y_value = y.eval(session=sess, feed_dict={x: data, keep_prob: 1.0})
        plt.imshow(y_value.reshape((28, 28)), cmap="magma", clim=(0, 1.0), origin='upper')
        plt.tick_params(labelbottom="off")
        plt.tick_params(labelleft="off")

plt.savefig("result.png")
plt.show()
広告を非表示にする