您的位置:首页 > 其它

tensorflow 训练mnist数据

2017-04-11 18:04 537 查看
tensorflow 训练mnist数据

import tensorflow as tf
import numpy as np
import os
import struct
import gzip
from tensorflow.examples.tutorials.mnist import input_data
import test8

def readata(label,image):

with gzip.open(label) as flbl:
magic,num=struct.unpack('>II',flbl.read(8))
lab=np.fromstring(flbl.read(),dtype=np.int8)
label=np.zeros((lab.shape[0],10))
for i in range(len(lab)):
label[i,lab[i]]=1.0

with gzip.open(image,'rb') as fimg:
magic,num,rows,cols=struct.unpack(">IIII",fimg.read(16))
img=np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols)
image=img.reshape(img.shape[0],img.shape[1],img.shape[2],1)
return label, image

with tf.name_scope('input'):
image=tf.placeholder(tf.float32, [None, 28, 28, 1])
label=tf.placeholder(tf.float32,[None,10])

with tf.name_scope('conv1'):
W_con1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))
b_con1=tf.Variable(tf.constant(0.1,shape=[32]))
c_con1=tf.nn.conv2d(image,W_con1,strides=[1,1,1,1],padding='SAME')+b_con1
h_con1=tf.nn.relu(c_con1)
m_pool2=tf.nn.max_pool(h_con1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

with tf.name_scope('conv2'):
W_con2=tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1))
b_con2=tf.Variable(tf.constant(0.1,shape=[64]))
c_con2=tf.nn.conv2d(m_pool2,W_con2,strides=[1,1,1,1],padding='SAME')+b_con2
h_con2=tf.nn.relu(c_con2)
m_pool2=tf.nn.max_pool(h_con2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

with tf.name_scope('fc1'):
W_fc1=tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1))
b_fc1=tf.Variable(tf.constant(0.1,shape=[1024]))
m_pool2_flat=tf.reshape(m_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(m_pool2_flat,W_fc1)+b_fc1)

with tf.name_scope('drop'):
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob=keep_prob)

with tf.name_scope('fc2'):
W_fc2=tf.Variable(tf.truncated_normal(shape=[1024,10],stddev=0.1))
b_fc2=tf.Variable(tf.constant(0.1,shape=[10]))
y_con=tf.matmul(h_fc1_drop,W_fc2)+b_fc2

with tf.name_scope('cross_entry'):
cross_entry=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=y_con))
train=tf.train.GradientDescentOptimizer(0.001).minimize(cross_entry)

# correct_prediction = tf.equal(tf.argmax(y_con, 1), tf.argmax(label, 1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

train_label,train_image=readata('train-labels-idx1-ubyte.gz','train-images-idx3-ubyte.gz')
test_label,test_image=readata('t10k-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz')

train_image1=test8.Dataset(train_image,train_label)

saver=tf.train.Saver()

//从头开始训练
# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     train_writer = tf.summary.FileWriter('./' + '/train', sess.graph)
#     for i in range(20000):
#         batchimage,batchlabel=train_image1.next_batch(60)
#         _,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0})
#         if i % 100 == 0:
#             # tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :]
#             # tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :]
#             # train_accuracy = accuracy.eval(feed_dict={
#             #     image:tm1, label:tn1, keep_prob: 1.0})
#             # print 'sss',train_accuracy
#             print 'loss',loss
#     train_writer.close()
#     save_path=saver.save(sess,'/home/dms/model.ckpt')
#     print save_path

//加载模型训练
with tf.Session() as sess:
saver.restore(sess,'/home/dms/model.ckpt')
# train_writer = tf.summary.FileWriter('./' + '/train', sess.graph)
for i in range(10000):
batchimage,batchlabel=train_image1.next_batch(60)
_,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0})
if i % 100 == 0:
# tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :]
# tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :]
# train_accuracy = accuracy.eval(feed_dict={
#     image:tm1, label:tn1, keep_prob: 1.0})
# print 'sss',train_accuracy
print 'loss',loss
# train_writer.close()
# save_path=saver.save(sess,'/home/dms/model.ckpt')
# print save_path


//迭代器

import numpy as np

class Dataset(object):
def __init__(self,images,labels):
self._images=images
self._labels=labels
self._num_examples=len(images)
self._index_in_epoch=0

def next_batch(self,batch_size):
start=self._index_in_epoch
if(start+batch_size>len(self._images)):
rest_num_examples = self._num_examples - start
image_rest_part=self._images[start:rest_num_examples]
label_rest_part=self._labels[start:rest_num_examples]
start=0
self._index_in_epoch=batch_size-rest_num_examples
end=self._index_in_epoch
image_new_part=self._images[start:end]
label_new_part=self._labels[start:end]
return np.concatenate((image_rest_part,image_new_part),axis=0), np.concatenate((label_rest_part,label_new_part),axis=0)
self._index_in_epoch += batch_size
end=self._index_in_epoch
return self._images[start:end],self._labels[start:end]
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: