您的位置:首页 > 其它

TensorFlow学习记录2——基于softmax回归的分类算法

2019-05-29 16:16 417 查看
版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。 本文链接:https://blog.csdn.net/JavenLau/article/details/90673061

TensorFlow学习记录2——基于softmax回归的分类算法


主要参考博客
LightRNN:深度学习之以小见大

深入MNIST——

import tensorflow as tf
import numpy as np
import matplotlib.pylab as plt

# prepare mnist data
from tensorflow.examples.tutorials.mnist import input_data
MNIST_data_folder = "D:\pycharm\OCR_Test_SH\src\LightRNN\MNIST_data"
mnist = input_data.read_data_sets(MNIST_data_folder, one_hot=True)

im = mnist.train.images[1]
im=im.reshape(-1,28)
print('input:', mnist.train.images.shape)
# plt.imshow(im)
# plt.show()

sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())

y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
print(i, accuracy.eval(feed_dict={x: batch[0], y_: batch[1]}))

print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
for i in range(0, len(mnist.test.images)):
result_prediction = correct_prediction.eval(feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
if not result_prediction:
label_arr = sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
predict_arr = sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})
label = tf.argmax(label_arr, 1)
predict = tf.argmax(predict_arr, 1)
print('the predict result of {} is error, the label is {} and the predict is {}'.format(i, label, predict))
current_image_arr = np.reshape(mnist.test.images[i], (28, 28))
current_image = np.matrix(current_image_arr, dtype="float")
plt.imshow(current_image)
plt.show()
# break

待做:代码46行

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: