Tensorflow Object Detection API 源码分析之 builders/graph_rewriter_builder.py
2018-08-16 22:47
956 查看
protos/graph_rewriter.proto
syntax = "proto2"; package object_detection.protos; // Message to configure graph rewriter for the tf graph. message GraphRewriter { optional Quantization quantization = 1; } // Message for quantization options. See // tensorflow/contrib/quantize/python/quantize.py for details. message Quantization { // Number of steps to delay before quantization takes effect during training. optional int32 delay = 1 [default = 500000]; // Number of bits to use for quantizing weights. // Only 8 bit is supported for now. optional int32 weight_bits = 2 [default = 8]; // Number of bits to use for quantizing activations. // Only 8 bit is supported for now. optional int32 activation_bits = 3 [default = 8]; }
builders/graph_rewriter_builder.py
# 量化的函数, 仅支持 8bit 量化 返回的是 graph_rewrite_fn 函数 """Functions for quantized training and evaluation.""" import tensorflow as tf def build(graph_rewriter_config, is_training): """Returns a function that modifies default graph based on options. Args: graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. is_training: whether in training of eval mode. """ def graph_rewrite_fn(): """Function to quantize weights and activation of the default graph.""" if (graph_rewriter_config.quantization.weight_bits != 8 or graph_rewriter_config.quantization.activation_bits != 8): raise ValueError('Only 8bit quantization is supported') # Quantize the graph by inserting quantize ops for weights and activations if is_training: tf.contrib.quantize.create_training_graph( input_graph=tf.get_default_graph(), quant_delay=graph_rewriter_config.quantization.delay) else: tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()) tf.contrib.layers.summarize_collection('quant_vars') return graph_rewrite_fn阅读更多
相关文章推荐
- Tensorflow Object Detection API 源码分析之 inputs.py
- 初窥Tensorflow Object Detection API 源码之(1.1) Resnet
- Tensorflow object detection API 源码阅读笔记:Fast r-cnn
- python eval.py under object detection API of TensorFlow
- 初窥Tensorflow Object Detection API 源码之(2.1.1)FasterRCNNMetaArch.predict
- 初窥Tensorflow Object Detection API 源码之(2.4)BoxPredictor
- Tensorflow object detection API 源码阅读笔记:Mask R-CNN
- 初窥Tensorflow Object Detection API 源码之(2.1)FasterRCNNMetaArch
- 初窥Tensorflow Object Detection API 源码
- 初窥Tensorflow Object Detection API 源码之(2.5)target_assigner
- 学习 train.py ( TensorFlow Object Detection API)
- 初窥Tensorflow Object Detection API 源码之(2.6)matcher
- 初窥Tensorflow Object Detection API 源码之(1.2)FeatureExtractor.Config
- Tensorflow object detection API 源码阅读笔记:RFCN
- Tensorflow object detection API 源码阅读笔记:架构
- TensorFlow object detection API应用
- Tensorflow object_detection API 目标检测环境搭建
- Tensorflow Object Detection API训练自己的数据集
- 测试TensorFlow Object Detection API
- 谷歌开放的TensorFlow Object Detection API 效果如何?对业界有什么影响