tensorflow(6) mnist.train.next_batch()函数解析
2017-12-07 23:01
736 查看
之前一直用keras,用keras的fit_generator需要写一个无限循环的生成器(while True, yield X,y),然而tensorflow的feed_dict原理不一样,它需要的只是一个batch的数据而已。
那么如何保证每一次调用next还能记住上一次的位置呢?第一个想到的是全局变量。tensorflow源码是将dataset输入写为一个类,self._index_in_epoch就相当于一个全局变量。只要累对象存在,这个变量就不会消失
那么如何保证每一次调用next还能记住上一次的位置呢?第一个想到的是全局变量。tensorflow源码是将dataset输入写为一个类,self._index_in_epoch就相当于一个全局变量。只要累对象存在,这个变量就不会消失
class DataSet(object): def __init__(self, images, labels,.....) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0 #self._num_examples 是指所有训练数据的样本个数 def next_batch(self, batch_size, fake_data=False, shuffle=True): #.....中间省略过一些fake start = self._index_in_epoch #self._index_in_epoch 所有的调用,总共用了多少个样本,相当于一个全局变量 #start第一个batch为0,剩下的就和self._index_in_epoch一样,如果超过了一个epoch,在下面还会重新赋值。 # Shuffle for the first epoch if self._epochs_completed == 0 and start == 0 and shuffle: perm0 = numpy.arange(self._num_examples) #生成的一个所有样本长度的np.array numpy.random.shuffle(perm0) self._images = self.images[perm0] self._labels = self.labels[perm0] # Go to the next epoch #从这里到下一个else,所做的是一个epoch快运行完了,但是不够一个batch,将这个epoch的结尾和下一个epoch的开头拼接起来,共同组成一个batch——size的数据。 if start + batch_size > self._num_examples: # Finished epoch self._epochs_completed += 1 # Get the rest examples in this epoch rest_num_examples = self._num_examples - start # 一个epoch 最后不够一个batch还剩下几个 images_rest_part = self._images[start:self._num_examples] labels_rest_part = self._labels[start:self._num_examples] # Shuffle the data if shuffle: perm = numpy.arange(self._num_examples) numpy.random.shuffle(perm) self._images = self.images[perm] self._labels = self.labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size - rest_num_examples end = self._index_in_epoch images_new_part = self._images[start:end] labels_new_part = self._labels[start:end] return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0) #新的epoch,和上一个epoch的结尾凑成一个batch else: self._index_in_epoch += batch_size #每调用这个函数一次,_index_in_epoch就加上一个batch——size的,它相当于一个全局变量,上不封顶 end = self._index_in_epoch return self._images[start:end], self._labels[start:end]
相关文章推荐
- TensorFlow入门01:MNIST分类的源码及关键函数解析
- TensorFlow教程——nest.flatten()函数解析
- tensorflow函数解析:Session.run和Tensor.eval
- tensorflow函数解析: tf.Session() 和tf.InteractiveSession()
- 13、Tensorflow:Tensorflow数据读取有三种方式(next_batch)
- 详解Tensorflow数据读取有三种方式(next_batch)
- tensorflow实战之二:MNIST手写数字识别的优化1-代价函数优化
- Caffe中对MNIST执行train操作执行流程解析
- TensorFlow入门02:cnn实现MNIST分类的源码及关键函数解析
- TensorFlow使用next_batch()读取/tensorflow.python.framework.errors_impl.InvalidArgumentError: Expect 3 fi
- Tensorflow之MNIST解析
- mnist_train_test.prototxt代码解析
- TensorFlow技术解析与实战 9 TensorFlow在MNIST中的应用
- tensorflow: tf.train.exponential_decay函数
- [TensorFlow 学习笔记-08]tf.pad函数源码解析
- Tensorflow中提供tf.train.ExponentialMovingAverage函数实现(滑动平均模型)
- TensorFlow入门_2_mnist数据集训练与相关函数解释
- mnist的Tensorflow官方模板(fullly_connected_feed.py文件中参数解析问题)
- 《TensorFlow技术解析与实战》09 Tensorflow在mnist中的应用
- tensorflow中next_batch