您的位置:首页 > 理论基础 > 计算机网络

深度学习与TensorFlow实战(六)全连接网络基础—MNIST数据集输出手写数字识别准确率

2018-08-22 18:32 1126 查看

mnist数据集:包含7万张黑底白字手写数字图片,其中55000张作为训练集,5000张作为验证集,10000作为测试集。每张图片大小为28X28像素,图片中纯黑色像素值为0,纯白1。数据集的标签长度为10的一维数组,数组每个元素索引号表示对应数字出现的概率。
在将mnist数据集作为输入喂入神经网络时,需先将数据集中每张图片变成长度784一维数组,将该数组作为输入特征喂入神经网络。

from tensorflow.example.tutorials.minst import input_data
mnist=input_data.read_data_sets(‘.data/,one_hot=True)
第一个参数表示数据集存放路径,第二个参数表示数据集的存取形式。当第二个参数为true时,表示以独热码形式存取数据集。read_data_sets()函数运行时,会检查指定路径内是否已经有数据集,若指定路径没有数据集,则自动下载,并将mnist数据集分为训练集train,验证集validation和测试集test存放。

常用函数:
1),tf.get_collection(“”)表示从collection集合中取出全部变量生成一个列表。
2),tf.cast(x,dtype)表示将参数转为指定数据类型。
3),tf.equal()表示对比两个矩阵或向量的元素。若对应元素相等,则返回true,否则false
4),tf.reduce_mean(x,axis)表示求矩阵或张量指定维度的平均值。若不指定第二份参数,则在所有元素中去平均值,指定第二个参数为0,则每一列求平均值,指定第二个参数为1,则每一行求平均值。
5),tf.argmax(x,axis)表示返回指定维度axis下,参数x中最大索引号。
6),os.path.join()表示把参数字符串按照路径命名规则拼接。、
例如:
import os
os.path.join(‘/hello’,’/good/’,’boy/’)
输出结果:/hello/good/boy/
7),字符串.split()表示按照指定“拆分符”对字符串拆分,返回拆分列表。
例如:
‘./model/mnist_model-1001’.split(‘/’)[-1].split(‘-1’)[-1]
在该例子中,共进行两次拆分。第一个拆分’/’,返回拆分列表,并提取拆分列表,并提取列表索引为-1的元素即倒数第一个元素,第二个拆分符为‘-’返回拆分列表,并提取列表索引为-1的元素即倒数第一个元素,故函数返回值为1001.
8),tf,Graph().as_default()函数表示当前图设置为默认图,并返回一个上下文管理器。
例如:
with tf.Graph().as_default() as g,表示将在Graph()内定义的节点加入到计算图g中。

神经网络模型保存:
在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型,并产生三个文件(保存当前图结构的.meta文件、保存当前参数名的.index文件、保存当前参数的.data文件)

saver=tf.train.Saver()
with tf.Session() as sess:
for i in range(steps):
if i %轮数==0:
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

上述代码表示,神经网络没循环规定的轮数,将神经网络模型中所有的参数等信息保存到指定的路径中,并在存放模型的文件夹名称中注明保存模型时的训练轮数。

神经网络模型的加载:
在测试网络效果时,需要将训练好的神经网络模型加载:

with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(存储路径)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)

若ckpt和保存的模型在指定路径中,则将保存的模型加载到当前会话中

加载模型中参数的滑动平均值:
在保存模型时,若模型中采用滑动平均,则参数的滑动平均值会保存在相应的文件中。通过实例化saver对象,实现参数滑动平均值的加载。

ema=tf.train.ExponentialMovingAverage(滑动平均基数)
ema_restore=ema.variables_to_restory()
saver=tf.train.Saver(ema_restore)

神经网络模型准确率评估方法:
在网络评估时,一般通过计算在一组数据上的识别准确率,评估神经网络的效果。

correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.floar32))

在上述中,y表示形状为一组数据(batch_size)上神经网络模型的预测结果,y的形状为[batch_size,10],没一行表示一张图片的识别结果。通过,tf.argmax()函数取出每张图片对应向量中最大值元素对应的索引值,组成长度为输入数据batch_size个数的一维数组。通过tf.equal()判断预测结果张量和实际标签张量的每个维度是否相等,若相等则返回True。通过tf,cast()将得到的布尔型转化为实数型,在通过tf.reduce_mean()求平均值,最终得到模型在本组数据上的准确率。

代码展示:
fb_mnist.py文件代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

#定义前向传播
INPUT_NODE=784
OUTPUT_NODE=10
LAYER_NODE=500

def get_weighe(shape,regularizer):
w=tf.Variable(tf.random_normal(shape,stddev=0.1))

if regularizer!=None:
tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))
return w

def get_bias(shape):
b=tf.Variable(tf.zeros(shape))
return b

def forward(x,regularizer):
w1=get_weighe([INPUT_NODE,LAYER_NODE],regularizer)
b1=get_bias([LAYER_NODE])
y1=tf.nn.relu(tf.matmul(x,w1)+b1)

w2=get_weighe([LAYER_NODE,OUTPUT_NODE],regularizer)
b2=get_bias([OUTPUT_NODE])
y=tf.matmul(y1,w2)+b2
return y

#定义反向传播
BATCH_SIZE=200
LEARNING_RATE_BASE=0.1
LEARNING_RATE_DECAY=0.99
REGULARIZER=0.0001
MOVING_AVERAGE_DECAY=0.99
MODEL_SAVE_PATH="./model/"
MODEL_NAME="mnist_model"

def backward(mnist):
x=tf.placeholder(tf.float32,shape=[None,INPUT_NODE])
y_=tf.placeholder(tf.float32,shape=[None,OUTPUT_NODE])
y=forward(x,REGULARIZER)
global_step=tf.Variable(0,trainable=False)

#加入正则化
ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem=tf.reduce_mean(ce)
loss=cem+tf.add_n(tf.get_collection('losses'))

#加入学习衰减率
learning_rate=tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples/BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True
)

train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)

#加入滑动平均

ema=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
ema_op=ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step,ema_op]):
train_op=tf.no_op(name='train')

#保存模型
saver=tf.train.Saver()

#生成会话,训练模型
with tf.Session() as sess:
init_op=tf.global_variables_initializer()
sess.run(init_op)

for i in range(80000):
xs,ys=mnist.train.next_batch(BATCH_SIZE)
_,loss_value,step=sess.run(
[train_op,loss,global_step],
feed_dict={x:xs,y_:ys}
)
if i%1000==0:
print("after %d training step,loss on training date is%g"%(step,loss_value))
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

def main():
mnist=input_data.read_data_sets("./data/",one_hot=True)
backward(mnist)

if __name__=='__main__':
main()

test_mnist.py文件代码:

import tensorflow as tf
import fb_mnist
import time
from tensorflow.examples.tutorials.mnist import input_data

def test(mnist):
with tf.Graph().as_default() as g:
x=tf.placeholder(tf.float32,shape=[None,fb_mnist.INPUT_NODE])
y_=tf.placeholder(tf.float32,shape=[None,fb_mnist.OUTPUT_NODE])
y=fb_mnist.forward(x,None)
#通过实例化saver对象,实现参数滑动平均值的加载。
ema=tf.train.ExponentialMovingAverage(fb_mnist.MOVING_AVERAGE_DECAY)
ema_restore=ema.variables_to_restore()
saver=tf.train.Saver(ema_restore)

correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

while True:#循环
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(fb_mnist.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-1')[-1]
accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
print("after %s training steps,test accuracy=%g"%(global_step,accuracy_score))
else:
print("no checkpoint file found")
return
time.sleep(5)

def main():
mnist=input_data.read_data_sets("./data/",one_hot=True)
test(mnist)

if __name__=='__main__':
main()

结果:



断点训练:
关键处理:计入ckpt操作
ckpt=tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)

注解:tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
checkpoint_dir:表示存储断点文件的目录
latest_filename=None:断点存储文件的名称,默认为checkpoint

saver.restore(sess,ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将ckpt中的值赋给w,b
sess:表示当前会话,之前保存的结果将被加入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新是什么。

ckpt代码位置:

结果:

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐