tensorflow使用range_input_producer多线程读取数据
2017-12-07 19:13
363 查看
转载:http://blog.csdn.net/lyg5623/article/details/69387917
有少许代码修改
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2…
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
代码:
输出:
如果range_input_producer去掉参数num_epochs=1,则输出:
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
有少许代码修改
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2…
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
代码:
# encoding: UTF-8 import tensorflow as tf import codecs BATCH_SIZE = 6 NUM_EXPOCHES = 5 def input_producer(): array = codecs.open("test.txt").readlines() print(array) array = list(map(lambda line: line.strip('\n'), array)) print(array) i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE]) return inputs class Inputs(object): def __init__(self): self.inputs = input_producer() def main(*args, **kwargs): inputs = Inputs() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) sess.run(init) try: index = 0 while not coord.should_stop() and index<10: datalines = sess.run(inputs.inputs) index += 1 print("step: %d, batch data: %s" % (index, str(datalines))) except tf.errors.OutOfRangeError: print("Done traing:-------Epoch limit reached") except KeyboardInterrupt: print("keyboard interrput detected, stop training") finally: coord.request_stop() coord.join(threads) sess.close() del sess if __name__ == "__main__": main()
输出:
step: 1, batch data: [b'1' b'2' b'3' b'4' b'5' b'6'] step: 2, batch data: [b'7' b'8' b'9' b'10' b'11' b'12'] step: 3, batch data: [b'13' b'14' b'15' b'16' b'17' b'18'] step: 4, batch data: [b'19' b'20' b'21' b'22' b'23' b'24'] step: 5, batch data: [b'25' b'26' b'27' b'28' b'29' b'30'] Done traing:-------Epoch limit reached
如果range_input_producer去掉参数num_epochs=1,则输出:
step: 1, batch data: [b'1' b'2' b'3' b'4' b'5' b'6'] step: 2, batch data: [b'7' b'8' b'9' b'10' b'11' b'12'] step: 3, batch data: [b'13' b'14' b'15' b'16' b'17' b'18'] step: 4, batch data: [b'19' b'20' b'21' b'22' b'23' b'24'] step: 5, batch data: [b'25' b'26' b'27' b'28' b'29' b'30'] step: 6, batch data: [b'1' b'2' b'3' b'4' b'5' b'6'] step: 7, batch data: [b'7' b'8' b'9' b'10' b'11' b'12'] step: 8, batch data: [b'13' b'14' b'15' b'16' b'17' b'18'] s 4000 tep: 9, batch data: [b'19' b'20' b'21' b'22' b'23' b'24'] step: 10, batch data: [b'25' b'26' b'27' b'28' b'29' b'30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6 [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
相关文章推荐
- tensorflow使用range_input_producer多线程读取数据
- tensorflow使用range_input_producer多线程读取数据
- tensorflow使用range_input_producer多线程读取数据
- Tensorflow中怎么使用queue读取数据的情况下,在同一个session中边训练边测试
- 使用python读取tensorflow实例中的MNIST模拟数据
- Tensorflow中使用TFRecords高效读取数据--结合NLP数据实践
- Tensorflow中使用tfrecord方式读取数据
- Tensorflow中使用TFRecords高效读取数据--结合NLP数据实践
- Tensorflow中使用tfrecord方式读取数据的方法
- 使用logstash的logstash-input-kafka插件读取kafka中的数据
- [C#]使用Process的StandardInput与StandardOutput写入读取控制台数据
- Linux下C语言实现的简单使用线程向FIFO里写入与读取数据的例子
- 使用Robot循环读取Excel中的数据
- 使用SqlDataReader读取数据示例
- 使用SqlDataReader读取数据示例
- 使用ifstream::get()方法从文本文件中读取数据
- [教程]使用AODKeycap读取数据
- [教程]在ADOKeycap中使用DataReader读取数据
- ■ASP中使用XMLHTTP读取远程数据3
- 使用 CFile 来读取特定格式的数据