Tensorflow使用笔记(2): 如何构建TFRecords并进行Mini Batch训练
2017-06-24 17:10
666 查看
引言
前段时间在做一门课程的期末大作业的时候,用到了TensorFlow,构建了含有两层卷积层的神经网络去做 交通标志的识别,一开始使用 24x24 的图像作为输入(把数据集的图像都resize为24x24)后来感觉应该设计大一点会可靠一点,那就想把输入的图像都改为 64x64 的大小,相应修改了网络一些参数后,run的时候发现出问题了,我还以为是代码没有改好,仔细看一下提示信息:run out of memory 。原来是是内存不足那在了解到情况后,就上网找方法,于是乎,找到了个普遍的解决方法:使用mini_batch方法训练,好,那下面就是一些整理网上的资料了
什么是Mini_Batch方法
点这里可以看比较简略的介绍。先简单介绍一下这三个常见的名词:batch_size ,iteration,epoch
batchsize:批大小。在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练;
iteration:1个iteration等于使用batchsize个样本训练一次;
epoch:1个epoch等于使用训练集中的全部样本训练一次
总体来说Mini_Batch就是介于SGD(随机梯度下降)和BGD(批梯度下降)之间的一种比较不错的方法,batch_size选择合适了,既能提高训练速度,又能求得一个逼近全局最优解的结果(但是在实际运用中应该要多次修改才能获得合适的size),点这里可以看一些关于怎么选择好batch_size的建议
怎么实现?
那问题来了,怎么在tensorflow中实现Mini_batch训练呢?一开始我的数据集是使用pickle存在硬盘上的with open( 'images.pkl') as f: # training_images 就是存储了很多的使用opencv读取的图像,它们的类型都是np.array training_images = pickle.load(f)
那如果是按照顺序读取,每次从取training_images的一部分,然后再取它的下一部分,似乎很容易实现,但是如果每次喂进去的样本没有随机性的话,那似乎便失去了Mini Batch的意义,如果想实现随机取的话,似乎不太容易,那我们可以利用tensorflow,制作适合Tensorflow的数据集TFRecords
简单制作TFRecords
使用TFRecord有什么好处,能把它转成二进制,tensorflow对它会加速。处理起来更快;可以配合tensorflow里的函数,配合使用起来更方便直接上代码(参考网上)
with open('Training_images64x64.pkl') as f1,open('Training_labels.pkl') as f2: print "loading...,please waitting for few seconds" images = pickle.load(f1) # 数据 labels = pickle.load(f2) # 样本标签 # 构建一个writer,待会用来把TFRecords写入硬盘的 writer = tf.python_io.TFRecordWriter("train.tfrecords") num = len(labels) print "the total account of the samples:",num for i in range(num): # build tf record label = eval(labels[i]) # 写入的标签是要数字 img = images[i] img_raw = img.tobytes() # TFRecord 要把它转化成字节 # 重点,开始映射了 example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) # 写入到硬盘 writer.write(example.SerializeToString()) # Serialize To String print "all Done" # 结束,close。跟文件操作挺类似的 writer.close()
读取并使用TFRecords
首先先定义一个读取 TFRecords的函数吧# 使用队列读取数据,这个队列在tensorflow里面有特别的含义 # 把数据放在队列里有很多好处,可以完成训练数据和测试数据的解耦 def read_and_decode(filename): # 根据文件名生成一个队列 filename_queue = tf.train.string_input_producer([filename]) # 定义reader,跟之前定义writer是对应的 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) # 下面这个reshape 可以根据自己的需要来决定要不要重新定义大小 # img = tf.reshape(img, [224, 224, 3]) # 转化为tensorflow 的 float32, img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label
接下来是怎么使用的问题了。假设已经定义好计算图了
假设计算图是已经定义好的 为 graph
# 先度硬盘上读取出来,利用上面定义好的函数 img, label = read_and_decode("train.tfrecords") sess = tf.Session(graph=graph) # 定义回话 init.run(session=sess) # 初始化 # 启动队列线程 threads = tf.train.start_queue_runners(sess=sess) sess.run(fetches=init,feed_dict={images_ph:training_images, labels_ph:training_labels}) # !!!划重点了!!! # 使用shuffle_batch,tensorflow可以有效地帮我们随机从训练数据中随机抽出batch_size个数据样本 image_batch ,label_batch = tf.train.shuffle_batch([img, label] ,batch_size=30) # 上面这个函数,还有capacity等其他参数,作用还未明白 # image_train ,label_train = tf.train.shuffle_batch([img, label] ,batch_size=30, capacity=2000,min_after_dequeue=1000) iteration_times = 2000 # 假定迭代次数为200 # 开始训练 for i in range(0,iteration_times): # 每次都要run一次,否则取不到说好的batch的数据哟 sess.run([image_train,label_train]) # 接下来就可以把image_train,label_train # 喂给你的训练节点了 _, loss_value =sess.run( [train_op,loss], feed_dict={images_ph: image_train, labels_ph:label_train} ) # 其他代码
训练完之后别忘记保存模型参数哟,具体的介绍可以看我的另一篇文章
关于队列的详细介绍,可以看这里,还挺复杂的
小结一下
我在自己的问题上,在实际训练的时候,使用实验室的带GPU的服务器,训练的结果有些慢,损失loss震荡比较厉害,一直没有收敛,可能还是网络设计的有问题吧。后来在提交作业的时候,选择了32x32的输入图像(梯度下降法,没有使用分批训练),反而在测试集上的准确率比较高,对于Mini_batch的训练方法。batch_size的选择还是缺少指导方针呀。大过小都不好,实践是检验真理的唯一标准呀,此次作业收获良多。
相关文章推荐
- 【hadoop】Hadoop学习笔记(九):如何在windows上使用eclipse远程连接hadoop进行程序开发
- Xcode学习笔记---如何使用Xcode中的storyboard构建你的第一个IOS应用
- Tensorflow之构建自己的图片数据集TFrecords
- 使用笔记:mysql与oracle进行sql查询时如何表示日期
- Docker学习笔记(3)-- 如何使用Dockerfile构建镜像
- 如何使用印象笔记进行更好的学习呢?
- Docker学习笔记(3)-- 如何使用Dockerfile构建镜像
- Tensorflow构建自己的图片数据集TFrecords
- tensorflow笔记:使用tf来实现word2vec
- 【深度学习】笔记7: CNN训练Cifar-10技巧 ---如何进行实验,如何进行构建自己的网络模型,提高精度
- Docker学习笔记(3)-- 如何使用Dockerfile构建镜像
- 如何使用TestFlight进行App构建版本测试
- iOS学习笔记9- iOS 如何使用TestFlight进行Beta测试
- tensorflow笔记:使用tf来实现word2vec
- Hadoop学习笔记(八):如何使用Maven构建《hadoop权威指南3》随书的源码包
- Docker学习笔记(3)-- 如何使用Dockerfile构建镜像
- 如何使用TestFlight进行App构建版本测试
- 如何使用TestFlight进行App构建版本测试(转)
- 使用笔记:mysql与oracle进行sql查询时如何表示日期
- ArcGIS API for JavaScript 4.2学习笔记[23] 没有地图如何进行查询?【FindTask类的使用】