您的位置:首页 > 编程语言 > Python开发

tensorflow入门Day3-MNIST

2017-03-17 19:54 176 查看
# -*- coding: utf-8 -*-
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

x = tf.placeholder("float",[None,784])

y = tf.nn.softmax(tf.matmul(x,W)+b)

#真实标签占坑
y_ = tf.placeholder("float",[None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(1000):
#每一次迭代随机取100张图,取1000次,相当于取了100000张图,这100000张图可能有重复的
batch_xs,batch_ys = mnist.train.next_batch(100)
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})

#tf.argmax(y,1)在第1维上的最大值所在的位置
#y输出Tensor("Softmax:0", shape=(?, 10), dtype=float32) ?表示不知道维数
print y
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#tf.cast()表示将false转化为0,true转化为1
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
print sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
#直接输不出占位符,需要给模型喂东西才有输出,可以看到训练得到的标签和实际的标签
print sess.run(y,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
print sess.run(y_,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  python tensorflow ubuntu