您的位置:首页 > 大数据 > 人工智能

损失函数震荡不收敛可能原因:tf.train.shuffle_batch

2018-10-19 09:09 134 查看

​ 在制作tfrecords数据集的时候,比如说将cifar数据转换成tfrecords数据集,一般会用到tf.train.shuffle_batch函数,而损失函数震荡不收敛的原因就可能就是数据集制作出现问题。

​ Cifar-10数据集包含了airlane、automobile、bird、cat、deer、dog、frog、horse、ship、truck,10种分类 ,分别放在十个文件夹中。共60000张图片,其中训练集50000张,测试集10000张。

​ 开始在制作数据集的时候,我是先将一个文件夹中的所有图片写入tfrecords,这样制作的问题就是:将同一类的图片按照顺序写入到了tfrecords中,然而后面再读取tfrecords时,使用到了tf.train.shuffle_batch,此函数只是在batch_size中reshuffle,总体的顺序并没有改变,所以喂入网络的数据都是同一类的图片,并不能起到训练网络的效果。

​ 正确的做法是:应该将这些图片打乱之后写入到tfrecords中。我采取的方法是:因为每个文件夹中图片数量是固定的,所以将这些图片名称全部读取出来,存储到字典中,因为batch_size为200,所以依次从十个文件夹中读取20张图片写入到tfrecords中,这样再训练的时候,取出的数据就不再会是同一类的图片。

def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName)
num_pic = 0
dirs = os.listdir(image_path)
# print(dirs)
contents = {}
for _dir in dirs:
temp_path = os.path.join(image_path, _dir)
temp = os.listdir(temp_path)
contents[_dir] = temp
# print(len(contents[_dir]))
# print(contents)

for i in range(int(len(contents[dirs[0]]) / 20)):
for index in range(len(dirs)):
for j in range(i*20, i*20+20):
ima_path = os.path.join(image_path, dirs[index], contents[dirs[index]][j])
img = Image.open(ima_path)
img_raw = img.tobytes()
labels = [0] * 10
labels[index] = 1

example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString())
num_pic += 1
print ("the number of picture:", num_pic)

writer.close()
print("write tfrecord successful")
阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: