您的位置:首页 > 其它

统一数据格式TFRecord

2018-02-11 16:03 405 查看

背景

机器学习的数据可以通过很多种方式进行存储,例如csv文件,excel文件,txt文件等。为了能够将各种种类的数据进行统一,我们将采用TRRECORD的格式进行统一,这种格式比较便于对训练数据的属性进行管理,同时也便于进行数据的多线程输入。

介绍

TFRecord数据通过tf.train.Example Protocol Buffer格式存储的。Protocol Buffer可以上网查询。可以把Protocol buffer看作是一些message的类。

//一个数据样本
message Example{
Features feature =1;
};
//一个特征
message Features{
map<string,Feature> feature =1;
};
//对应的特征值类型
message Features{
oneof kind{
BytesList bytes_list=1;
FloatList float_list=2;
Int64List int64_list=3;
}
};


保存数据集

保存一个TFRecord的数据集大致的步骤如下:

解析出一个训练集的图像部分和标签部分

新建一个TFRecord文件

循环写入样本

图像数据字符串化

生成example实例,图像像素、标签、图像数据

example序列化后写入文件

下面是保存Mnist数据的代码。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

#定义两个初始化函数
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets("dataset/", dtype=tf.uint8, one_hot=True)
#解析出图片
images = mnist.train.images
#解析出标签
labels = mnist.train.labels
#解析出像素
pixels = images.shape[1]
num_examples = mnist.train.num_examples
#保存的文件格式是*.tfrecords
filename = "./tfrecord/output.tfrecords"
#定义写入对象writer
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_exampl
4000
es):
image_raw = images[index].tostring()
#定义example实例,根据类型选定函数进行初始化
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(labels[index])),
'image_raw' : _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()




读取数据

import tensorflow as tf
import scipy.misc
#定义读取实例
reader = tf.TFRecordReader()
#定义文件队列
files = ['./tfrecord/output.tfrecords']
file_queue = tf.train.string_input_producer(files)

#定义读取一个样本的操作
_,one_example = reader.read(file_queue)
features = tf.parse_single_example(one_example,features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64),
})
#解析图像的操作
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)

with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
image,label,pixel = sess.run([images,labels,pixels])
#注意这里一定要reshape
test = tf.reshape(image,[28,28])
scipy.misc.imsave('./pics/'+str(i)+'.png',sess.run(test))
coord.request_stop()
coord.join(threads)


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: