您的位置:首页 > 其它

Tensorflow: Logistic Regression Mnist

2016-06-30 15:13 288 查看
import numpy as np
import os
import matplotlib.pyplot as plt
import pprint
# from sklearn.datasets import load_boston
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data/', one_hot=True)
train_img = mnist.train.images
train_lbl = mnist.train.labels
test_img  = mnist.test.images
test_lbl  = mnist.test.labels

print train_img.shape

lr = 0.01
epoch = 50
batch_size = 100
snapshot = 5

x = tf.placeholder(tf.float32, [None, 784], name='input')
y = tf.placeholder(tf.float32, [None, 10], name='groundtruth')
w = tf.Variable(tf.random_normal([784, 10], stddev=0.5))
# w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([1,10]))

score = tf.matmul(x, w) + b
prob = tf.nn.softmax(score)
# loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(prob), reduction_indices=1))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(score, y))

optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
pred = tf.equal(tf.argmax(prob, 1), tf.argmax(y,1))
acc = tf.reduce_mean(tf.cast(pred, tf.float32))

init = tf.initialize_all_variables()

sess = tf.Session()
with tf.Session() as sess:
sess.run(init)

loss_cache = []
acc_cache = []
for ep in xrange(epoch):
num_batch = mnist.train.num_examples/batch_size
avg_loss = 0
for nb in xrange(num_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
out = sess.run([optimizer, acc, loss], feed_dict={x:batch_x, y:batch_y})
avg_loss += out[2]/num_batch

loss_cache.append(avg_loss)
acc_cache.append(out[1])
if ep % snapshot ==0:
print 'Epoch: %d, loss: %.4f, acc: %.4f'%(ep, avg_loss, acc_cache[-1])

print 'test accuracy:' , acc.eval({x:test_img, y:test_lbl})

plt.figure(1)
plt.plot(range(len(loss_cache)), loss_cache, 'b-', label='loss')
plt.legend(loc = 'upper right')
plt.show()




plt.figure(2)
plt.plot(range(len(acc_cache)), acc_cache, 'o-', label='acc')
plt.legend(loc = 'lower right')
plt.show()




# Epoch: 0, loss: 3.1894, acc: 0.3900
# Epoch: 5, loss: 0.7776, acc: 0.8300
# Epoch: 10, loss: 0.6080, acc: 0.8600
# Epoch: 15, loss: 0.5365, acc: 0.8500
# Epoch: 20, loss: 0.4944, acc: 0.9000
# Epoch: 25, loss: 0.4657, acc: 0.8700
# Epoch: 30, loss: 0.4442, acc: 0.9100
# Epoch: 35, loss: 0.4274, acc: 0.9000
# Epoch: 40, loss: 0.4136, acc: 0.8600
# Epoch: 45, loss: 0.4022, acc: 0.9000
# test accuracy: 0.8925
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: