您的位置:首页 > 其它

tensorflow 训练接口

2018-01-15 09:22 169 查看
learn_runner

experiment
核心

初始化

train
experimenttrain

estimatortrain

train_and_evaluate
experimenttrain_and_evaluate

estimatorevaluate

learn_runner

使用接口

from tensorflow.contrib.learn import learn_runner

def run_experiment(argv=None):
learn_runner.run(
experiment_fn=experiment_fn,  # First-class function
schedule=_schedule,  # What to run "train" or "train_and_evaluate"
run_config=run_config,  # RunConfig
hparams=params  # HParams
)

def experiment_fn(run_config, params):
"""Create an experiment to train and evaluate the model.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParam): Hyperparameters
Returns:
(Experiment) Experiment for training the mnist model.
"""
# Define the estimator
estimator = get_estimator(run_config, params)

# Setup data loaders
# mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)
train_input_fn, train_input_hook = get_train_inputs(
params.train_batch_size, params.dataset_dir, params.dataset_file_pattern)
eval_input_fn, eval_input_hook = get_val_inputs(
params.eval_batch_size, params.dataset_dir, params.dataset_file_pattern)

# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator,  # Estimator
train_input_fn=train_input_fn,  # First-class function
eval_input_fn=eval_input_fn,  # First-class function
train_steps=params.train_steps,  # Minibatch steps
min_eval_frequency=params.eval_min_frequency,  # Eval frequency
# train_monitors=[],  # Hooks for training
# eval_hooks=[eval_input_hook],  # Hooks for evaluation
eval_steps=params.eval_steps  # Use evaluation feeder until its empty
)

return experiment


run()

def run(experiment_fn, schedule=None, run_config=None,
hparams=None):
"""
Desc:
It creates an Experiment by calling `experiment_fn`. Then it calls the
function named as `schedule` of the Experiment.

If schedule is not provided, then the default schedule for the current task
type is used. The defaults are as follows:

* 'ps' maps to 'serve'
* 'worker' maps to 'train'
* 'master' maps to 'local_run'

If the experiment's config does not include a task type, then an exception
is raised.

Args:
experiment_fn: A function that creates an `Experiment`.
It accepts two arguments `run_config` and `hparams`, which should be
used to create the `Estimator` (`run_config` passed as `config` to its
constructor; `hparams` used as the hyper-parameters of the model).
It must return an `Experiment`.
schedule: The name of the method in the `Experiment` to run.
run_config: `RunConfig` instance. The `run_config.model_dir` must be
non-empty.
hparams: `HParams` instance. The default hyper-parameters, which will be
passed to the `experiment_fn` if `run_config` is not None.

Returns:
The return value of function `schedule`.
"""
# 1. get experiment
# wrapped for checking the uid
wrapped_experiment_fn = _wrapped_experiment_fn_with_uid_check(experiment_fn)
experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams)

# 2. Get the schedule
run_config = run_config or experiment.estimator.config
schedule = schedule or _get_default_schedule(run_config)

def _execute_schedule(experiment, schedule):
"""Execute the method named `schedul
15300
e` of `experiment`."""
task = getattr(experiment, schedule)
return task()

def _get_default_schedule(config):
"""Returns the default schedule for the provided RunConfig."""
if not config or not _is_distributed(config):
return 'train_and_evaluate'

if not config.task_type:
raise ValueError('Must specify a schedule')

if config.task_type == run_config_lib.TaskType.MASTER:
# TODO(rhaertel): handle the case where there is more than one master
# or explicitly disallow such a case.
return 'train_and_evaluate'
elif config.task_type == run_config_lib.TaskType.PS:
return 'run_std_server'
elif config.task_type == run_config_lib.TaskType.WORKER:
return 'train'

def _is_distributed(config):
"""Returns true if this is a distributed job."""
if not config.cluster_spec:
return False

# This is considered a distributed job if there is more than one task
# in the cluster spec.
task_count = 0
for job in config.cluster_spec.jobs:
for _ in config.cluster_spec.job_tasks(job):
task_count += 1

return task_count > 1


run_config

参考:

https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig

https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig#master

experiment

核心

触发初始化:
experiment = wrapped_experiment_fn(run_config=run_config, hparams=hparams)


def experiment_fn(run_config, params):
# .......

# get estimator
estimator = get_estimator()

# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator,  # Estimator
train_input_fn=train_input_fn,  # First-class function
eval_input_fn=eval_input_fn,  # First-class function
train_steps=params.train_steps,  # Minibatch steps
min_eval_frequency=params.eval_min_frequency,  # Eval frequency
# train_monitors=[],  # Hooks for training
# eval_hooks=[eval_input_hook],  # Hooks for evaluation
eval_steps=params.eval_steps  # Use evaluation feeder until its empty
)

def get_estimator(run_config, params):
"""Return the model as a Tensorflow Estimator object.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParams): hyperparameters.
"""
return tf.estimator.Estimator(
model_fn=model_fn,  # First-class function
params=params,  # HParams
config=run_config  # RunConfig
)

def model_fn(features, labels, mode, params):
"""Model function used in the estimator.
Args:
features (Tensor): Input features to the model.
labels (Tensor): Labels tensor for training and evaluation.
mode (ModeKeys): Specifies if training, evaluation or prediction.
params (HParams): hyperparameters.
Returns:
(EstimatorSpec): Model to be run by Estimator.
"""
tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
# scaffold=get_scaffold()
)


调用“schedule”函数:
task = getattr(experiment, schedule)  return task()


初始化

class Experiment(object):

def __init__(self,
estimator,
train_input_fn,
eval_input_fn,
eval_metrics=None,
train_steps=None,
eval_steps=100,
train_monitors=None,
eval_hooks=None,
local_eval_frequency=None,
eval_delay_secs=120,
continuous_eval_throttle_secs=60,
min_eval_frequency=None,
delay_workers_by_global_step=False,
export_strategies=None,
train_steps_per_iteration=None,
checkpoint_and_export=False,
saving_listeners=None):
"""Constructor for `Experiment`.

Creates an Experiment instance. None of the functions passed to this
constructor are executed at construction time. They are stored and used
when a method is executed which requires it.

Args:
estimator: Object implementing Estimator interface, which could be a
combination of @{tf.contrib.learn.Trainable} and
@{tf.contrib.learn.Evaluable} (deprecated), or
@{tf.estimator.Estimator}.
train_input_fn: function, returns features and labels for training.
eval_input_fn: function, returns features and labels for evaluation. If
`eval_steps` is `None`, this should be configured only to produce for a
finite number of batches (generally, 1 epoch over the evaluation data).
train_steps: Perform this many steps of training. `None`, the default,
means train forever.
eval_steps: `evaluate` runs until input is exhausted (or another exception
is raised), or for `eval_steps` steps, if specified.
train_monitors: A list of monitors to pass to the `Estimator`'s `fit`
function.
eval_hooks: A list of `SessionRunHook` hooks to pass to the
`Estimator`'s `evaluate` function.
eval_delay_secs: Start evaluating after waiting for this many seconds.
continuous_eval_throttle_secs: Do not re-evaluate unless the last
evaluation was started at least this many seconds ago for
continuous_eval().
min_eval_frequency: (applies only to train_and_evaluate). the minimum
number of steps between evaluations. Of course, evaluation does not
occur if no new snapshot is available, hence, this is the minimum.
If 0, the evaluation will only happen after training.
If None, defaults to 1, unless model_dir is on GCS, in which case the
default is 1000.
delay_workers_by_global_step: if `True` delays training workers
based on global step instead of time.
export_strategies: Iterable of `ExportStrategy`s, or a single one, or
`None`.
train_steps_per_iteration: (applies only to continuous_train_and_eval).
Perform this many (integer) number of train steps for each
training-evaluation iteration. With a small value, the model will be
evaluated more frequently with more checkpoints saved. If `None`, will
use a default value (which is smaller than `train_steps` if provided).
checkpoint_and_export: (applies only to train_and_evaluate). If `True`,
performs intermediate model checkpoints and exports during the training
process, rather than only once model training is complete. This
parameter is experimental and may be changed or removed in the future.
Setting this parameter leads to the following: the value of
`min_eval_frequency` will be ignored, and the number of steps between
evaluations and exports will instead be determined by the Estimator
configuration parameters `save_checkpoints_secs` and
`save_checkpoints_steps`. Also, this parameter leads to the creation of
a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the
provided `train_monitors` will need to be adjusted accordingly.
saving_listeners: list of `CheckpointSaverListener` objects. Used by
tf.estimator.Estimator for callbacks that run immediately before or
after checkpoint savings.

Raises:
ValueError: if `estimator` does not implement Estimator interface,
or if export_strategies has the wrong type.
"""


train()

experiment.train()

def train(self, delay_secs=None):
"""Fit the estimator using the training data.

Train the estimator for `self._train_steps` steps, after waiting for
`delay_secs` seconds. If `self._train_steps` is `None`, train forever.

Args:
delay_secs: Start training after this many seconds.

Returns:
The trained estimator.
"""
start = time.time()

# Start the server, if needed. It's important to start the server before
# we (optionally) sleep for the case where no device_filters are set.
# Otherwise, the servers will wait to connect to each other before starting
# to train. We might as well start as soon as we can.
#1.如果是分布式,开启server.
#  1.1 如果是tf.contrib.learn.RunConfig(tf.contrib.learn.Experiment在分布式情况下只支持这种RunConfig), 确保分布式配置设置正确, 开启server
#  1.2 如果不是tf.contrib.learn.RunConfig报错。
if (分布式):
self._start_server()

# 2 设置延时 如果设置了self._delay_workers_by_global_step workers需要使用hook
extra_hooks = []
if delay_secs is None:
task_id = self._estimator.config.task_id or 0
if self._delay_workers_by_global_step:
# Wait 5500 global steps for the second worker. Each worker waits more
# then previous one but with a diminishing number of steps.
extra_hooks.append(
basic_session_run_hooks.GlobalStepWaiterHook(
int(8000.0 * math.log(task_id + 1))))
delay_secs = 0
else:
# Wait 5 secs more for each new worker up to 60 secs.
delay_secs = min(60, task_id * 5)

if delay_secs > 0:
elapsed_secs = time.time() - start
remaining = delay_secs - elapsed_secs
logging.info("Waiting %d secs before starting training.", remaining)
time.sleep(delay_secs)

# 3. _call_train
return self._call_train(
input_fn=self._train_input_fn,
max_steps=self._train_steps,
hooks=self._train_monitors + extra_hooks,
saving_listeners=self._saving_listeners)

def _call_train(self,
input_fn=None, steps=None, hooks=None, max_steps=None,
saving_listeners=None):

# Estimator in core cannot work with monitors. We need to convert them
# to hooks. For Estimator in contrib, it is converted internally. So, it is
# safe to convert for both cases.
hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
if self._core_estimator_used:
return self._estimator.train(
input_fn=input_fn,
steps=steps,
max_steps=max_steps,
hooks=hooks,
saving_listeners=saving_listeners)


estimator

由experiment中的 train()触发调用 estimator

self._estimator.train(
input_fn=input_fn,
steps=steps,
max_steps=max_steps,
hooks=hooks,
saving_listeners=saving_listeners)


estimator.train()

def train(self,
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None):
"""Trains a model given training data input_fn.

Args:
input_fn: Input function returning a tuple of:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the training loop.
steps: Number of steps for which to train model. If `None`, train forever
or train until input_fn generates the `OutOfRange` error or
`StopIteration` exception. 'steps' works incrementally. If you call two
times train(steps=10) then training occurs in total 20 steps. If
`OutOfRange` or `StopIteration` occurs in the middle, training stops
before 20 steps. If you don't want to have incremental behavior please
set `max_steps` instead. If set, `max_steps` must be `None`.
max_steps: Number of total steps for which to train model. If `None`,
train forever or train until input_fn generates the `OutOfRange` error
or `StopIteration` exception. If set, `steps` must be `None`. If
`OutOfRange` or `StopIteration` occurs in the middle, training stops
before `max_steps` steps.
Two calls to `train(steps=100)` means 200 training
iterations. On the other hand, two calls to `train(max_steps=100)` means
that the second call will not do any iteration since first call did
all 100 steps.
saving_listeners: list of `CheckpointSaverListener` objects. Used for
callbacks that run immediately before or after checkpoint savings.

Returns:
`self`, for chaining.

Raises:
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps` is <= 0.
"""

if max_steps is not None:
start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return self

hooks = _check_hooks_type(hooks)
hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))

saving_listeners = _check_listeners_type(saving_listeners)

loss = self._train_model(input_fn, hooks, saving_listeners)
logging.info('Loss for final step: %s.', loss)
return self

def _load_global_step_from_checkpoint_dir(checkpoint_dir):
try:
checkpoint_reader = training.NewCheckpointReader(
training.latest_checkpoint(checkpoint_dir))
return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
except:  # pylint: disable=bare-except
return 0

def _check_hooks_type(hooks):
"""Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
hooks = list(hooks or [])
for h in hooks:
if not isinstance(h, training.SessionRunHook):
raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))
return hooks

def _convert_train_steps_to_hooks(self, steps, max_steps):
if steps is not None or max_steps is not None:
return [training.StopAtStepHook(steps, max_steps)]
else:
return []

def _train_model(self, input_fn, hooks, saving_listeners):
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):

global_step_tensor = self._create_and_assert_global_step(g)

# 1. 获取输入数据
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)

# 2. 调用model_fn得到 模型相关的计算图。 (存储在 tf.estimator.EstimatorSpec对象中)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)

# 3. 补充一些计算图细节,包括
# 3.1 loss添加到summary和LOSSES集合中
if not any([x.op.name == 'loss'
for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
summary.scalar('loss', estimator_spec.loss)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)

# 3.2 管理worker_hooks, 包括 experiment传来的hooks、 新补充的hooks、 estimator_spec中的hooks
worker_hooks.extend(hooks)
worker_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
{
'loss': estimator_spec.loss,
'step': global_step_tensor
},
every_n_iter=100)
])
worker_hooks.extend(estimator_spec.training_hooks)

# 3.3 如果没有saver则创建saver
if not (estimator_spec.scaffold.saver or
ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
self._config.keep_checkpoint_every_n_hours),
defer_build=True,
save_relative_paths=True))
# 3.4 管理chief_hooks, 主要包括estimator_spec.training_chief_hooks, 如有必要则补充CheckpointSaverHook
chief_hooks = []
all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
saver_hooks = [
h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
if not saver_hooks:
chief_hooks = [
training.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=estimator_spec.scaffold)
]

# 4. MonitoredTrainingSession训练
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=worker_hooks,
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0,  # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=self._session_config,
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
return loss


train_and_evaluate

experiment.train_and_evaluate

def train_and_evaluate(self):
"""Interleaves training and evaluation.

The frequency of evaluation is controlled by the constructor arg
`min_eval_frequency`. When this parameter is 0, evaluation happens
only after training has completed. Note that evaluation cannot happen
more frequently than checkpoints are taken. If no new snapshots are
available when evaluation is supposed to occur, then evaluation doesn't
happen for another `min_eval_frequency` steps (assuming a checkpoint is
available at that point). Thus, settings `min_eval_frequency` to 1 means
that the model will be evaluated everytime there is a new checkpoint.

This is particular useful for a "Master" task in the cloud, whose
responsibility it is to take checkpoints, evaluate those checkpoints,
and write out summaries. Participating in training as the supervisor
allows such a task to accomplish the first and last items, while
performing evaluation allows for the second.

Returns:
The result of the `evaluate` call to the `Estimator` as well as the
export results using the specified `ExportStrategy`.
"""
# The directory to which evaluation summaries are written are determined
# by adding a suffix to 'eval'; that suffix is the 'name' parameter to
# the various evaluate(...) methods. By setting it to None, we force
# the directory name to simply be 'eval'.
eval_dir_suffix = None

# We set every_n_steps to 1, but evaluation only occurs when a new
# snapshot is available. If, by the time we finish evaluation
# there is a new snapshot, then we just evaluate again. Otherwise,
# we keep training until one becomes available.
with _new_attr_context(self, "_train_monitors"):
# 1. 通过monitors 来实现训练时的间隔测试。 在experiment初始化时设置的train_monitors的基础上新增 ValidationMonitor。   不过
self._train_monitors = self._train_monitors or []

if self._min_eval_frequency:
self._train_monitors += [
monitors.ValidationMonitor(
input_fn=self._eval_input_fn,
eval_steps=self._eval_steps,
metrics=self._eval_metrics,
every_n_steps=self._min_eval_frequency,
name=eval_dir_suffix,
hooks=self._eval_hooks)
]
# 2. 训练, 会使用上面设置的monitors, 使用前会将所有monitors转换成hooks
self.train(delay_secs=0)

#3. 训练结束后 最后执行一下evaluate
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,
metrics=self._eval_metrics, # experiment 初始化时设置的eval_metrics, 它 must be `None` with `tf.estimator.Estimator`。 仅当使用contrib版本的Estimator时才能在experiment中使用
name=eval_dir_suffix,
hooks=self._eval_hooks) # experiment 初始化时设置的eval_hooks
export_results = self._maybe_export(eval_result)
return eval_result, export_results

def _call_evaluate(self, _sentinel=None,  # pylint: disable=invalid-name,
input_fn=None, steps=None, metrics=None, name=None,
checkpoint_path=None, hooks=None):
if _sentinel is not None:
raise ValueError("_call_evaluate should be called with keyword args only")

if self._core_estimator_used:
if metrics is not None:
raise ValueError(
"`eval_metrics` must be `None` with `tf.estimator.Estimator`")
return self._estimator.evaluate(input_fn=input_fn,
steps=steps,
name=name,
checkpoint_path=checkpoint_path,
hooks=hooks)
else:
return self._estimator.evaluate(input_fn=input_fn,
steps=steps,
metrics=metrics,
name=name,
checkpoint_path=checkpoint_path,
hooks=hooks)

class ValidationMonitor(EveryN):

def _evaluate_estimator(self):

return self._estimator.evaluate(
input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks,
name=self.name)

def every_n_step_end(self, step, outputs):
# 先调用父类的同名方法
super(ValidationMonitor, self).every_n_step_end(step, outputs)

# Check that we are not running evaluation on the same checkpoint.
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)

self._latest_path = latest_path
self._latest_path_step = step

# Run evaluation and log it.
validation_outputs = self._evaluate_estimator()
stats = []
for name in validation_outputs:
stats.append("%s = %s" % (name, str(validation_outputs[name])))
logging.info("Validation (step %d): %s", step, ", ".join(stats))

# Early stopping logic.
if self.early_stopping_rounds is not None:
if self.early_stopping_metric not in validation_outputs:
raise ValueError("Metric %s missing from outputs %s." % (
self.early_stopping_metric, set(validation_outputs.keys())))
current_value = validation_outputs[self.early_stopping_metric]
if (self._best_value is None or (self.early_stopping_metric_minimize and
(current_value < self._best_value)) or
(not self.early_stopping_metric_minimize and
(current_value > self._best_value))):
self._best_value = current_value
self._best_metrics = copy.deepcopy(validation_outputs)
self._best_value_step = step
stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
if stop_now:
logging.info("Stopping. Best step: {} with {} = {}."
.format(self._best_value_step,
self.early_stopping_metric, self._best_value))
self._early_stopped = True
return True # 表示停止训练
return False


estimator.evaluate

def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
"""Evaluates the model given evaluation data input_fn.

For each step, calls `input_fn`, which returns one batch of data.
Evaluates until:
- `steps` batches are processed, or
- `input_fn` raises an end-of-input exception (`OutOfRangeError` or
`StopIteration`).

Args:
input_fn: Input function returning a tuple of:
features - Dictionary of string feature name to `Tensor` or
`SparseTensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
steps: Number of steps for which to evaluate model. If `None`, evaluates
until `input_fn` raises an end-of-input exception.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the evaluation call.
checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
latest checkpoint in `model_dir` is used.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data. Metrics for
different evaluations are saved in separate folders, and appear
separately in tensorboard.

Returns:
A dict containing the evaluation metrics specified in `model_fn` keyed by
name, as well as an entry `global_step` which contains the value of the
global step for which this evaluation was performed.

Raises:
ValueError: If `steps <= 0`.
ValueError: If no model has been trained, namely `model_dir`, or the
given `checkpoint_path` is empty.
"""
hooks = _check_hooks_type(hooks)
hooks.extend(self._convert_eval_steps_to_hooks(steps))

return self._evaluate_model(
input_fn=input_fn,
hooks=hooks,
checkpoint_path=checkpoint_path,
name=name)

def _convert_eval_steps_to_hooks(self, steps):
if steps is None:
return []

if steps <= 0:
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
return [evaluation._StopAfterNEvalsHook(num_evals=steps)]  # pylint: disable=protected-access

def _evaluate_model(self,
input_fn,
hooks=None,
checkpoint_path=None,
name=''):
"""Evaluates the model using the training.evaluation library."""
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
if not latest_path:
raise ValueError('Could not find trained model in model_dir: {}.'.
format(self._model_dir))
checkpoint_path = latest_path

# Setup output directory.
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)

with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
# 1. 获取输入数据
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.EVAL))
# 2. 调用model_fn得到计算图信息 (estimator_spec对象)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.config)

# 3. 补充部分计算图
# 3.1 loss加入到 eval_metric_ops
estimator_spec.eval_metric_ops[
model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss)
# 3.2 update_op 和 value_ops
update_op, eval_dict = _extract_metric_update_ops(
estimator_spec.eval_metric_ops)
# 3.3 在value_ops里面加入  global_step, 方便打印出来迭代次数信息
eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor

# 3.4 管理所有hooks 包括 传入的hooks 和 输入数据时产生的hooks
all_hooks = list(input_hooks)
all_hooks.extend(hooks)
all_hooks.extend(list(estimator_spec.evaluation_hooks or []))

# 4 执行一次evaluate
eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access
checkpoint_path=checkpoint_path,
master=self._config.evaluation_master,
scaffold=estimator_spec.scaffold,
eval_ops=update_op,
final_ops=eval_dict,
hooks=all_hooks,
config=self._session_config)

# 5 将结果写入日志
_write_dict_to_summary(
output_dir=eval_dir,
dictionary=eval_results,
current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])

return eval_results

def _extract_metric_update_ops(eval_dict):
"""Separate update operations from metric value operations."""
update_ops = []
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
for name, metric_ops in sorted(six.iteritems(eval_dict)):
value_ops[name] = metric_ops[0]
update_ops.append(metric_ops[1])

if update_ops:
update_op = control_flow_ops.group(*update_ops)
else:
update_op = None

return update_op, value_ops
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: