损失函数震荡不收敛可能原因:tf.train.shuffle_batch
2018-10-19 09:09
134 查看
在制作tfrecords数据集的时候,比如说将cifar数据转换成tfrecords数据集,一般会用到tf.train.shuffle_batch函数,而损失函数震荡不收敛的原因就可能就是数据集制作出现问题。
Cifar-10数据集包含了airlane、automobile、bird、cat、deer、dog、frog、horse、ship、truck,10种分类 ,分别放在十个文件夹中。共60000张图片,其中训练集50000张,测试集10000张。
开始在制作数据集的时候,我是先将一个文件夹中的所有图片写入tfrecords,这样制作的问题就是:将同一类的图片按照顺序写入到了tfrecords中,然而后面再读取tfrecords时,使用到了tf.train.shuffle_batch,此函数只是在batch_size中reshuffle,总体的顺序并没有改变,所以喂入网络的数据都是同一类的图片,并不能起到训练网络的效果。
正确的做法是:应该将这些图片打乱之后写入到tfrecords中。我采取的方法是:因为每个文件夹中图片数量是固定的,所以将这些图片名称全部读取出来,存储到字典中,因为batch_size为200,所以依次从十个文件夹中读取20张图片写入到tfrecords中,这样再训练的时候,取出的数据就不再会是同一类的图片。
def write_tfRecord(tfRecordName, image_path, label_path): writer = tf.python_io.TFRecordWriter(tfRecordName) num_pic = 0 dirs = os.listdir(image_path) # print(dirs) contents = {} for _dir in dirs: temp_path = os.path.join(image_path, _dir) temp = os.listdir(temp_path) contents[_dir] = temp # print(len(contents[_dir])) # print(contents) for i in range(int(len(contents[dirs[0]]) / 20)): for index in range(len(dirs)): for j in range(i*20, i*20+20): ima_path = os.path.join(image_path, dirs[index], contents[dirs[index]][j]) img = Image.open(ima_path) img_raw = img.tobytes() labels = [0] * 10 labels[index] = 1 example = tf.train.Example(features=tf.train.Features(feature={ 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)) })) writer.write(example.SerializeToString()) num_pic += 1 print ("the number of picture:", num_pic) writer.close() print("write tfrecord successful")阅读更多
相关文章推荐
- tf.train.shuffle_batch和tf.train.batch接口说明
- tf.train.batch和tf.train.shuffle_batch的用法
- tf.train.batch 和 tf.train.shuffle_batch 的知识笔记
- tf.train.batch和tf.train.shuffle_batch的理解
- tf.train.shuffle_batch
- tf.train.batch和tf.train.shuffle_batch的用法
- 遇到的问题与解决办法(tf.train.shuffle_batch与tf.train.slice_input_producer)
- tf.train.batch()和tf.train.shuffle_batch()函数
- [tensorflow教程] [cifar10] tf.train.batch和tf.train.shuffle_batch的用法
- tf.train.batch 和 tf.train.batch_join的区别
- tf.train.batch()
- tf.train.shuffle_batch函数解析
- tensorflow学习——tf.floor与tf.train.batch
- tensorflow tf.train.batch之数据批量读取
- springmvc No mapping found for HTTP request with URI 可能原因统计
- 移植boa后运行CGI程序可能出现的原因及解决方法
- wamp服务器访问php非常缓慢的可能原因以及解决方法
- 电脑无故死机的可能原因……
- 莫一种可能的原因 gzip: stdin: not in gzip format
- web项目编译出错时,原因之一,可能是build path 中order and Export引起