TensorFlow入门(十-III)tfrecord 图片数据 读写
2017-11-21 17:19
501 查看
本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord
关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据
在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。
tfrecord_3_img_writer.py
tfrecord_3_img_reader.py
关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据
在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。
tfrecord_3_img_writer.py
# -*- coding:utf-8 -*- import tensorflow as tf import numpy as np from tqdm import tqdm import sys import os import time '''tfrecord 写入数据. 将图片数据写入 tfrecord 文件。以 MNIST png格式数据集为例。 首先将图片解压到 ../../MNIST_data/mnist_png/ 目录下。 解压以后会有 training 和 testing 两个数据集。在每个数据集下,有十个文件夹,分别存放了这10个类别的数据。 每个文件夹名为对应的类别编码。 现在网上关于打包图片的例子非常多,实现方式各式各样,效率也相差非常多。 选择合适的方式能够有效地节省时间和硬盘空间。 有几点需要注意: 1.打包 tfrecord 的时候,千万不要使用 Image.open() 或者 matplotlib.image.imread() 等方式读取。 1张小于10kb的png图片,前者(Image.open) 打开后,生成的对象100+kb, 后者直接生成 numpy 数组,大概是原图片的几百倍大小。 所以应该直接使用 tf.gfile.FastGFile() 方式读入图片。 2.从 tfrecord 中取数据的时候,再用 tf.image.decode_png() 对图片进行解码。 3.不要随便使用 tf.image.resize_image_with_crop_or_pad 等函数,可以直接使用 tf.reshape()。前者速度极慢。 4.如果有固态硬盘的话,图片数据一定要放在固态硬盘中进行读取,速度能高几十倍几十倍几十倍!生成的 tfrecord 文件就无所谓了,找个机械盘放着就行。 ''' # png 文件路径 TRAINING_DIR = '../../MNIST_data/mnist_png/training/' TESTING_DIR = '../../MNIST_data/mnist_png/testing/' # tfrecord 文件保存路径,这里只保存一个 tfrecord 文件 TRAINING_TFRECORD_NAME = 'training.tfrecord' TESTING_TFRECORD_NAME = 'testing.tfrecord' DICT_LABEL_TO_ID = { # 把 label(文件名) 转为对应 id '0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, } def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def int64_feature(values): return tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) def convert_tfrecord_dataset(dataset_dir, tfrecord_name, tfrecord_path='../data/'): """ convert samples to tfrecord dataset. Args: dataset_dir: 数据集的路径。 tfrecord_name: 保存为 tfrecord 文件名 tfrecord_path: 保存 tfrecord 文件的路径。 """ if not os.path.exists(dataset_dir): print(u'png文件路径错误,请检查是否已经解压png文件。') exit() if not os.path.exists(os.path.dirname(tfrecord_path)): os.makedirs(os.path.dirname(tfrecord_path)) tfrecord_file = os.path.join(tfrecord_path, tfrecord_name) class_names = os.listdir(dataset_dir) n_class = len(class_names) print(u'一共有 %d 个类别' % n_class) with tf.python_io.TFRecordWriter(tfrecord_file) as writer: for class_name in class_names: # 对于每个类别 class_dir = os.path.join(dataset_dir, class_name) # 获取类别对应的文件夹路径 file_names = os.listdir(class_dir) # 在该文件夹下,获取所有图片文件名 label_id = DICT_LABEL_TO_ID.get(class_name) # 获取类别 id print(u'\n正在处理类别 %d 的数据' % label_id) time0 = time.time() n_sample = len(file_names) for i in range(n_sample): file_name = file_names[i] sys.stdout.write('\r>> Converting image %d/%d , %g s' % ( i + 1, n_sample, time.time() - time0)) png_path = os.path.join(class_dir, file_name) # 获取每个图片的路径 # CNN inputs using img = tf.gfile.FastGFile(png_path, 'rb').read() # 读入图片 example = tf.train.Example( features=tf.train.Features( feature={ 'image': bytes_feature(img), 'label': int64_feature(label_id) })) serialized = example.SerializeToString() writer.write(serialized) print('\nFinished writing data to tfrecord files.') if __name__ == '__main__': convert_tfrecord_dataset(TRAINING_DIR, TRAINING_TFRECORD_NAME) convert_tfrecord_dataset(TESTING_DIR, TESTING_TFRECORD_NAME)
tfrecord_3_img_reader.py
# -*- coding:utf-8 -*- import tensorflow as tf '''read data 从 tfrecord 文件中读取数据,对应数据的格式为png / jpg 等图片数据。 ''' # **1.把所有的 tfrecord 文件名列表写入队列中 filename_queue = tf.train.string_input_producer(['../data/training.tfrecord'], num_epochs=None, shuffle=True) # **2.创建一个读取器 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # **3.根据你写入的格式对应说明读取的格式 features = tf.parse_single_example(serialized_example, features={ 'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64) } ) img = features['image'] # 这里需要对图片进行解码 img = tf.image.decode_png(img, channels=3) # 这里,也可以解码为 1 通道 img = tf.reshape(img, [28, 28, 3]) # 28*28*3 label = features['label'] print('img is', img) print('label is', label) # **4.通过 tf.train.shuffle_batch 或者 tf.train.batch 函数读取数据 """ 这里,你会发现每次取出来的数据都是一个类别的,除非你把 capacity 和 min_after_dequeue 设得很大,如 X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100, capacity=20000, min_after_dequeue=10000, num_threads=3) 这是因为在打包的时候都是一个类别一个类别的顺序打包的,所以每次填数据都是按照那个顺序填充进来。 只有当我们把队列容量舍得非常大,这样在队列中才会混杂各个类别的数据。但是这样非常不好,因为这样的话, 读取速度就会非常慢。所以解决方法是: 1.在写入数据的时候先进行数据 shuffle。 2.多存几个 tfrecord 文件,比如 64 个。 """ X_batch, y_batch = tf.train.shuffle_batch([img, label], batch_size=100, capacity=200, min_after_dequeue=100, num_threads=3) sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) # **5.启动队列进行数据读取 # 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。 # 这样,在数据读取完毕以后,调用协调器把线程全部都关了。 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) y_outputs = list() for i in xrange(5): _X_batch, _y_batch = sess.run([X_batch, y_batch]) print('** batch %d' % i) print('_X_batch.shape:', _X_batch.shape) print('_y_batch:', _y_batch) y_outputs.extend(_y_batch.tolist()) print(y_outputs) # **6.最后记得把队列关掉 coord.request_stop() coord.join(threads)
相关文章推荐
- TensorFlow入门(十-I)tfrecord 固定维度数据读写
- TensorFlow入门(十-II)tfrecord 可变长度的序列数据
- 【TensorFlow系列】【一】利用TFRecordDataset读取图片数据
- tensorflow入门:TFRecordDataset变长数据的batch读取
- TFRecord —— tensorflow 下的统一数据存储格式
- TensorFlow高效读取数据——TFRecord
- Tensorflow中创建自己的TFRecord格式数据集
- tensorflow入门:tfrecord 和tf.data.TFRecordDataset
- How to transform our data into TFRecord(怎样将自己的图片数据转换成TF的格式)
- Tensorflow读取数据2-tfrecord
- Tensorflow中使用tfrecord方式读取数据
- TensorFlow 制作自己的TFRecord数据集
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
- Tensorflow使用tfrecord输入数据格式
- tensorflow之图片与tfrecord之间的转化
- Tensorflow中使用tfrecord方式读取数据的方法
- 深入浅出的TensorFlow数据格式化存储工具TFRecord用法教程
- TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
- tensorflow读取数据-tfrecord格式
- 第一阶段-入门详细图文讲解tensorflow1.4 -(八)tf.estimator构建数据预处理bostonHouse