您的位置:首页 > 移动开发 > Objective-C

Tensorflow Object Detection API 源码分析之 utils/variables_helper.py

2018-08-15 15:42 716 查看

Tensorflow Object Detection API 源码分析之 utils/variables_helper.py

# model_lib.py 使用了 get_variables_available_in_checkpoint 函数
# 返回 在 checkpoint 包含的variables,在 model_lib.py 中 从checkpoint恢复
"""Helper functions for manipulating collections of variables during training.
import logging
import re

import tensorflow as tf

slim = tf.contrib.slim

# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
"""Filters out the variables matching the filter_regex.

Filter out the variables whose name matches the any of the regular
expressions in filter_regex_list and returns the remaining variables.
Optionally, if invert=True, the complement set is returned.

variables: a list of tensorflow variables.
filter_regex_list: a list of string regular expressions.
invert: (boolean).  If True, returns the complement of the filter set; that
is, all variables matching filter_regex are kept and all others discarded.

a list of filtered variables.
kept_vars = []
variables_to_ignore_patterns = list(filter(None, filter_regex_list))
for var in variables:
add = True
for pattern in variables_to_ignore_patterns:
if re.match(pattern, var.op.name):
add = False
if add != invert:
return kept_vars

def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
"""Multiply gradients whose variable names match a regular expression.

grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.
multiplier: A (float) multiplier to apply to each gradient matching the
regular expression.

grads_and_vars: A list of gradient to variable pairs (tuples).
variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True)
for var in matching_vars:
logging.info('Applying multiplier %f to variable [%s]',
multiplier, var.op.name)
grad_multipliers = {var: float(multiplier) for var in matching_vars}
return slim.learning.multiply_gradients(grads_and_vars,

def freeze_gradients_matching_regex(grads_and_vars, regex_list):
"""Freeze gradients whose variable names match a regular expression.

grads_and_vars: A list of gradient to variable pairs (tuples).
regex_list: A list of string regular expressions.

grads_and_vars: A list of gradient to variable pairs (tuples) that do not
contain the variables and gradients matching the regex.
variables = [pair[1] for pair in grads_and_vars]
matching_vars = filter_variables(variables, regex_list, invert=True)
kept_grads_and_vars = [pair for pair in grads_and_vars
if pair[1] not in matching_vars]
for var in matching_vars:
logging.info('Freezing variable [%s]', var.op.name)
return kept_grads_and_vars

# model_lib.py 中使用的函数,返回checkpoint可用的variables
def get_variables_available_in_checkpoint(variables,
"""Returns the subset of variables available in the checkpoint.

Inspects given checkpoint and returns the subset of variables that are
available in it.

TODO(rathodv): force input and output to be a dictionary.

variables: a list or dictionary of variables to find in checkpoint.
checkpoint_path: path to the checkpoint to restore variables from.
include_global_step: whether to include `global_step` variable, if it
exists. Default True.

A list or dictionary of variables.
ValueError: if `variables` is not a list or dict.
if isinstance(variables, list):
variable_names_map = {variable.op.name: variable for variable in variables}
elif isinstance(variables, dict):
variable_names_map = variables
raise ValueError('`variables` is expected to be a list or dict.')
ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
if not include_global_step:
ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
vars_in_ckpt = {}
for variable_name, variable in sorted(variable_names_map.items()):
if variable_name in ckpt_vars_to_shape_map:
if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
vars_in_ckpt[variable_name] = variable
logging.warning('Variable [%s] is available in checkpoint, but has an '
'incompatible shape with model variable. Checkpoint '
'shape: [%s], model variable shape: [%s]. This '
'variable will not be initialized from the checkpoint.',
variable_name, ckpt_vars_to_shape_map[variable_name],
logging.warning('Variable [%s] is not available in checkpoint',
if isinstance(variables, list):
return vars_in_ckpt.values()
return vars_in_ckpt
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息