TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
2016-12-06 10:06
756 查看
一 why tfrecord?
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords参考 :TensorFlow高效读取数据
二 代码详解
1导入库
import tensorflow as tf import numpy
2 构建writer,用于写入数据
writer = tf.python_io.TFRecordWriter('test.tfrecord')
3 分俩步创建a,b,c三个不同格式的列表并保存到writer中
for i in range(0, 2): a = 0.618 + i b = [2016 + i, 2017+i] c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i c = c.astype(numpy.uint8) c_raw = c.tostring()#这里是把c换了一种格式存储 print 'i:',i print ' a:',a print ' b:',b print ' c:',c example = tf.train.Example( features = tf.train.Features( #固定模式,字典格式保存 feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])), 'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)), 'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))})) serialized = example.SerializeToString() writer.write(serialized) print ' writer',i,'DOWN!' writer.close()
i: 0 a: 0.618 b: [2016, 2017] c: [[0 1 2] [3 4 5]] writer 0 DOWN! i: 1 a: 1.618 b: [2017, 2018] c: [[1 2 3] [4 5 6]] writer 1 DOWN!
4 创建文件读取队列并读取其中内容(字典格式)
# output file name string to a queue filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None) # create a reader from file queue reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # get feature from serialized example features = tf.parse_single_example(serialized_example, features={ 'a': tf.FixedLenFeature([], tf.float32), 'b': tf.FixedLenFeature([2], tf.int64), 'c': tf.FixedLenFeature([], tf.string) } )
5 读取内容
a_out = features['a'] b_out = features['b'] c_raw_out = features['c'] c_out = tf.decode_raw(c_raw_out, tf.uint8) c_out = tf.reshape(c_out, [2, 3])
6 显示格式
print a_out print b_out print c_out
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32) Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64) Tensor("Reshape:0", shape=(2, 3), dtype=uint8)
7 通过shuffle_batch喂入数据
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)
8 构建sess,读入数据并显示
sess = tf.Session() init = tf.initialize_all_variables() sess.run(init) tf.train.start_queue_runners(sess=sess) a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) # print(a_val, b_val, c_val) print 'first batch:' print ' a_val:',a_val print ' b_val:',b_val print ' c_val:',c_val a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) print 'second batch:' print ' a_val:',a_val print ' b_val:',b_val print ' c_val:',c_val
first batch: a_val: [ 0.61799997 1.61800003 0.61799997] b_val: [[2016 2017] [2017 2018] [2016 2017]] c_val: [[[0 1 2] [3 4 5]] [[1 2 3] [4 5 6]] [[0 1 2] [3 4 5]]] second batch: a_val: [ 0.61799997 0.61799997 1.61800003] b_val: [[2016 2017] [2016 2017] [2017 2018]] c_val: [[[0 1 2] [3 4 5]] [[0 1 2] [3 4 5]] [[1 2 3] [4 5 6]]]
之前定义了batch=3,所以每个batch输入三个数据,并且是随机读入的。
三 完整代码
for i in range(0,2): print i
0 1
1 把数据写成tfrecord文件
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf import numpy
writer = tf.python_io.TFRecordWriter('test1111.tfrecord')
for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
#c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
#c = c.astype(numpy.uint8)
c = "你好哦"+str(i)
#c_raw = c.tostring()#这里是把c换了一种格式存储
c_raw = c
print 'i:',i
print ' a:',a
print ' b:',b
print ' c:',c
example = tf.train.Example(
features = tf.train.Features(
feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
print ' writer',i,'DOWN!'
writer.close()
i: 0 a: 0.618 b: [2016, 2017] c: 你好哦0 writer 0 DOWN! i: 1 a: 1.618 b: [2017, 2018] c: 你好哦1 writer 1 DOWN!
2数据提取及显示
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test1111.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
yyp
features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([],tf.string)
}
)
a_out = features['a']
b_out = features['b']
c_out = features['c']
#c_raw_out = features['c']
#c_raw_out = tf.sparse_to_dense(features['c'])
#c_out = tf.decode_raw(c_raw_out, tf.uint8)
print a_out print b_out print c_out
#c_out = tf.reshape(c_out, [2, 3])
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',c_val[0].decode('utf-8')
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',str(c_val).decode('utf-8')
Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32) Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64) Tensor("ParseSingleExample/Squeeze_c:0", shape=(), dtype=string) first batch: a_val: [ 0.61799997 0.61799997 1.61800003] b_val: [[2016 2017] [2016 2017] [2017 2018]] c_val: 你好哦0 second batch: a_val: [ 1.61800003 0.61799997 1.61800003] b_val: [[2017 2018] [2016 2017] [2017 2018]] c_val: ['\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61' '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60' '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61']
不解码输出的是unicode格式的
print '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'.decode('utf-8')
你好哦0
相关文章推荐
- TensorFlow高效读取数据——TFRecord
- Tensorflow学习笔记-通过slim读取TFRecord文件
- 【TensorFlow系列】【一】利用TFRecordDataset读取图片数据
- tensorflow读取数据-tfrecord格式
- Tensorflow中使用tfrecord方式读取数据
- TensorFlow 学习(二) 制作自己的TFRecord数据集,读取,显示及代码详解
- Tensorflow中使用tfrecord方式读取数据的方法
- Tensorflow学习教程------tfrecords数据格式生成与读取
- tensorflow读取数据-tfrecord格式
- Tensorflow读取数据2-tfrecord
- 云端TensorFlow读取数据IO的高效方式
- Tensorflow深度学习入门——TF自带的数据文件读取及下载
- 深度学习小白——Tensorflow(三) 读取数据
- TensorFlow学习(十一):保存TFRecord文件
- TFRecord —— tensorflow 下的统一数据存储格式
- TensorFlow读取tfrecords数据
- Tensorflow深度学习入门——下载和读取MNIST数据
- tensorflow中的TFRecord格式文件的写入和读取
- TensorFlow高效读取数据的方法
- TensorFlow学习---tf生成数据的方法