tensorflow 训练接口
2018-01-15 09:22
169 查看
learn_runner
experiment
核心
初始化
train
experimenttrain
estimatortrain
train_and_evaluate
experimenttrain_and_evaluate
estimatorevaluate
run()
run_config
参考:
https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig
https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig#master
调用“schedule”函数:
estimator
由experiment中的 train()触发调用 estimator
experiment
核心
初始化
train
experimenttrain
estimatortrain
train_and_evaluate
experimenttrain_and_evaluate
estimatorevaluate
learn_runner
使用接口from tensorflow.contrib.learn import learn_runner def run_experiment(argv=None): learn_runner.run( experiment_fn=experiment_fn, # First-class function schedule=_schedule, # What to run "train" or "train_and_evaluate" run_config=run_config, # RunConfig hparams=params # HParams ) def experiment_fn(run_config, params): """Create an experiment to train and evaluate the model. Args: run_config (RunConfig): Configuration for Estimator run. params (HParam): Hyperparameters Returns: (Experiment) Experiment for training the mnist model. """ # Define the estimator estimator = get_estimator(run_config, params) # Setup data loaders # mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False) train_input_fn, train_input_hook = get_train_inputs( params.train_batch_size, params.dataset_dir, params.dataset_file_pattern) eval_input_fn, eval_input_hook = get_val_inputs( params.eval_batch_size, params.dataset_dir, params.dataset_file_pattern) # Define the experiment experiment = tf.contrib.learn.Experiment( estimator=estimator, # Estimator train_input_fn=train_input_fn, # First-class function eval_input_fn=eval_input_fn, # First-class function train_steps=params.train_steps, # Minibatch steps min_eval_frequency=params.eval_min_frequency, # Eval frequency # train_monitors=[], # Hooks for training # eval_hooks=[eval_input_hook], # Hooks for evaluation eval_steps=params.eval_steps # Use evaluation feeder until its empty ) return experiment
run()
def run(experiment_fn, schedule=None, run_config=None, hparams=None): """ Desc: It creates an Experiment by calling `experiment_fn`. Then it calls the function named as `schedule` of the Experiment. If schedule is not provided, then the default schedule for the current task type is used. The defaults are as follows: * 'ps' maps to 'serve' * 'worker' maps to 'train' * 'master' maps to 'local_run' If the experiment's config does not include a task type, then an exception is raised. Args: experiment_fn: A function that creates an `Experiment`. It accepts two arguments `run_config` and `hparams`, which should be used to create the `Estimator` (`run_config` passed as `config` to its constructor; `hparams` used as the hyper-parameters of the model). It must return an `Experiment`. schedule: The name of the method in the `Experiment` to run. run_config: `RunConfig` instance. The `run_config.model_dir` must be non-empty. hparams: `HParams` instance. The default hyper-parameters, which will be passed to the `experiment_fn` if `run_config` is not None. Returns: The return value of function `schedule`. """ # 1. get experiment # wrapped for checking the uid wrapped_experiment_fn = _wrapped_experiment_fn_with_uid_check(experiment_fn) experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams) # 2. Get the schedule run_config = run_config or experiment.estimator.config schedule = schedule or _get_default_schedule(run_config) def _execute_schedule(experiment, schedule): """Execute the method named `schedul 15300 e` of `experiment`.""" task = getattr(experiment, schedule) return task() def _get_default_schedule(config): """Returns the default schedule for the provided RunConfig.""" if not config or not _is_distributed(config): return 'train_and_evaluate' if not config.task_type: raise ValueError('Must specify a schedule') if config.task_type == run_config_lib.TaskType.MASTER: # TODO(rhaertel): handle the case where there is more than one master # or explicitly disallow such a case. return 'train_and_evaluate' elif config.task_type == run_config_lib.TaskType.PS: return 'run_std_server' elif config.task_type == run_config_lib.TaskType.WORKER: return 'train' def _is_distributed(config): """Returns true if this is a distributed job.""" if not config.cluster_spec: return False # This is considered a distributed job if there is more than one task # in the cluster spec. task_count = 0 for job in config.cluster_spec.jobs: for _ in config.cluster_spec.job_tasks(job): task_count += 1 return task_count > 1
run_config
参考:
https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig
https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig#master
experiment
核心
触发初始化:experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams)
def experiment_fn(run_config, params): # ....... # get estimator estimator = get_estimator() # Define the experiment experiment = tf.contrib.learn.Experiment( estimator=estimator, # Estimator train_input_fn=train_input_fn, # First-class function eval_input_fn=eval_input_fn, # First-class function train_steps=params.train_steps, # Minibatch steps min_eval_frequency=params.eval_min_frequency, # Eval frequency # train_monitors=[], # Hooks for training # eval_hooks=[eval_input_hook], # Hooks for evaluation eval_steps=params.eval_steps # Use evaluation feeder until its empty ) def get_estimator(run_config, params): """Return the model as a Tensorflow Estimator object. Args: run_config (RunConfig): Configuration for Estimator run. params (HParams): hyperparameters. """ return tf.estimator.Estimator( model_fn=model_fn, # First-class function params=params, # HParams config=run_config # RunConfig ) def model_fn(features, labels, mode, params): """Model function used in the estimator. Args: features (Tensor): Input features to the model. labels (Tensor): Labels tensor for training and evaluation. mode (ModeKeys): Specifies if training, evaluation or prediction. params (HParams): hyperparameters. Returns: (EstimatorSpec): Model to be run by Estimator. """ tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, # scaffold=get_scaffold() )
调用“schedule”函数:
task = getattr(experiment, schedule) return task()
初始化
class Experiment(object): def __init__(self, estimator, train_input_fn, eval_input_fn, eval_metrics=None, train_steps=None, eval_steps=100, train_monitors=None, eval_hooks=None, local_eval_frequency=None, eval_delay_secs=120, continuous_eval_throttle_secs=60, min_eval_frequency=None, delay_workers_by_global_step=False, export_strategies=None, train_steps_per_iteration=None, checkpoint_and_export=False, saving_listeners=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this constructor are executed at construction time. They are stored and used when a method is executed which requires it. Args: estimator: Object implementing Estimator interface, which could be a combination of @{tf.contrib.learn.Trainable} and @{tf.contrib.learn.Evaluable} (deprecated), or @{tf.estimator.Estimator}. train_input_fn: function, returns features and labels for training. eval_input_fn: function, returns features and labels for evaluation. If `eval_steps` is `None`, this should be configured only to produce for a finite number of batches (generally, 1 epoch over the evaluation data). train_steps: Perform this many steps of training. `None`, the default, means train forever. eval_steps: `evaluate` runs until input is exhausted (or another exception is raised), or for `eval_steps` steps, if specified. train_monitors: A list of monitors to pass to the `Estimator`'s `fit` function. eval_hooks: A list of `SessionRunHook` hooks to pass to the `Estimator`'s `evaluate` function. eval_delay_secs: Start evaluating after waiting for this many seconds. continuous_eval_throttle_secs: Do not re-evaluate unless the last evaluation was started at least this many seconds ago for continuous_eval(). min_eval_frequency: (applies only to train_and_evaluate). the minimum number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. If None, defaults to 1, unless model_dir is on GCS, in which case the default is 1000. delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: Iterable of `ExportStrategy`s, or a single one, or `None`. train_steps_per_iteration: (applies only to continuous_train_and_eval). Perform this many (integer) number of train steps for each training-evaluation iteration. With a small value, the model will be evaluated more frequently with more checkpoints saved. If `None`, will use a default value (which is smaller than `train_steps` if provided). checkpoint_and_export: (applies only to train_and_evaluate). If `True`, performs intermediate model checkpoints and exports during the training process, rather than only once model training is complete. This parameter is experimental and may be changed or removed in the future. Setting this parameter leads to the following: the value of `min_eval_frequency` will be ignored, and the number of steps between evaluations and exports will instead be determined by the Estimator configuration parameters `save_checkpoints_secs` and `save_checkpoints_steps`. Also, this parameter leads to the creation of a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the provided `train_monitors` will need to be adjusted accordingly. saving_listeners: list of `CheckpointSaverListener` objects. Used by tf.estimator.Estimator for callbacks that run immediately before or after checkpoint savings. Raises: ValueError: if `estimator` does not implement Estimator interface, or if export_strategies has the wrong type. """
train()
experiment.train()
def train(self, delay_secs=None): """Fit the estimator using the training data. Train the estimator for `self._train_steps` steps, after waiting for `delay_secs` seconds. If `self._train_steps` is `None`, train forever. Args: delay_secs: Start training after this many seconds. Returns: The trained estimator. """ start = time.time() # Start the server, if needed. It's important to start the server before # we (optionally) sleep for the case where no device_filters are set. # Otherwise, the servers will wait to connect to each other before starting # to train. We might as well start as soon as we can. #1.如果是分布式,开启server. # 1.1 如果是tf.contrib.learn.RunConfig(tf.contrib.learn.Experiment在分布式情况下只支持这种RunConfig), 确保分布式配置设置正确, 开启server # 1.2 如果不是tf.contrib.learn.RunConfig报错。 if (分布式): self._start_server() # 2 设置延时 如果设置了self._delay_workers_by_global_step workers需要使用hook extra_hooks = [] if delay_secs is None: task_id = self._estimator.config.task_id or 0 if self._delay_workers_by_global_step: # Wait 5500 global steps for the second worker. Each worker waits more # then previous one but with a diminishing number of steps. extra_hooks.append( basic_session_run_hooks.GlobalStepWaiterHook( int(8000.0 * math.log(task_id + 1)))) delay_secs = 0 else: # Wait 5 secs more for each new worker up to 60 secs. delay_secs = min(60, task_id * 5) if delay_secs > 0: elapsed_secs = time.time() - start remaining = delay_secs - elapsed_secs logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) # 3. _call_train return self._call_train( input_fn=self._train_input_fn, max_steps=self._train_steps, hooks=self._train_monitors + extra_hooks, saving_listeners=self._saving_listeners) def _call_train(self, input_fn=None, steps=None, hooks=None, max_steps=None, saving_listeners=None): # Estimator in core cannot work with monitors. We need to convert them # to hooks. For Estimator in contrib, it is converted internally. So, it is # safe to convert for both cases. hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: return self._estimator.train( input_fn=input_fn, steps=steps, max_steps=max_steps, hooks=hooks, saving_listeners=saving_listeners)
estimator
由experiment中的 train()触发调用 estimator
self._estimator.train( input_fn=input_fn, steps=steps, max_steps=max_steps, hooks=hooks, saving_listeners=saving_listeners)
estimator.train()
def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None): """Trains a model given training data input_fn. Args: input_fn: Input function returning a tuple of: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the training loop. steps: Number of steps for which to train model. If `None`, train forever or train until input_fn generates the `OutOfRange` error or `StopIteration` exception. 'steps' works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If `OutOfRange` or `StopIteration` occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set `max_steps` instead. If set, `max_steps` must be `None`. max_steps: Number of total steps for which to train model. If `None`, train forever or train until input_fn generates the `OutOfRange` error or `StopIteration` exception. If set, `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the middle, training stops before `max_steps` steps. Two calls to `train(steps=100)` means 200 training iterations. On the other hand, two calls to `train(max_steps=100)` means that the second call will not do any iteration since first call did all 100 steps. saving_listeners: list of `CheckpointSaverListener` objects. Used for callbacks that run immediately before or after checkpoint savings. Returns: `self`, for chaining. Raises: ValueError: If both `steps` and `max_steps` are not `None`. ValueError: If either `steps` or `max_steps` is <= 0. """ if max_steps is not None: start_step = _load_global_step_from_checkpoint_dir(self._model_dir) if max_steps <= start_step: logging.info('Skipping training since max_steps has already saved.') return self hooks = _check_hooks_type(hooks) hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps)) saving_listeners = _check_listeners_type(saving_listeners) loss = self._train_model(input_fn, hooks, saving_listeners) logging.info('Loss for final step: %s.', loss) return self def _load_global_step_from_checkpoint_dir(checkpoint_dir): try: checkpoint_reader = training.NewCheckpointReader( training.latest_checkpoint(checkpoint_dir)) return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP) except: # pylint: disable=bare-except return 0 def _check_hooks_type(hooks): """Returns hooks if all are SessionRunHook, raises TypeError otherwise.""" hooks = list(hooks or []) for h in hooks: if not isinstance(h, training.SessionRunHook): raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h)) return hooks def _convert_train_steps_to_hooks(self, steps, max_steps): if steps is not None or max_steps is not None: return [training.StopAtStepHook(steps, max_steps)] else: return [] def _train_model(self, input_fn, hooks, saving_listeners): worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): global_step_tensor = self._create_and_assert_global_step(g) # 1. 获取输入数据 features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN)) worker_hooks.extend(input_hooks) # 2. 调用model_fn得到 模型相关的计算图。 (存储在 tf.estimator.EstimatorSpec对象中) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) # 3. 补充一些计算图细节,包括 # 3.1 loss添加到summary和LOSSES集合中 if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]): summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) # 3.2 管理worker_hooks, 包括 experiment传来的hooks、 新补充的hooks、 estimator_spec中的hooks worker_hooks.extend(hooks) worker_hooks.extend([ training.NanTensorHook(estimator_spec.loss), training.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=100) ]) worker_hooks.extend(estimator_spec.training_hooks) # 3.3 如果没有saver则创建saver if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) # 3.4 管理chief_hooks, 主要包括estimator_spec.training_chief_hooks, 如有必要则补充CheckpointSaverHook chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ training.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold) ] # 4. MonitoredTrainingSession训练 with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config, log_step_count_steps=self._config.log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) return loss
train_and_evaluate
experiment.train_and_evaluate
def train_and_evaluate(self): """Interleaves training and evaluation. The frequency of evaluation is controlled by the constructor arg `min_eval_frequency`. When this parameter is 0, evaluation happens only after training has completed. Note that evaluation cannot happen more frequently than checkpoints are taken. If no new snapshots are available when evaluation is supposed to occur, then evaluation doesn't happen for another `min_eval_frequency` steps (assuming a checkpoint is available at that point). Thus, settings `min_eval_frequency` to 1 means that the model will be evaluated everytime there is a new checkpoint. This is particular useful for a "Master" task in the cloud, whose responsibility it is to take checkpoints, evaluate those checkpoints, and write out summaries. Participating in training as the supervisor allows such a task to accomplish the first and last items, while performing evaluation allows for the second. Returns: The result of the `evaluate` call to the `Estimator` as well as the export results using the specified `ExportStrategy`. """ # The directory to which evaluation summaries are written are determined # by adding a suffix to 'eval'; that suffix is the 'name' parameter to # the various evaluate(...) methods. By setting it to None, we force # the directory name to simply be 'eval'. eval_dir_suffix = None # We set every_n_steps to 1, but evaluation only occurs when a new # snapshot is available. If, by the time we finish evaluation # there is a new snapshot, then we just evaluate again. Otherwise, # we keep training until one becomes available. with _new_attr_context(self, "_train_monitors"): # 1. 通过monitors 来实现训练时的间隔测试。 在experiment初始化时设置的train_monitors的基础上新增 ValidationMonitor。 不过 self._train_monitors = self._train_monitors or [] if self._min_eval_frequency: self._train_monitors += [ monitors.ValidationMonitor( input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency, name=eval_dir_suffix, hooks=self._eval_hooks) ] # 2. 训练, 会使用上面设置的monitors, 使用前会将所有monitors转换成hooks self.train(delay_secs=0) #3. 训练结束后 最后执行一下evaluate eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, metrics=self._eval_metrics, # experiment 初始化时设置的eval_metrics, 它 must be `None` with `tf.estimator.Estimator`。 仅当使用contrib版本的Estimator时才能在experiment中使用 name=eval_dir_suffix, hooks=self._eval_hooks) # experiment 初始化时设置的eval_hooks export_results = self._maybe_export(eval_result) return eval_result, export_results def _call_evaluate(self, _sentinel=None, # pylint: disable=invalid-name, input_fn=None, steps=None, metrics=None, name=None, checkpoint_path=None, hooks=None): if _sentinel is not None: raise ValueError("_call_evaluate should be called with keyword args only") if self._core_estimator_used: if metrics is not None: raise ValueError( "`eval_metrics` must be `None` with `tf.estimator.Estimator`") return self._estimator.evaluate(input_fn=input_fn, steps=steps, name=name, checkpoint_path=checkpoint_path, hooks=hooks) else: return self._estimator.evaluate(input_fn=input_fn, steps=steps, metrics=metrics, name=name, checkpoint_path=checkpoint_path, hooks=hooks) class ValidationMonitor(EveryN): def _evaluate_estimator(self): return self._estimator.evaluate( input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks, name=self.name) def every_n_step_end(self, step, outputs): # 先调用父类的同名方法 super(ValidationMonitor, self).every_n_step_end(step, outputs) # Check that we are not running evaluation on the same checkpoint. latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) self._latest_path = latest_path self._latest_path_step = step # Run evaluation and log it. validation_outputs = self._evaluate_estimator() stats = [] for name in validation_outputs: stats.append("%s = %s" % (name, str(validation_outputs[name]))) logging.info("Validation (step %d): %s", step, ", ".join(stats)) # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: raise ValueError("Metric %s missing from outputs %s." % ( self.early_stopping_metric, set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or (not self.early_stopping_metric_minimize and (current_value > self._best_value))): self._best_value = current_value self._best_metrics = copy.deepcopy(validation_outputs) self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: logging.info("Stopping. Best step: {} with {} = {}." .format(self._best_value_step, self.early_stopping_metric, self._best_value)) self._early_stopped = True return True # 表示停止训练 return False
estimator.evaluate
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None): """Evaluates the model given evaluation data input_fn. For each step, calls `input_fn`, which returns one batch of data. Evaluates until: - `steps` batches are processed, or - `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). Args: input_fn: Input function returning a tuple of: features - Dictionary of string feature name to `Tensor` or `SparseTensor`. labels - `Tensor` or dictionary of `Tensor` with labels. steps: Number of steps for which to evaluate model. If `None`, evaluates until `input_fn` raises an end-of-input exception. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the evaluation call. checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the latest checkpoint in `model_dir` is used. name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard. Returns: A dict containing the evaluation metrics specified in `model_fn` keyed by name, as well as an entry `global_step` which contains the value of the global step for which this evaluation was performed. Raises: ValueError: If `steps <= 0`. ValueError: If no model has been trained, namely `model_dir`, or the given `checkpoint_path` is empty. """ hooks = _check_hooks_type(hooks) hooks.extend(self._convert_eval_steps_to_hooks(steps)) return self._evaluate_model( input_fn=input_fn, hooks=hooks, checkpoint_path=checkpoint_path, name=name) def _convert_eval_steps_to_hooks(self, steps): if steps is None: return [] if steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps)) return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''): """Evaluates the model using the training.evaluation library.""" # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: raise ValueError('Could not find trained model in model_dir: {}.'. format(self._model_dir)) checkpoint_path = latest_path # Setup output directory. eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) # 1. 获取输入数据 features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.EVAL)) # 2. 调用model_fn得到计算图信息 (estimator_spec对象) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.config) # 3. 补充部分计算图 # 3.1 loss加入到 eval_metric_ops estimator_spec.eval_metric_ops[ model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss) # 3.2 update_op 和 value_ops update_op, eval_dict = _extract_metric_update_ops( estimator_spec.eval_metric_ops) # 3.3 在value_ops里面加入 global_step, 方便打印出来迭代次数信息 eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor # 3.4 管理所有hooks 包括 传入的hooks 和 输入数据时产生的hooks all_hooks = list(input_hooks) all_hooks.extend(hooks) all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) # 4 执行一次evaluate eval_results = evaluation._evaluate_once( # pylint: disable=protected-access checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=estimator_spec.scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=all_hooks, config=self._session_config) # 5 将结果写入日志 _write_dict_to_summary( output_dir=eval_dir, dictionary=eval_results, current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) return eval_results def _extract_metric_update_ops(eval_dict): """Separate update operations from metric value operations.""" update_ops = [] value_ops = {} # Sort metrics lexicographically so graph is identical every time. for name, metric_ops in sorted(six.iteritems(eval_dict)): value_ops[name] = metric_ops[0] update_ops.append(metric_ops[1]) if update_ops: update_op = control_flow_ops.group(*update_ops) else: update_op = None return update_op, value_ops
相关文章推荐
- TensorFlow——训练自己的数据(四)模型测试
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- TensorFlow计算框架训练图片
- Tensorflow-SSD测试及训练自己的数据集
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow1.2训练cifar10步骤以及问题汇总
- TensorFlow使用C++加载使用训练好的模型,.cc文件代码实现的相关类及方法总结
- 如何用Tensorflow训练模型成pb文件(二)——基于tfrecord的读取
- 用tensorflow训练自己的图片集-用TFRecords将代码导入神经网络
- tensorflow训练权重保存和调用——tf.saver()
- TensorFlow训练线性回归
- Caffe训练Mnist实例:使用pycaffe与cmdcaffe接口
- tensorflow使用GPU训练时的显存占用问题
- 利用TensorFlow训练简单的RNN
- tensorflow入门之训练简单的神经网络方法
- tensorflow 学习笔记(四) - mnist实例--用简单的神经网络来训练和测试
- 使用Tensorflow实现多GPU并行训练
- TensorFlow实现人脸识别(4)--------对人脸样本进行训练,保存人脸识别模型
- tensorflow使用GPU训练时的显存占用问题
- TensorFlow自己训练的SSD mobilenet模型 安卓移植