Tensorflow 模型文件格式转换
2017-09-20 10:32
465 查看
Tensorflow模型的graph结构可以保存为.pb文件或者.pbtxt文件,或者.meta文件,其中只有.pbtxt文件是可读的
网上大牛们训练好的网络,往往会利用我上篇博客讲的方法,将模型保存为一个统一的.pb文件,这个文件中不止保存着模型网络的结构和变量名,
还保存了所有变量的值,如果我们想利用别人训练好的模型对自己的数据进行测试,往往要对这个模型做一些修改,
参见我的下一篇博客《Tensorflow之迁移学习》,
这时我们经常需要知道原有模型里面的一些张量名称,但是.pb文件和.meta文件都是不可读的,所有有必要对这两种文件进行格式转换。
①.meta文件
这种情况下,通常还需要其他几个checkpoint文件,checkpoint ,model.cpkt.index,model.cpkt.data 等,可以使用tensofrflow安装目录下的 /home/zhaixingzhe/tensorflow/tensorflow/python/tools/inspect_checkpoint.py 文件打印输出模型中所有张量(tensor)和操作(op)的名称
下面是inspect_checkpoint.py的全部代码:
②.pb文件
下面的代码定义了两个函数,可以实现.pb文件和.pbtxt文件之间的转换
网上大牛们训练好的网络,往往会利用我上篇博客讲的方法,将模型保存为一个统一的.pb文件,这个文件中不止保存着模型网络的结构和变量名,
还保存了所有变量的值,如果我们想利用别人训练好的模型对自己的数据进行测试,往往要对这个模型做一些修改,
参见我的下一篇博客《Tensorflow之迁移学习》,
这时我们经常需要知道原有模型里面的一些张量名称,但是.pb文件和.meta文件都是不可读的,所有有必要对这两种文件进行格式转换。
①.meta文件
这种情况下,通常还需要其他几个checkpoint文件,checkpoint ,model.cpkt.index,model.cpkt.data 等,可以使用tensofrflow安装目录下的 /home/zhaixingzhe/tensorflow/tensorflow/python/tools/inspect_checkpoint.py 文件打印输出模型中所有张量(tensor)和操作(op)的名称
下面是inspect_checkpoint.py的全部代码:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A simple script for inspect checkpoint files.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app from tensorflow.python.platform import flags FLAGS = None def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) 4000 if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key) print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") if ("Data loss" in str(e) and (any([e in file_name for e in [".index", ".meta", ".data"]]))): proposed_file = ".".join(file_name.split(".")[0:-1]) v2_file_error_template = """ It's likely that this is a V2 checkpoint and you need to provide the filename *prefix*. Try removing the '.' and extension. Try: inspect checkpoint --file_name = {}""" print(v2_file_error_template.format(proposed_file)) def parse_numpy_printoption(kv_str): """Sets a single numpy printoption from a string of the form 'x=y'. See documentation on numpy.set_printoptions() for details about what values x and y can take. x can be any option listed there other than 'formatter'. Args: kv_str: A string of the form 'x=y', such as 'threshold=100000' Raises: argparse.ArgumentTypeError: If the string couldn't be used to set any nump printoption. """ k_v_str = kv_str.split("=", 1) if len(k_v_str) != 2 or not k_v_str[0]: raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) k, v_str = k_v_str printoptions = np.get_printoptions() if k not in printoptions: raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) v_type = type(printoptions[k]) if v_type is type(None): raise argparse.ArgumentTypeError( "Setting '%s' from the command line is not supported." % k) try: v = (v_type(v_str) if v_type is not bool else flags.BooleanParser().parse(v_str)) except ValueError as e: raise argparse.ArgumentTypeError(e.message) np.set_printoptions(**{k: v}) def main(unused_argv): if not FLAGS.file_name: print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " "[--tensor_name=tensor_to_print]") sys.exit(1) else: print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, FLAGS.all_tensors) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( "--file_name", type=str, default="", help="Checkpoint filename. " "Note, if using Checkpoint V2 format, file_name is the " "shared prefix between all files in the checkpoint.") parser.add_argument( "--tensor_name", type=str, default="", help="Name of the tensor to inspect") parser.add_argument( "--all_tensors", nargs="?", const=True, type="bool", default=False, help="If True, print the values of all the tensors.") parser.add_argument( "--printoptions", nargs="*", type=parse_numpy_printoption, help="Argument for numpy.set_printoptions(), in the form 'k=v'.") FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed)
②.pb文件
下面的代码定义了两个函数,可以实现.pb文件和.pbtxt文件之间的转换
import tensorflow as tf from tensorflow.python.platform import gfile from google.protobuf import text_format def convert_pb_to_pbtxt(filename): with gfile.FastGFile(filename,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True) return def convert_pbtxt_to_pb(filename): """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file. Args: filename: The name of a file containing a GraphDef pbtxt (text-formatted `tf.GraphDef` protocol buffer data). """ with tf.gfile.FastGFile(filename, 'r') as f: graph_def = tf.GraphDef() file_content = f.read() # Merges the human-readable string in `file_content` into `graph_def`. text_format.Merge(file_content, graph_def) tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False ) return
相关文章推荐
- [开发总结]系统架构及数据模型----AutoDesk文件格式转换篇(五)
- PHP将mysql数据表转换为excel文件格式
- 将ppt转换成pdf文件格式
- 移动指定文件并可转换格式
- UTF8 == 是Unicode传送格式。即把Unicode文件转换成BYTE的传送流
- fat32 格式转换成NTFS,不丢失硬盘原有文件
- 文件内容在Js(Jquery)中,字符串与JSON格式互相转换的示例(直接运行例子)
- xslt格式转换——解析指定文件 转换为指定格式
- 基于Java的图片文件格式转换和线性缩放(2)
- txt文件转换成pdf格式的转换方法
- pdf文件怎么转换成html格式
- Office技巧:CAD转换PDF格式文件使用方法
- Android 读取txt文件并以utf-8格式转换成字符串
- 如何将pdf文件转换成word格式
- 检测到 Mac 文件格式: 请将源文件转换为 DOS 格式或 UNIX 格式
- HDFS文件系统内的文件格式转换(zip格式转化成gzip格式)
- 将tensorflow网络模型(图+权值)保存为.pb文件,并从.pb文件中还原网络模型
- 牛人写的把其他格式的文件转换成flv格式的文件,java版
- C# .csv文件转为Excel格式;Excel格式转换为.csv
- 机房合作之ER模型图“生成SQL文件”并“转换ER实体联系图”