您的位置:首页 > 其它

Tensorflow中创建自己的TFRecord格式数据集

2018-01-03 22:03 459 查看
参考文献《TensorFlow实战Google深度学习框架》

TFRecord格式介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer格式(即二进制文件)存储,具体定义如下:

message Example{
Features features = 1;
};
message Features{
map<string,Feature> feature = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};


它实际上存储了一个从属性名到取值的字典。其中属性名为一个字符串,属性取值可以为字符串(ByteList),实数列表(FloatList)和整数列表(Int64List)。比如对于一幅图像而言,可以将图像的像素信息保存成一个字符串,将图像对应的标签保存成整数列表。

创建TFRecord文件

先导入一些必要的库:(jupyter-notebook中实现)

import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt
%matplotlib inline


数据预处理

我自己从网上下载了10张图片(3张猫,4张狗,3张马),分别存放在cat, dog和horse文件夹下,因为从网上下载的图片大小格式不统一,先将这些图片做预处理,函数如下(这里只是附上函数部分代码,文末会附上完整测试代码):

def preprocess(imageRawDir, imageDir):
"""
images preprocess

Arguments:
imageRawDir -- directory of primary images.
imageDir -- directory of processed images.

Return: none.
"""
imageNames = os.listdir(imageRawDir)
label = imageDir.split("/")[-2] # directory format:"./data/cat/"
for index, imageName in enumerate(imageNames):
image = Image.open(os.path.join(imageRawDir,imageName))
image = image.resize((256, 256))
savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
image.save(savePath)


预处理后的图片会保存在另一个指定的文件夹下。

写入到TFRecord文件

下面两个函数会在创建TFRecord文件的时候用到。因为如果不写成函数的形式,代码会很长,看起来也很头疼。

def _int64_feature(value):
"""
generate int64 feature.
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
"""
generate byte feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


下面这个函数完成的功能是读取我们之前预处理后的所有图片,并依次将每张图片写入到TFRecord文件中,这里因为数据很少,只写入到了一个TFRecord文件中,当数据量很大时,也可以写入多个文件中。

def createRecord(imageDir):
"""
create TFRecord data.

Arguments:
imageDir -- image directory.
Return: none.
"""
# create a writer to write TFRecord file
writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
classNames = ["cat", "dog", "horse"]

for classIndex, className in enumerate(classNames):
print "class name = ",className
currentClassDir = os.path.join(imageDir,className)
print "current dir = ",currentClassDir
for index, imageName in enumerate(os.listdir(currentClassDir)):
image = Image.open(os.path.join(currentClassDir,imageName))
image_raw = image.tobytes() # convert image to binary format
print index, imageName

# write image data(pixel values and label) to Example Protocol Buffer
example = tf.train.Example(features = tf.train.Features(feature = {
"label": _int64_feature(classIndex),
"image_raw": _bytes_feature(image_raw),
}))

# write an example to TFRecord file
writer.write(example.SerializeToString())
writer.close()


读取TFRecord文件

def readRecord(recordName):
"""
read TFRecord data (images).

Arguments:
recordName -- the TFRecord file to be read.
return: data saved in recordName (image and label).
"""
filenameQueue = tf.train.string_input_producer([recordName])
reader = tf.TFRecordReader()
_, serializedExample = reader.read(filenameQueue)
features = tf.parse_single_example(serializedExample, features={
"label": tf.FixedLenFeature([], tf.int64),
"image_raw": tf.FixedLenFeature([], tf.string)
})

label = features["label"]
image = features["image_raw"]
image = tf.decode_raw(image, tf.uint8)
image = tf.reshape(image,[256,256,3])
label = tf.cast(label, tf.int32)
return image, label


注意,这里我们得到的返回值都是张量,需要在tensorflow中创建session后才能得到实际的数据。如下:

##test code
image, label =  readRecord("./data/train.tfreco
4000
rds")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
#print image_batch.shape, label.shape
images, labels = sess.run([imageBatch, labelBatch])
print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels
for i in range(4):
plt.subplot(1,4,i+1)
plt.axis("off")
plt.imshow(images[i])


输出结果:

batch shape =  (4, 256, 256, 3) labels =  [0 1 1 0]
batch shape =  (4, 256, 256, 3) labels =  [1 0 0 0]
batch shape =  (4, 256, 256, 3) labels =  [1 2 2 1]
batch shape =  (4, 256, 256, 3) labels =  [1 2 1 2]
batch shape =  (4, 256, 256, 3) labels =  [0 0 1 0]
batch shape =  (4, 256, 256, 3) labels =  [2 1 2 1]
batch shape =  (4, 256, 256, 3) labels =  [1 1 0 0]
batch shape =  (4, 256, 256, 3) labels =  [0 2 1 1]
batch shape =  (4, 256, 256, 3) labels =  [2 2 1 0]
batch shape =  (4, 256, 256, 3) labels =  [1 2 0 1]
label =  [1 2 0 1]




可以发现,最后一个的batch中的图像和标签是一一对应的(0: cat; 1: dog; 2: horse),说明我们已经成功从TFRecord文件中读出了数据。

完整样例代码

import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt

currentDir = os.getcwd()
os.chdir(currentDir)
print currentDir

def preprocess(imageRawDir, imageDir):
"""
images preprocess

Arguments:
imageRawDir -- directory of primary images.
imageDir -- directory of processed images.

Return: none.
"""
imageNames = os.listdir(imageRawDir)
label = imageDir.split("/")[-2] # directory format:"./data/cat/"
for index, imageName in enumerate(imageNames):
image = Image.open(os.path.join(imageRawDir,imageName))
image = image.resize((256, 256))
savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
image.save(savePath)

##test code
catRawDir = "./data_raw/cat/"
catDir = "./data/cat/"
preprocess(catRawDir, catDir)

dogRawDir = "./data_raw/dog/"
dogDir = "./data/dog/"
preprocess(dogRawDir, dogDir)

horseRawDir = "./data_raw/horse/"
horseDir = "./data/horse/"
preprocess(horseRawDir, horseDir)

def _int64_feature(value): """ generate int64 feature. """ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value): """ generate byte feature. """ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def createRecord(imageDir):
"""
create TFRecord data.

Arguments:
imageDir -- image directory.
Return: none.
"""
writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
classNames = ["cat", "dog", "horse"]

for classIndex, className in enumerate(classNames):
print "class name = ",className
currentClassDir = os.path.join(imageDir,className)
print "current dir = ",currentClassDir
for index, imageName in enumerate(os.listdir(currentClassDir)):
image = Image.open(os.path.join(currentClassDir,imageName))
image_raw = image.tobytes() # convert image to binary format
print index, imageName

example = tf.train.Example(features = tf.train.Features(feature = {
"label": _int64_feature(classIndex),
"image_raw": _bytes_feature(image_raw),
}))
writer.write(example.SerializeToString())
writer.close()

##test code
createRecord(os.path.join(currentDir, "data/"))

def readRecord(recordName): """ read TFRecord data (images). Arguments: recordName -- the TFRecord file to be read. return: data saved in recordName (image and label). """ filenameQueue = tf.train.string_input_producer([recordName]) reader = tf.TFRecordReader() _, serializedExample = reader.read(filenameQueue) features = tf.parse_single_example(serializedExample, features={ "label": tf.FixedLenFeature([], tf.int64), "image_raw": tf.FixedLenFeature([], tf.string) }) label = features["label"] image = features["image_raw"] image = tf.decode_raw(image, tf.uint8) image = tf.reshape(image,[256,256,3]) label = tf.cast(label, tf.int32) return image, label

##test code
image, label = readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)

##test code
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
#print image_batch.shape, label.shape
images, labels = sess.run([imageBatch, labelBatch])
print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels
for i in range(4):
plt.subplot(1,4,i+1)
plt.axis("off")
plt.imshow(images[i])
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息