Tensorflow框架搭建全连接神经网络训练手写数字mnist数据集
本文将用Tensorflow框架训练Mnist数据集,搭建全连接神经网络,损失将以动态折线图方式展示
全连接神经网络如图所示:
Mnist数据集是0-9十个数字构成的图片形式的数据集,每张图片是28*28的大小在这里插入图片描述
导入tensorflow中带的mnist数据集,以one-hot的形式:
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets(".\MNIST_data",one_hot=True)
这里建立一个Net的类,self.x是神经网络数据的输入,用占位符tf.placeholder占位,输入的形状是[N,V]结构N是批次,V为28*28的=784的数据,整张图片不能直接传入神经网络,每张图片是28乘以28,要变成784乘以1,即把每个像素挨个排列送进网络。
这里用的是两层的神经网络,w是权重,截取自标准正态分布,b为偏置,设为0,因为是10分类问题,所以最后的输出有10个
感知机模型
class Net: def __init__(self): self.x = tf.placeholder(dtype=tf.float32,shape=[None,784]) self.y = tf.placeholder(dtype=tf.float32,shape=[None,10]) self.w1= tf.Variable(tf.truncated_normal(shape=[784,256],stddev=0.01,dtype=tf.float32)) self.b1= tf.Variable(tf.zeros(shape=[256],dtype=tf.float32)) self.w2= tf.Variable(tf.truncated_normal(shape=[256,10],stddev=0.01,dtype=tf.float32)) self.b2= tf.Variable(tf.zeros(shape=[10],dtype=tf.float32))
定义前向:
根据公式f(wx+b),f为激活函数,第一层的输出作为第二层的输入,第一层用rule激活,最后一层用softmax激活(多分类问题最后一层要用softmax激活)
def forward(self): y1 = tf.nn.relu(tf.matmul(self.x,self.w1)+self.b1) self.y2 = tf.matmul(y1,self.w2)+self.b2 self.output = tf.nn.softmax(self.y2)
定义损失函数:loss,使用交叉熵损失函数softmax_cross_entropy_with_logits,具体用法请自行学习
def loss(self): self.error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y,logits=self.y2))
定义后向函数backward,使用Adam优化器优化损失,学习率为0.001
def backward(self): self.optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(self.error)
之后就是主函数,实例化,喂数据,训练、验证,并使用matplotlib将损失以动态折线图的形式展示出来,下面是全部程序
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets(".\MNIST_data",one_hot=True)# mnist = ".\MNIST_data" import matplotlib.pyplot as plt class Net: def __init__(self): self.x = tf.placeholder(dtype=tf.float32,shape=[None,784]) self.y = tf.placeholder(dtype=tf.float32,shape=[None,10]) self.w1= tf.Variable(tf.truncated_normal(shape=[784,256],stddev=0.01,dtype=tf.float32)) self.b1= tf.Variable(tf.zeros(shape=[256],dtype=tf.float32)) self.w2= tf.Variable(tf.truncated_normal(shape=[256,10],stddev=0.01,dtype=tf.float32)) self.b2= tf.Variable(tf.zeros(shape=[10],dtype=tf.float32))def forward(self): y1 = tf.nn.relu(tf.matmul(self.x,self.w1)+self.b1) self.y2 = tf.matmul(y1,self.w2)+self.b2 self.output = tf.nn.softmax(self.y2) def loss(self): self.error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y,logits=self.y2)) def backward(self): self.optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(self.error) def accuracy(self): y = tf.equal(tf.argmax(self.output,axis=1),tf.argmax(self.y,axis=1)) self.acc = tf.reduce_mean(tf.cast(y,dtype=tf.float32)) if __name__ == '__main__': net= Net() net.forward() net.loss() net.backward() net.accuracy() init = tf.global_variables_initializer() plt.ion() a=[] b=[] c=[] with tf.Session() as sess: sess.run(init) for i in range(50000): xs,ys = mnist.train.next_batch(100) error,_ = sess.run([net.error,net.optimizer],feed_dict={net.x:xs,net.y:ys}) if i%100 == 0: xss,yss = mnist.validation.next_batch(100) _error,_output,acc = sess.run([net.error,net.output,net.acc],feed_dict={net.x:xss,net.y:yss}) label= np.argmax(yss[0]) out = np.argmax(_output[0]) print("error:",error) print("label:",label,"output:",out) print(acc) a.append(i) b.append(error) c.append(_error) plt.clf() train, = plt.plot(a,b,linewidth = 1,color = "red") validate, = plt.plot(a,c,linewidth = 1, color = "blue") plt.legend([train,validate],["train","validate"],loc= "right top",fontsize = 10) plt.pause(0.01) plt.ioff()
运行之前请确认mnist数据集是否已经加载进来,如果没有要自行下载mnist数据集并粘贴到这里
运行结果:这里不展示动态图,只截取了刚开始运行时的损失和训练一段时间之后的损失
刚开始训练的损失
训练一段时间之后的损失
结论:用全连接神经网络训练mnist数据集,可以得到较好的效果,不过全连接的计算量比较大,如果用来训练较为复杂的数据,运行速度比较慢,精度低,效果不好,所以后面会介绍卷积神经网络CNN。
如果转载或引用请注明来源!
- 点赞
- 收藏
- 分享
- 文章举报
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 深度学习-传统神经网络使用TensorFlow框架实现MNIST手写数字识别
- 使用tensorflow利用神经网络分类识别MNIST手写数字数据集,转自随心1993
- 深度学习与TensorFlow实战(六)全连接网络基础—MNIST数据集输出手写数字识别准确率
- 【手把手TensorFlow】三、神经网络搭建完整框架+MNIST数据集实践
- 对抗神经网络学习(一)——GAN实现mnist手写数字生成(tensorflow实现)
- tensorflow训练mnist数据集-识别手写数字
- 用tensorflow搭建mnist全连接神经网络
- 神经网络与深度学习 使用Python实现基于梯度下降算法的神经网络和自制仿MNIST数据集的手写数字分类可视化程序 web版本
- MNIST手写字体数据集神经网络实现(tensorflow)
- 利用tensorflow框架搭建CNN网络解析mnist数据集(5)---《深度学习》
- 搭建全连接神经网络识别mnist数据集(下)
- Tensorflow-RNN循环网络-手写数字识别(MNIST数据集)
- 神经网络与深度学习 1.6 使用Python实现基于梯度下降算法的神经网络和MNIST数据集的手写数字分类程序
- 手写汉字数字识别详细过程(构建数据集+CNN神经网络+Tensorflow)
- Tensorflow框架搭建卷积神经网络CNN训练mnist数据集
- 神经网络入门之Tensorflow实战一:MNIST数据集的训练与预测
- python Tensorflow三层全连接神经网络实现手写数字识别
- 基于TensorFlow1.4.0的FNN全连接网络识别MNIST手写数据集
- 搭建全连接神经网络识别mnist数据集(上)