Tensorflow: Logistic Regression Mnist
2016-06-30 15:13
288 查看
import numpy as np import os import matplotlib.pyplot as plt import pprint # from sklearn.datasets import load_boston import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/', one_hot=True) train_img = mnist.train.images train_lbl = mnist.train.labels test_img = mnist.test.images test_lbl = mnist.test.labels print train_img.shape lr = 0.01 epoch = 50 batch_size = 100 snapshot = 5 x = tf.placeholder(tf.float32, [None, 784], name='input') y = tf.placeholder(tf.float32, [None, 10], name='groundtruth') w = tf.Variable(tf.random_normal([784, 10], stddev=0.5)) # w = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([1,10])) score = tf.matmul(x, w) + b prob = tf.nn.softmax(score) # loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(prob), reduction_indices=1)) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(score, y)) optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss) pred = tf.equal(tf.argmax(prob, 1), tf.argmax(y,1)) acc = tf.reduce_mean(tf.cast(pred, tf.float32)) init = tf.initialize_all_variables() sess = tf.Session() with tf.Session() as sess: sess.run(init) loss_cache = [] acc_cache = [] for ep in xrange(epoch): num_batch = mnist.train.num_examples/batch_size avg_loss = 0 for nb in xrange(num_batch): batch_x, batch_y = mnist.train.next_batch(batch_size) out = sess.run([optimizer, acc, loss], feed_dict={x:batch_x, y:batch_y}) avg_loss += out[2]/num_batch loss_cache.append(avg_loss) acc_cache.append(out[1]) if ep % snapshot ==0: print 'Epoch: %d, loss: %.4f, acc: %.4f'%(ep, avg_loss, acc_cache[-1]) print 'test accuracy:' , acc.eval({x:test_img, y:test_lbl}) plt.figure(1) plt.plot(range(len(loss_cache)), loss_cache, 'b-', label='loss') plt.legend(loc = 'upper right') plt.show()
plt.figure(2) plt.plot(range(len(acc_cache)), acc_cache, 'o-', label='acc') plt.legend(loc = 'lower right') plt.show()
# Epoch: 0, loss: 3.1894, acc: 0.3900 # Epoch: 5, loss: 0.7776, acc: 0.8300 # Epoch: 10, loss: 0.6080, acc: 0.8600 # Epoch: 15, loss: 0.5365, acc: 0.8500 # Epoch: 20, loss: 0.4944, acc: 0.9000 # Epoch: 25, loss: 0.4657, acc: 0.8700 # Epoch: 30, loss: 0.4442, acc: 0.9100 # Epoch: 35, loss: 0.4274, acc: 0.9000 # Epoch: 40, loss: 0.4136, acc: 0.8600 # Epoch: 45, loss: 0.4022, acc: 0.9000 # test accuracy: 0.8925
相关文章推荐
- C语言内存思考题
- dubbo
- Up to 8% free bonus for runescape 2007 gp on Rsorder as july best gift&Enjoy Telos During 7.1-7.22
- debian 6.0.10禁用r8169网卡驱动
- thinkphp使用自定义类方法
- 修改TFS与本地源代码映射路径
- BZOJ1112: [POI2008]砖块Klo
- HDFS的主要设计理念
- iOS中帮你轻松code之可复用代码块
- Terminating app due to uncaught execption'NSUnknownKeyException'的解决方式
- 纯Java实现控制台对数据库的增删改查(Eclipse)
- 开源视频会议SIP协议栈
- 关于程序进入包含EditText控件的界面会自动获取焦点并弹出软键盘影响用户体验的问题
- 用 NSIS制作64位安装包 步骤
- iOS开发中设置tabbar选中图标的颜色
- JMS -- 概念入门
- 每天一个命令(20) cut (remove sections from each line of files)
- SMA、SMB、SMC封装的二极管尺寸区分
- oracle存储过程举例
- Elasticsearch java API (21)查询 DSL 复合查询