tensorflow学习(4):保存模型Saver.save()的参数命名机制以及restore并创建手写字体识别引擎
2017-03-26 22:20
761 查看
前言
上一章中我们讲到如何训练一个网络,点击查看博客,这章我们来讲tensorflow在保存网络的时候是怎么给不同的参数命名的,以及怎么将保存的参数重新restore到重构的网络结构中的。最后利用重构的网络去预测一张包含数字(0-9)的图片(任意像素)。代码主要参考github:github地址
正文
一、如何查看保存到二进制文件中保存的参数tensorflow同样提供了以下方法来查看保存的参数,将保存的参数读取到字典中。
from tensorflow.python import pywrap_tensorflow reader2 = pywrap_tensorflow.NewCheckpointReader('./model2/mnistModel2-2') dic2 = reader2.get_variable_to_shape_map() for i in dic2: print(i,':',dic2[i]) print(len(dic2))
可以看到上面代码的输出如下:
('Variable_7/Adam', ':', [10]) ('Variable_7', ':', [10]) ('Variable_6', ':', [1024, 10]) ('Variable_5', ':', [1024]) ('Variable_4', ':', [3136, 1024]) ('Variable/Adam', ':', [5, 5, 1, 32]) ('Variable_2', ':', [5, 5, 32, 64]) ('Variable_1', ':', [32]) ('Variable_5/Adam_1', ':', [1024]) ('Variable_4/Adam_1', ':', [3136, 1024]) ('Variable_2/Adam', ':', [5, 5, 32, 64]) ('Variable_7/Adam_1', ':', [10]) ('Variable', ':', [5, 5, 1, 32]) ('Variable_5/Adam', ':', [1024]) ('Variable_4/Adam', ':', [3136, 1024]) ('Variable_1/Adam_1', ':', [32]) ('Variable_6/Adam_1', ':', [1024, 10]) ('beta2_power', ':', []) ('Variable_1/Adam', ':', [32]) ('beta1_power', ':', []) ('Variable_3/Adam_1', ':', [64]) ('Variable/Adam_1', ':', [5, 5, 1, 32]) ('Variable_3/Adam', ':', [64]) ('Variable_6/Adam', ':', [1024, 10]) ('Variable_3', ':', [64]) ('Variable_2/Adam_1', ':', [5, 5, 32, 64])
可以看出(据我猜测),如果你在定义saver的时候没有自己给tensor取名字,那么,tensorflow会按照自己的方式给你的tensor取名字。取名字的方式如下:
如果tensor是常量,那么就按规律取名:Const,Const_1,Const_2,Const_3,……
如果tensor是变量,那么就按规律取名:Variable,Variable_1,Variable_2,Variable_3,……
二,如何restore参数到重构网络中
据我所知,在restore参数时,要重新构造出与训练的网络相同的结构。如果不重构就能恢复参数,请联系我你是怎么做到的。restore很简单,只需要定义了saver后直接restore(这里就没有训练的过程啦)。下面的代码就是restore的过程,然后就能识别自己的手写字体啦(可以用画图来写一个数字)。
# encoding=utf-8 import tensorflow as tf from PIL import Image,ImageFilter from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) def imageprepare(argv): # 该函数读一张图片,处理后返回一个数组,进到网络中预测 """ This function returns the pixel values. The imput is a png file location. """ im = Image.open(argv).convert('L') width = float(im.size[0]) height = float(im.size[1]) newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels if width > height: # check which dimension is bigger # Width is bigger. Width becomes 20 pixels. nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width if nheight == 0: # rare case but minimum is 1 pixel nheight = 1 # resize and sharpen img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition newImage.paste(img, (4, wtop)) # paste resized image on white canvas else: # Height is bigger. Heigth becomes 20 pixels. nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height if (nwidth == 0): # rare case but minimum is 1 pixel nwidth = 1 # resize and sharpen img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition newImage.paste(img, (wleft, 4)) # paste resized image on white canvas # newImage.save("sample.png") tv = list(newImage.getdata()) # get pixel values # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. tva = [(255 - x) * 1.0 / 255.0 for x in tv] return tva def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) myGraph = tf.Graph() with myGraph.as_default(): # 重构相同的网络 with tf.name_scope('inputsAndLabels'): x_raw = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) with tf.name_scope('hidden1'): x = tf.reshape(x_raw, shape=[-1,28,28,1]) W_conv1 = weight_variable([5,5,1,32]) b_conv1 = bias_variable([32]) l_conv1 = tf.nn.relu(tf.nn.conv2d(x,W_conv1, strides=[1,1,1,1],padding='SAME') + b_conv1) l_pool1 = tf.nn.max_pool(l_conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') with tf.name_scope('hidden2'): W_conv2 = weight_variable([5,5,32,64]) b_conv2 = bias_variable([64]) l_conv2 = tf.nn.relu(tf.nn.conv2d(l_pool1, W_conv2, strides=[1,1,1,1], padding='SAME')+b_conv2) l_pool2 = tf.nn.max_pool(l_conv2, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') with tf.name_scope('fc1'): W_fc1 = weight_variable([64*7*7, 1024]) b_fc1 = bias_variable([1024]) l_pool2_flat = tf.reshape(l_pool2, [-1, 64*7*7]) l_fc1 = tf.nn.relu(tf.matmul(l_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder(tf.float32) l_fc1_drop = tf.nn.dropout(l_fc1, keep_prob) with tf.name_scope('fc2'): W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv = tf.matmul(l_fc1_drop, W_fc2) + b_fc2 with tf.Session(graph=myGraph) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess,'./model/mnistmodel-1') # restore参数 array = imageprepare('./1.png') # 读一张包含数字的图片 prediction = tf.argmax(y_conv, 1) # 预测 prediction = prediction.eval(feed_dict={x_raw:[array],keep_prob:1.0},session=sess) print('The digits in this image is:%d'%prediction[0])
总结
识别引擎效果还是不错的,其核心就是卷积神经网络。相关文章推荐
- Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- tensorflow1.0学习之模型的保存与恢复(Saver)
- 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。
- tensorflow 学习笔记(十一)- 模型的保存与恢复(Saver)
- 百度DMLC分布式深度机器学习开源项目(简称“深盟”)上线了如xgboost(速度快效果好的Boosting模型)、CXXNET(极致的C++深度学习库)、Minerva(高效灵活的并行深度学习引擎)以及Parameter Server(一小时训练600T数据)等产品,在语音识别、OCR识别、人脸识别以及计算效率提升上发布了多个成熟产品。
- tensorflow学习笔记(八):模型持久化 saver and restore
- Tensorflow学习(6)模型的保存与恢复(saver)
- tensorflow: 保存和加载模型, 参数;以及使用预训练参数方法
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- [tensorflow] tensorflow 1.0 学习:模型的保存与恢复(Saver)
- TensorFlow学习记录-- 5.用lstm对手写数字进行识别(待修改,差增加rnn以及lstm的知识)
- RNN,LSTM手写数组的识别,saver保存以及加载。
- 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体
- tensorflow Lenet5手写字体识别模型的保存与加载
- TensorFlow学习笔记(二):手写数字识别之多层感知机
- tensorflow模型参数保存和加载问题
- 02:一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
- tensorflow学习笔记六:保存和加载训练模型
- 学习TensorFlow,保存学习到的网络结构参数并调用