Tensorflow Object Detection API 源码分析之 utils/variables_helper.py
2018-08-15 15:42
716 查看
Tensorflow Object Detection API 源码分析之 utils/variables_helper.py
# model_lib.py 使用了 get_variables_available_in_checkpoint 函数 # 返回 在 checkpoint 包含的variables,在 model_lib.py 中 从checkpoint恢复 """Helper functions for manipulating collections of variables during training. """ import logging import re import tensorflow as tf slim = tf.contrib.slim # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in # tensorflow/contrib/framework/python/ops/variables.py def filter_variables(variables, filter_regex_list, invert=False): """Filters out the variables matching the filter_regex. Filter out the variables whose name matches the any of the regular expressions in filter_regex_list and returns the remaining variables. Optionally, if invert=True, the complement set is returned. Args: variables: a list of tensorflow variables. filter_regex_list: a list of string regular expressions. invert: (boolean). If True, returns the complement of the filter set; that is, all variables matching filter_regex are kept and all others discarded. Returns: a list of filtered variables. """ kept_vars = [] variables_to_ignore_patterns = list(filter(None, filter_regex_list)) for var in variables: add = True for pattern in variables_to_ignore_patterns: if re.match(pattern, var.op.name): add = False break if add != invert: kept_vars.append(var) return kept_vars def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): """Multiply gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. multiplier: A (float) multiplier to apply to each gradient matching the regular expression. Returns: grads_and_vars: A list of gradient to variable pairs (tuples). """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) for var in matching_vars: logging.info('Applying multiplier %f to variable [%s]', multiplier, var.op.name) grad_multipliers = {var: float(multiplier) for var in matching_vars} return slim.learning.multiply_gradients(grads_and_vars, grad_multipliers) def freeze_gradients_matching_regex(grads_and_vars, regex_list): """Freeze gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. Returns: grads_and_vars: A list of gradient to variable pairs (tuples) that do not contain the variables and gradients matching the regex. """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) kept_grads_and_vars = [pair for pair in grads_and_vars if pair[1] not in matching_vars] for var in matching_vars: logging.info('Freezing variable [%s]', var.op.name) return kept_grads_and_vars # model_lib.py 中使用的函数,返回checkpoint可用的variables def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True): """Returns the subset of variables available in the checkpoint. Inspects given checkpoint and returns the subset of variables that are available in it. TODO(rathodv): force input and output to be a dictionary. Args: variables: a list or dictionary of variables to find in checkpoint. checkpoint_path: path to the checkpoint to restore variables from. include_global_step: whether to include `global_step` variable, if it exists. Default True. Returns: A list or dictionary of variables. Raises: ValueError: if `variables` is not a list or dict. """ if isinstance(variables, list): variable_names_map = {variable.op.name: variable for variable in variables} elif isinstance(variables, dict): variable_names_map = variables else: raise ValueError('`variables` is expected to be a list or dict.') ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() if not include_global_step: ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None) vars_in_ckpt = {} for variable_name, variable in sorted(variable_names_map.items()): if variable_name in ckpt_vars_to_shape_map: if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list(): vars_in_ckpt[variable_name] = variable else: logging.warning('Variable [%s] is available in checkpoint, but has an ' 'incompatible shape with model variable. Checkpoint ' 'shape: [%s], model variable shape: [%s]. This ' 'variable will not be initialized from the checkpoint.', variable_name, ckpt_vars_to_shape_map[variable_name], variable.shape.as_list()) else: logging.warning('Variable [%s] is not available in checkpoint', variable_name) if isinstance(variables, list): return vars_in_ckpt.values() return vars_in_ckpt阅读更多
相关文章推荐
- Tensorflow Object Detection API 源码分析之 core/standard_fields.py
- Tensorflow Object Detection API 源码分析之 builders/model_builder.py
- Tensorflow Object Detection API 源码分析之 builders/optimizer_builder.py
- Tensorflow Object Detection API 源码分析之 builders/graph_rewriter_builder.py
- Tensorflow Object Detection API 源码分析之 inputs.py
- Tensorflow object detection API 源码阅读笔记:Mask R-CNN
- 初窥Tensorflow Object Detection API 源码
- 初窥Tensorflow Object Detection API 源码之(2.1.1)FasterRCNNMetaArch.predict
- 初窥Tensorflow Object Detection API 源码之(2.4)BoxPredictor
- 学习 train.py ( TensorFlow Object Detection API)
- 初窥Tensorflow Object Detection API 源码之(2.1)FasterRCNNMetaArch
- 初窥Tensorflow Object Detection API 源码之(1.2)FeatureExtractor.Config
- 初窥Tensorflow Object Detection API 源码之(2.5)target_assigner
- Tensorflow object detection API 源码阅读笔记:Fast r-cnn
- 初窥Tensorflow Object Detection API 源码之(2.6)matcher
- Tensorflow object detection API 源码阅读笔记:架构
- 初窥Tensorflow Object Detection API 源码之(1.1) Resnet
- Tensorflow object detection API 源码阅读笔记:RFCN
- python eval.py under object detection API of TensorFlow
- tensorflow object detection api 更新