您的位置:首页 > 其它

TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读

2016-12-06 10:06 756 查看

一 why tfrecord?

对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords

参考 :TensorFlow高效读取数据

二 代码详解

1导入库

import tensorflow as tf
import numpy


2 构建writer,用于写入数据

writer = tf.python_io.TFRecordWriter('test.tfrecord')


3 分俩步创建a,b,c三个不同格式的列表并保存到writer中

for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
c = c.astype(numpy.uint8)
c_raw = c.tostring()#这里是把c换了一种格式存储
print 'i:',i
print '  a:',a
print '  b:',b
print '  c:',c
example = tf.train.Example(
features = tf.train.Features( #固定模式,字典格式保存
feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
print '   writer',i,'DOWN!'
writer.close()


i: 0
a: 0.618
b: [2016, 2017]
c: [[0 1 2]
[3 4 5]]
writer 0 DOWN!
i: 1
a: 1.618
b: [2017, 2018]
c: [[1 2 3]
[4 5 6]]
writer 1 DOWN!


4 创建文件读取队列并读取其中内容(字典格式)

# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example

features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([], tf.string)
}
)


5 读取内容

a_out = features['a']
b_out = features['b']
c_raw_out = features['c']
c_out = tf.decode_raw(c_raw_out, tf.uint8)
c_out = tf.reshape(c_out, [2, 3])


6 显示格式

print a_out
print b_out
print c_out


Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("Reshape:0", shape=(2, 3), dtype=uint8)


7 通过shuffle_batch喂入数据

a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3,
capacity=200, min_after_dequeue=100, num_threads=2)


8 构建sess,读入数据并显示

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',c_val
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print '  a_val:',a_val
print '  b_val:',b_val
print '  c_val:',c_val


first batch:
a_val: [ 0.61799997  1.61800003  0.61799997]
b_val: [[2016 2017]
[2017 2018]
[2016 2017]]
c_val: [[[0 1 2]
[3 4 5]]

[[1 2 3]
[4 5 6]]

[[0 1 2]
[3 4 5]]]
second batch:
a_val: [ 0.61799997  0.61799997  1.61800003]
b_val: [[2016 2017]
[2016 2017]
[2017 2018]]
c_val: [[[0 1 2]
[3 4 5]]

[[0 1 2]
[3 4 5]]

[[1 2 3]
[4 5 6]]]


之前定义了batch=3,所以每个batch输入三个数据,并且是随机读入的。

三 完整代码

for i in range(0,2):
print i


0
1


1 把数据写成tfrecord文件

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf import numpy
writer = tf.python_io.TFRecordWriter('test1111.tfrecord')
for i in range(0, 2):
a = 0.618 + i
b = [2016 + i, 2017+i]
#c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
#c = c.astype(numpy.uint8)
c = "你好哦"+str(i)
#c_raw = c.tostring()#这里是把c换了一种格式存储
c_raw = c
print 'i:',i
print ' a:',a
print ' b:',b
print ' c:',c
example = tf.train.Example(
features = tf.train.Features(
feature = {'a':tf.train.Feature(float_list = tf.train.FloatList(value=[a])),
'b':tf.train.Feature(int64_list = tf.train.Int64List(value = b)),
'c':tf.train.Feature(bytes_list = tf.train.BytesList(value = [c_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
print ' writer',i,'DOWN!'
writer.close()


i: 0
a: 0.618
b: [2016, 2017]
c: 你好哦0
writer 0 DOWN!
i: 1
a: 1.618
b: [2017, 2018]
c: 你好哦1
writer 1 DOWN!


2数据提取及显示

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# output file name string to a queue
filename_queue = tf.train.string_input_producer(['test1111.tfrecord'], num_epochs=None)
# create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
yyp
features = tf.parse_single_example(serialized_example,
features={
'a': tf.FixedLenFeature([], tf.float32),
'b': tf.FixedLenFeature([2], tf.int64),
'c': tf.FixedLenFeature([],tf.string)
}
)
a_out = features['a']
b_out = features['b']
c_out = features['c']
#c_raw_out = features['c']
#c_raw_out = tf.sparse_to_dense(features['c'])
#c_out = tf.decode_raw(c_raw_out, tf.uint8)
print a_out print b_out print c_out
#c_out = tf.reshape(c_out, [2, 3])
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a_out, b_out, c_out], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

tf.train.start_queue_runners(sess=sess)
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
# print(a_val, b_val, c_val)
print 'first batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',c_val[0].decode('utf-8')
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
print 'second batch:'
print ' a_val:',a_val
print ' b_val:',b_val
print ' c_val:',str(c_val).decode('utf-8')


Tensor("ParseSingleExample/Squeeze_a:0", shape=(), dtype=float32)
Tensor("ParseSingleExample/Squeeze_b:0", shape=(2,), dtype=int64)
Tensor("ParseSingleExample/Squeeze_c:0", shape=(), dtype=string)
first batch:
a_val: [ 0.61799997  0.61799997  1.61800003]
b_val: [[2016 2017]
[2016 2017]
[2017 2018]]
c_val: 你好哦0
second batch:
a_val: [ 1.61800003  0.61799997  1.61800003]
b_val: [[2017 2018]
[2016 2017]
[2017 2018]]
c_val: ['\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61'
'\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'
'\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa61']


不解码输出的是unicode格式的

print '\xe4\xbd\xa0\xe5\xa5\xbd\xe5\x93\xa60'.decode('utf-8')


你好哦0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow tfrecord