ValueError:GraphDef cannot be larger than 2GB.解决办法
在使用TensorFlow 1.X版本的estimator的时候经常会碰到类似于
ValueError:GraphDef cannot be larger than 2GB的报错信息,可能的原因是数据太大无法写入graph。
一般来说,常见的数据构建方法如下:
def input_fn(): features, labels = (np.random.sample((100,2)), np.random.sample((100,1))) dataset = tf.data.Dataset.from_tensor_slices((features,labels)) dataset = dataset.shuffle(100000).repeat().batch(batch_size) return dataset ... estimator.train(input_fn)
TensorFlow在读取数据的时候会将数据也写入Graph,所以当数据量很大的时候会碰到这种情况,之前做实验在多GPU的时候也会遇到这种情况,即使我把batch size调到很低。所以解决办法有两种思路,一直不保存graph,而是使用
feed_dict的方式来构建input pipeline。
不写入graph
我的代码环境是TensorFlow1.14,所以我以这个版本为例进行介绍。
首先总结一下estimator的运行原理(假设在单卡情况下),以
estimator.train为例(eval和predict类似),其调用顺序如下:
class Estimator(): ... def train(): ... loss = self._train_model(input_fn, hooks, saving_listeners) ... def _train_model(self, input_fn, hooks, saving_listeners): if self._train_distribution: return self._train_model_distributed(input_fn, hooks, saving_listeners) else: return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): ... return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners) def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): .... with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=(tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=save_summary_steps, config=self._session_config, max_wait_secs=self._config.session_creation_timeout_secs, log_step_count_steps=log_step_count_steps) as mon_sess:
单步调试后发现,estimator写入event文件发生在调用MonitoredTrainingSession
的时刻,而真正写入event是在执行hook的时候,例如在我的实验中我设置了
log_step_count_steps这个值,这个值会每隔指定次数steps就会打印出计算速度和当前的loss值。而实现这一功能的是
StepCounterHook,它定义在
tensorflow/tensorflow/python/training/basic_session_run_hooks.py中,部分定义如下:
class StepCounterHook(session_run_hook.SessionRunHook): """Hook that counts steps per second.""" def __init__(...): ... self._summary_writer = summary_writer def begin(self): if self._summary_writer is None and self._output_dir: self._summary_writer = SummaryWriterCache.get(self._output_dir) self._summary_tag = training_util.get_global_step().op.name + "/sec" def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def _log_and_record(self, elapsed_steps, elapsed_time, global_step): steps_per_sec = elapsed_steps / elapsed_time if self._summary_writer is not None: summary = Summary(value=[ Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) ]) self._summary_writer.add_summary(summary, global_step) logging.info("%s: %g", self._summary_tag, steps_per_sec)
所以我们只需要将出现类似于self._summary_writer.add_summary
的地方注释掉,这样estimator在运行过程中就不会再生成event文件,也就不会有2GB的问题了。
feed_dict
为了在大数据量时使用 dataset,我们可以用 placeholder 创建 dataset。这时数据就不会直接写到 graph 中,graph 中只有一个 placeholder 占位符。但是,用了 placeholder 就需要我们在一开始对它进行初始化填数据,需要调用
sess.run(iter.initializer, feed_dict={ x: data })。
但是estimator并没有显示的session可以调用,那应该怎么办呢?其实我们可以使用
SessionRunHook来解决这个问题。
tf.train.SessionRunHook()类定义在
tensorflow/python/training/session_run_hook.py,该类的具体介绍可参见【转】tf.SessionRunHook使用方法。
仔细看一下 estimator 的 train 和 evaluate 函数定义可以发现它们都接收 hooks 参数,这个参数的定义是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 也就是说我们可以自己定义一个SessionRunHook作为参数传递到hook就可以了。
train( input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None )
我们现在想要在训练之前初始化 dataset 的 placeholder,那么我们就应该具体实现 SessionRunHook 的after_create_session 成员函数:
class IteratorInitializerHook(tf.train.SessionRunHook): def __init__(self): super(IteratorInitializerHook, self).__init__() self.iterator_initializer_fn = None def after_create_session(self, session, coord): del coord self.iterator_initializer_fn(session) def make_input_fn(): iterator_initializer_hook = IteratorInitializerHook() def input_fn(): x = tf.placeholder(tf.float32, shape=[None,2]) dataset = tf.data.Dataset.from_tensor_slices(x) dataset = dataset.shuffle(100000).repeat().batch(batch_size) iter = dataset.make_initializable_iterator() data = np.random.sample((100,2)) iterator_initializer_hook.iterator_initializer_fn = ( lambda sess: sess.run(iter.initializer, feed_dict={x: data}) ) return iter.get_next() return input_fn, iterator_initializer_hook ... input_fn, iterator_initializer_hook = make_input_fn() estimator.train(input_fn, hooks=[iterator_initializer_hook])
参考
MARSGGBO♥原创
2019-10-21 11:04:22
- 使用jdbcTemplate.queryForRowSet()遇到UncategorizedSQLException:Invalid precision value. Cannot be less than zero解决办法
- Error writing file: A file cannot be larger than the value set by ulimit
- Error writing file: A file cannot be larger than the value set by ulimit
- Windows下FTP发生452 Error writing file: A file cannot be larger than the value set by ulimit.错误~
- Android之error: void value not ignored as it ought to be(In function 'callMethod2')解决办法
- taskdef class xdoclet.modules.ejb.EjbDocletTask cannot be found解决办法
- mysql ERROR 2059 (HY000): Authentication plugin 'caching_sha2_password' cannot be loaded; 的解决办法
- 错误ValueError: Object arrays cannot be loaded when allow_pickle=False的解决
- java hibernate 中"node to traverse cannot be null"错误解决办法
- Error: initial value of reference to non const must be lvalue 原因以及解决方法
- Error running app: Instant Run requires 'Tools | Android | Enable ADB integration' to be enabled.解决办法
- adt "cannot be resolved to a type"错误解决办法
- 解决 Android Studio : minSdkVersion 8 cannot be smaller than version L
- Hadoop启动报Error: JAVA_HOME is not set and could not be found解决办法
- [解决办法]Python中使用json.loads解码字符串时出错:ValueError: Expecting property name: line 1 column 2 (char 1)
- Error staring Tomcat Cannot connect to VM错误解决办法
- 成功解决ValueError: `bins` must be positive, when an integer
- ValueError: Object arrays cannot be loaded when allow_pickle=False
- MYSQL:Cannot convert value '0000-00-00 00:00:00' from column 10 to TIMESTAMP的解决办法
- jsp页面出现“String cannot be resolved to a type”错误解决办法