您的位置:首页 > 其它

TensorFlow入门(十-III)tfrecord 图片数据 读写

2017-11-21 17:19 501 查看
本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord

关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。

- 维度固定的 numpy 矩阵

- 可变长度的 序列 数据

- 图片数据

在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。

tfrecord_3_img_writer.py

# -*- coding:utf-8 -*-

import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys
import os
import time

'''tfrecord 写入数据.
将图片数据写入 tfrecord 文件。以 MNIST png格式数据集为例。

首先将图片解压到 ../../MNIST_data/mnist_png/ 目录下。
解压以后会有 training 和 testing 两个数据集。在每个数据集下,有十个文件夹,分别存放了这10个类别的数据。
每个文件夹名为对应的类别编码。

现在网上关于打包图片的例子非常多,实现方式各式各样,效率也相差非常多。
选择合适的方式能够有效地节省时间和硬盘空间。
有几点需要注意:
1.打包 tfrecord 的时候,千万不要使用 Image.open() 或者 matplotlib.image.imread() 等方式读取。
1张小于10kb的png图片,前者(Image.open) 打开后,生成的对象100+kb, 后者直接生成 numpy 数组,大概是原图片的几百倍大小。
所以应该直接使用 tf.gfile.FastGFile() 方式读入图片。
2.从 tfrecord 中取数据的时候,再用 tf.image.decode_png() 对图片进行解码。
3.不要随便使用 tf.image.resize_image_with_crop_or_pad 等函数,可以直接使用 tf.reshape()。前者速度极慢。
4.如果有固态硬盘的话,图片数据一定要放在固态硬盘中进行读取,速度能高几十倍几十倍几十倍!生成的 tfrecord 文件就无所谓了,找个机械盘放着就行。
'''

# png 文件路径
TRAINING_DIR = '../../MNIST_data/mnist_png/training/'
TESTING_DIR = '../../MNIST_data/mnist_png/testing/'
# tfrecord 文件保存路径,这里只保存一个 tfrecord 文件
TRAINING_TFRECORD_NAME = 'training.tfrecord'
TESTING_TFRECORD_NAME = 'testing.tfrecord'

DICT_LABEL_TO_ID = {  # 把 label(文件名) 转为对应 id
'0': 0,
'1': 1,
'2': 2,
'3': 3,
'4': 4,
'5': 5,
'6': 6,
'7': 7,
'8': 8,
'9': 9,
}

def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def int64_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))

def convert_tfrecord_dataset(dataset_dir, tfrecord_name, tfrecord_path='../data/'):
""" convert samples to tfrecord dataset.
Args:
dataset_dir: 数据集的路径。
tfrecord_name: 保存为 tfrecord 文件名
tfrecord_path: 保存 tfrecord 文件的路径。
"""
if not os.path.exists(dataset_dir):
print(u'png文件路径错误,请检查是否已经解压png文件。')
exit()
if not os.path.exists(os.path.dirname(tfrecord_path)):
os.makedirs(os.path.dirname(tfrecord_path))
tfrecord_file = os.path.join(tfrecord_path, tfrecord_name)
class_names = os.listdir(dataset_dir)
n_class = len(class_names)
print(u'一共有 %d 个类别' % n_class)
with tf.python_io.TFRecordWriter(tfrecord_file) as writer:
for class_name in class_names:  # 对于每个类别
class_dir = os.path.join(dataset_dir, class_name)  # 获取类别对应的文件夹路径
file_names = os.listdir(class_dir)  # 在该文件夹下,获取所有图片文件名
label_id = DICT_LABEL_TO_ID.get(class_name)  # 获取类别 id
print(u'\n正在处理类别 %d 的数据' % label_id)
time0 = time.time()
n_sample = len(file_names)
for i in range(n_sample):
file_name = file_names[i]
sys.stdout.write('\r>> Converting image %d/%d , %g s' % (
i + 1, n_sample, time.time() - time0))
png_path = os.path.join(class_dir, file_name)  # 获取每个图片的路径
# CNN inputs using
img = tf.gfile.FastGFile(png_path, 'rb').read()  # 读入图片
example = tf.train.Example(
features=tf.train.Features(
feature={
'image': bytes_feature(img),
'label': int64_feature(label_id)
}))
serialized = example.SerializeToString()
writer.write(serialized)
print('\nFinished writing data to tfrecord files.')

if __name__ == '__main__':
convert_tfrecord_dataset(TRAINING_DIR, TRAINING_TFRECORD_NAME)
convert_tfrecord_dataset(TESTING_DIR, TESTING_TFRECORD_NAME)


tfrecord_3_img_reader.py

# -*- coding:utf-8 -*-

import tensorflow as tf

'''read data
从 tfrecord 文件中读取数据,对应数据的格式为png / jpg 等图片数据。
'''

# **1.把所有的 tfrecord 文件名列表写入队列中
filename_queue = tf.train.string_input_producer(['../data/training.tfrecord'], num_epochs=None, shuffle=True)
# **2.创建一个读取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根据你写入的格式对应说明读取的格式
features = tf.parse_single_example(serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
)
img = features['image']
# 这里需要对图片进行解码
img = tf.image.decode_png(img, channels=3)  # 这里,也可以解码为 1 通道
img = tf.reshape(img, [28, 28, 3])  # 28*28*3

label = features['label']

print('img is', img)
print('label is', label)
# **4.通过 tf.train.shuffle_batch 或者 tf.train.batch 函数读取数据
"""
这里,你会发现每次取出来的数据都是一个类别的,除非你把 capacity 和 min_after_dequeue 设得很大,如
X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100,
capacity=20000, min_after_dequeue=10000, num_threads=3)
这是因为在打包的时候都是一个类别一个类别的顺序打包的,所以每次填数据都是按照那个顺序填充进来。
只有当我们把队列容量舍得非常大,这样在队列中才会混杂各个类别的数据。但是这样非常不好,因为这样的话,
读取速度就会非常慢。所以解决方法是:
1.在写入数据的时候先进行数据 shuffle。
2.多存几个 tfrecord 文件,比如 64 个。
"""

X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100,
capacity=200, min_after_dequeue=100, num_threads=3)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.启动队列进行数据读取
# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。
# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in xrange(5):
_X_batch, _y_batch = sess.run([X_batch, y_batch])
print('** batch %d' % i)
print('_X_batch.shape:', _X_batch.shape)
print('_y_batch:', _y_batch)
y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最后记得把队列关掉
coord.request_stop()
coord.join(threads)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: