您的位置:首页 > 编程语言 > Python开发

tensorflow 学习专栏(四):使用tensorflow在mnist数据集上使用逻辑回归logistic Regression进行分类

2018-02-13 15:17 951 查看
在面对分类问题时,我们常用的一个算法便是逻辑回归(logistic Regression)
在本次实验中,我们的实验对象是mnist手写数据集,在该数据集中每张图像包含28*28个像素点如下图所示:



我们使用逻辑回归算法来对mnist数据集的数据进行分类,判断图像所表示的数字是几。
代码如下:import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

tf.set_random_seed(1)
np.random.seed(1)

BATCH_SIZE = 50
LR = 0.001

mnist = input_data.read_data_sets('./mnist',one_hot=True) #导入MNIST数据集
test_x = mnist.test.images[:2000] #将MNIST.TEST前2000个数据设置为测试数据集
test_y = mnist.test.labels[:2000]

x = tf.placeholder(tf.float32,[None,784])/255.
y = tf.placeholder(tf.int32,[None,10])

def addlayer(input,in_size,out_size,activiation_function=None): #定义addlayer函数
Weight = tf.Variable(tf.zeros([in_size,out_size]))
Baise = tf.Variable(tf.zeros([out_size]))
wx_b = tf.matmul(input,Weight)+Baise
if activiation_function is None:
out = wx_b
else:
out = activiation_function(wx_b)
return out

#build model
pred = addlayer(x,784,10,tf.nn.softmax) #构建模型

loss = tf.losses.softmax_cross_entropy(onehot_labels=y,logits=pred) #计算误差
train = tf.train.AdamOptimizer(LR).minimize(loss) #训练优化
accuracy = tf.metrics.accuracy(labels=tf.argmax(y,axis=1),predictions=tf.argmax(pred,axis=1),)[1]
#计算准确率
sess = tf.Session() #初始化
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

for step in range(10000): #训练
b_x,b_y = mnist.train.next_batch(BATCH_SIZE)
_,loss_ = sess.run([train,loss],feed_dict={x:b_x,y:b_y})
if step%50==0:
accuracy_ = sess.run(accuracy,feed_dict={x:test_x,y:test_y})
print('train loss:%.4f'%loss_, '|test accuracy%.4f'%accuracy_)

for i in range(5): #将test数据集前5个数据进行可视化
X = test_x[i][np.newaxis,:]
Y = test_y[i]
test_output = sess.run(pred,feed_dict={x:X})
pred_y = np.argmax(test_output,axis=1)
real_y = np.argmax(Y)
img = X.reshape((28,28))
plt.imshow(img,cmap='gray')
plt.text(1.5,2.5,'real number=%.4f'%real_y,fontdict={'size':20,'color':'green'})
plt.text(1.5,5,'pred number=%.4f'%pred_y,fontdict={'size':20,'color':'red'})
plt.show()

训练结果如下:



测试数据集中前五个数据可视化结果如下:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐