您的位置:首页 > 编程语言 > Python开发

学习笔记(五)Tensorflow实现Soft Regression简单识别MNIST手写数字

2018-01-15 14:48 956 查看
MNIST数据库官网

推荐阅读Tensorflow官方文档中文版 MNIST机器学习入门

以及黄文坚、唐源编著的《TensorFlow实战》

一、加载MNIST数据

在这之前需要先安装IPYTHON以便能在terminal中编写python代码

sudo apt-get install ipython-notebook


虽然直接输入python命令也可以编写python代码,但ipython的优点是提供了代码自动补全,自动缩进,高亮显示等功能。

打开terminal输入

ipython




输入以下代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)  #在MNIST_data目录中加载MNIST数据集

import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

tf.global_variables_initializer().run()   #执行

for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))   #打印准确率


每行代码的含义在《TensorFlow实战》中有详细介绍

极客学院中TensorFlow官方文档中文版 MNIST机器学习入门也有介绍

如下图是运行成功后的截图:



如图最后一行可以看到训练的结果的平均准确率为91.99%

以上例子仅是使用tensorflow实现了一个简单的机器学习算法Soft Regressin,结果是较为不精确的,这可以算作是一个没有隐含层的最浅的神经网络。

整个流程我们做的事情有4个部分:

(1)定义算法公式
(2)定义loss,选定优化器
(3)迭代地对数据进行训练
(4)在测试集或验证集上对准确率进行评测


这几个步骤是我们使用TensorFlow进行算法设计、训练的核心步骤,也贯穿其他类型神经网络。接下来我将继续学习其他算法提高准确率

二、可能出现的错误

1.

version `CXXABI_1.3.8' not found




解决方法:

conda update libgcc
cd ~/anaconda3/lib
mv libstdc++.so libstdc++.so.bkp
mv libstdc++.so.6 libstdc++.so.6.bkp
进一步(可选)在anoconda库中创建一个软链接 ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 libstdc++.so.6


解决方法参考网址:https://stackoverflow.com/questions/39844772/cxxabi-1-3-8-not-found-in-tensorflow-gpu-install-from-source

2.

name 'input_data' is not defined




那么由这两行代码

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)


替代

import tensorflow.examples.tutorials.mnist.input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


这两行代码的作用是在MNIST_data目录中加载生成这四个小可爱

其实就是MNIST的训练集和测试集





3.报错

urlopen error [Errno 101] Network is unreachable


大概是网络不稳,重新再来一次就行了

三、查看MNIST数据集的情况

print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)




可以看到训练集有55000个样本,测试集有10000个样本,验证集有5000个样本

每个样本都有它对应的标注信息,就是label

训练数据是55000×784的Tensor:



训练数据label是一个55000×10的Tensor:



只有一个值为1,其余为0,label是一个十维向量。

如图,0为[1,0,0,0,0,0,0,0,0,0,0]

1为[0,1,0,0,0,0,0,0,0,0,0]以此类推,数字n就表示为对应位置的值为1。

四、Soft Regression算法

以上代码使用了Soft Regression算法,这个算法是训练手写数字识别的分类模型。

数字0~9一共有10个类别,Soft Regression对每一个类别估算概率,比如当模型对数字3的图片进行预测时,假设预测是数字3的概率为80%,预测是数字5的概率是5%,最后取概率最大的那个为模型的输出结果。



如图,明亮区域代表负的权重,灰暗区域代表真正的权重。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐