Tensorflow中创建自己的TFRecord格式数据集
2018-01-03 22:03
459 查看
参考文献《TensorFlow实战Google深度学习框架》
它实际上存储了一个从属性名到取值的字典。其中属性名为一个字符串,属性取值可以为字符串(ByteList),实数列表(FloatList)和整数列表(Int64List)。比如对于一幅图像而言,可以将图像的像素信息保存成一个字符串,将图像对应的标签保存成整数列表。
预处理后的图片会保存在另一个指定的文件夹下。
下面这个函数完成的功能是读取我们之前预处理后的所有图片,并依次将每张图片写入到TFRecord文件中,这里因为数据很少,只写入到了一个TFRecord文件中,当数据量很大时,也可以写入多个文件中。
注意,这里我们得到的返回值都是张量,需要在tensorflow中创建session后才能得到实际的数据。如下:
输出结果:
可以发现,最后一个的batch中的图像和标签是一一对应的(0: cat; 1: dog; 2: horse),说明我们已经成功从TFRecord文件中读出了数据。
TFRecord格式介绍
TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer格式(即二进制文件)存储,具体定义如下:message Example{ Features features = 1; }; message Features{ map<string,Feature> feature = 1; }; message Feature{ oneof kind{ BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };
它实际上存储了一个从属性名到取值的字典。其中属性名为一个字符串,属性取值可以为字符串(ByteList),实数列表(FloatList)和整数列表(Int64List)。比如对于一幅图像而言,可以将图像的像素信息保存成一个字符串,将图像对应的标签保存成整数列表。
创建TFRecord文件
先导入一些必要的库:(jupyter-notebook中实现)import numpy as np import tensorflow as tf from PIL import Image import os import matplotlib.pyplot as plt %matplotlib inline
数据预处理
我自己从网上下载了10张图片(3张猫,4张狗,3张马),分别存放在cat, dog和horse文件夹下,因为从网上下载的图片大小格式不统一,先将这些图片做预处理,函数如下(这里只是附上函数部分代码,文末会附上完整测试代码):def preprocess(imageRawDir, imageDir): """ images preprocess Arguments: imageRawDir -- directory of primary images. imageDir -- directory of processed images. Return: none. """ imageNames = os.listdir(imageRawDir) label = imageDir.split("/")[-2] # directory format:"./data/cat/" for index, imageName in enumerate(imageNames): image = Image.open(os.path.join(imageRawDir,imageName)) image = image.resize((256, 256)) savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg") image.save(savePath)
预处理后的图片会保存在另一个指定的文件夹下。
写入到TFRecord文件
下面两个函数会在创建TFRecord文件的时候用到。因为如果不写成函数的形式,代码会很长,看起来也很头疼。def _int64_feature(value): """ generate int64 feature. """ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value): """ generate byte feature. """ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
下面这个函数完成的功能是读取我们之前预处理后的所有图片,并依次将每张图片写入到TFRecord文件中,这里因为数据很少,只写入到了一个TFRecord文件中,当数据量很大时,也可以写入多个文件中。
def createRecord(imageDir): """ create TFRecord data. Arguments: imageDir -- image directory. Return: none. """ # create a writer to write TFRecord file writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords")) classNames = ["cat", "dog", "horse"] for classIndex, className in enumerate(classNames): print "class name = ",className currentClassDir = os.path.join(imageDir,className) print "current dir = ",currentClassDir for index, imageName in enumerate(os.listdir(currentClassDir)): image = Image.open(os.path.join(currentClassDir,imageName)) image_raw = image.tobytes() # convert image to binary format print index, imageName # write image data(pixel values and label) to Example Protocol Buffer example = tf.train.Example(features = tf.train.Features(feature = { "label": _int64_feature(classIndex), "image_raw": _bytes_feature(image_raw), })) # write an example to TFRecord file writer.write(example.SerializeToString()) writer.close()
读取TFRecord文件
def readRecord(recordName): """ read TFRecord data (images). Arguments: recordName -- the TFRecord file to be read. return: data saved in recordName (image and label). """ filenameQueue = tf.train.string_input_producer([recordName]) reader = tf.TFRecordReader() _, serializedExample = reader.read(filenameQueue) features = tf.parse_single_example(serializedExample, features={ "label": tf.FixedLenFeature([], tf.int64), "image_raw": tf.FixedLenFeature([], tf.string) }) label = features["label"] image = features["image_raw"] image = tf.decode_raw(image, tf.uint8) image = tf.reshape(image,[256,256,3]) label = tf.cast(label, tf.int32) return image, label
注意,这里我们得到的返回值都是张量,需要在tensorflow中创建session后才能得到实际的数据。如下:
##test code image, label = readRecord("./data/train.tfreco 4000 rds") print image, label imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) thread = tf.train.start_queue_runners(sess=sess) for i in range(10): #print image_batch.shape, label.shape images, labels = sess.run([imageBatch, labelBatch]) print "batch shape = ", images.shape,"labels = ", labels print "label = ", labels for i in range(4): plt.subplot(1,4,i+1) plt.axis("off") plt.imshow(images[i])
输出结果:
batch shape = (4, 256, 256, 3) labels = [0 1 1 0] batch shape = (4, 256, 256, 3) labels = [1 0 0 0] batch shape = (4, 256, 256, 3) labels = [1 2 2 1] batch shape = (4, 256, 256, 3) labels = [1 2 1 2] batch shape = (4, 256, 256, 3) labels = [0 0 1 0] batch shape = (4, 256, 256, 3) labels = [2 1 2 1] batch shape = (4, 256, 256, 3) labels = [1 1 0 0] batch shape = (4, 256, 256, 3) labels = [0 2 1 1] batch shape = (4, 256, 256, 3) labels = [2 2 1 0] batch shape = (4, 256, 256, 3) labels = [1 2 0 1] label = [1 2 0 1]
可以发现,最后一个的batch中的图像和标签是一一对应的(0: cat; 1: dog; 2: horse),说明我们已经成功从TFRecord文件中读出了数据。
完整样例代码
import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt
currentDir = os.getcwd()
os.chdir(currentDir)
print currentDir
def preprocess(imageRawDir, imageDir):
"""
images preprocess
Arguments:
imageRawDir -- directory of primary images.
imageDir -- directory of processed images.
Return: none.
"""
imageNames = os.listdir(imageRawDir)
label = imageDir.split("/")[-2] # directory format:"./data/cat/"
for index, imageName in enumerate(imageNames):
image = Image.open(os.path.join(imageRawDir,imageName))
image = image.resize((256, 256))
savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
image.save(savePath)
##test code
catRawDir = "./data_raw/cat/"
catDir = "./data/cat/"
preprocess(catRawDir, catDir)
dogRawDir = "./data_raw/dog/"
dogDir = "./data/dog/"
preprocess(dogRawDir, dogDir)
horseRawDir = "./data_raw/horse/"
horseDir = "./data/horse/"
preprocess(horseRawDir, horseDir)
def _int64_feature(value): """ generate int64 feature. """ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value): """ generate byte feature. """ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def createRecord(imageDir):
"""
create TFRecord data.
Arguments:
imageDir -- image directory.
Return: none.
"""
writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
classNames = ["cat", "dog", "horse"]
for classIndex, className in enumerate(classNames):
print "class name = ",className
currentClassDir = os.path.join(imageDir,className)
print "current dir = ",currentClassDir
for index, imageName in enumerate(os.listdir(currentClassDir)):
image = Image.open(os.path.join(currentClassDir,imageName))
image_raw = image.tobytes() # convert image to binary format
print index, imageName
example = tf.train.Example(features = tf.train.Features(feature = {
"label": _int64_feature(classIndex),
"image_raw": _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()
##test code
createRecord(os.path.join(currentDir, "data/"))
def readRecord(recordName): """ read TFRecord data (images). Arguments: recordName -- the TFRecord file to be read. return: data saved in recordName (image and label). """ filenameQueue = tf.train.string_input_producer([recordName]) reader = tf.TFRecordReader() _, serializedExample = reader.read(filenameQueue) features = tf.parse_single_example(serializedExample, features={ "label": tf.FixedLenFeature([], tf.int64), "image_raw": tf.FixedLenFeature([], tf.string) }) label = features["label"] image = features["image_raw"] image = tf.decode_raw(image, tf.uint8) image = tf.reshape(image,[256,256,3]) label = tf.cast(label, tf.int32) return image, label
##test code
image, label = readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)
##test code
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
#print image_batch.shape, label.shape
images, labels = sess.run([imageBatch, labelBatch])
print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels
for i in range(4):
plt.subplot(1,4,i+1)
plt.axis("off")
plt.imshow(images[i])
相关文章推荐
- tensorflow读取数据-tfrecord格式
- TFRecord —— tensorflow 下的统一数据存储格式
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
- TensorFlow 制作自己的TFRecord数据集
- tensorflow读取数据-tfrecord格式
- Tensorflow使用tfrecord输入数据格式
- How to transform our data into TFRecord(怎样将自己的图片数据转换成TF的格式)
- 将自己的数据集制作成TFRecord格式
- Tensorflow读取数据2-tfrecord
- TensorFlow制作、读取TFRecord格式数据集
- 统一数据格式TFRecord
- Tensorflow学习教程------tfrecords数据格式生成与读取
- 【TensorFlow系列】【一】利用TFRecordDataset读取图片数据
- [TFRecord格式数据]基本介绍
- TensorFlow入门(十-I)tfrecord 固定维度数据读写
- Tensorflow-tfrecord数据
- TensorFlow 制作自己的TFRecord数据集
- Tensorflow 训练自己的数据集(二)(TFRecord)
- tensorflow入门:TFRecordDataset变长数据的batch读取
- TensorFlow入门(十-II)tfrecord 可变长度的序列数据