tensorflow的python离线训练java在线预测方案
2017-05-27 14:53
831 查看
tensorflow目前主要的使用语言主要还是python,但是有相当一部分互联网应用是用java开发的,那么java应用如何使用tensorflow开发深度学习相关的功能呢?虽然google开源了tensorflow serving用于生产环境部署训练好的模型,但需要自己实现集群功能和健康检查,同时和java应用中间还隔着一个网络通讯的开销。所以最好还是java应用内部直接调用模型。tensorflow 1.1版本已经推出了java接口,不过我看了一下目前接口数量还是比较少,跟python丰富的各类接口没法比。因此完全使用java接口来构建模型不太现实,而且我估计模型训练效率可能也没python好。另一方面,网上开源的tensorflow模型基本都是用python的,用java重新构建费时费力。基于上述原因,python构建并训练模型+java在线预测是比较合理的方案。
在python训练代码里,模型训练好以后,要用tf.train.write_graph把整个图的protobuf写到文件中,但是tf.train.write_graph只能保存图的定义和constant参数,variable会被忽略掉,可以使用tf.graph_util.convert_variables_to_constants把variable转成constant再写到文件中,这样学习到的参数就不会丢失。相关代码:
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output/logits"])
tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.1.0</version>
</dependency>
可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:
String modelDir = ".";
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "graph.pb"));
Graph g = new Graph();
g.importGraphDef(graphDef);
Session s = new Session(g);
Tensor input = constructTensor(data);
Tensor result = s.runner().feed("input", input).fetch("output/logits").run().get(0);
long[] rshape = result.shape();
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);
其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。
在python训练代码里,模型训练好以后,要用tf.train.write_graph把整个图的protobuf写到文件中,但是tf.train.write_graph只能保存图的定义和constant参数,variable会被忽略掉,可以使用tf.graph_util.convert_variables_to_constants把variable转成constant再写到文件中,这样学习到的参数就不会丢失。相关代码:
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output/logits"])
tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.1.0</version>
</dependency>
可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:
String modelDir = ".";
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "graph.pb"));
Graph g = new Graph();
g.importGraphDef(graphDef);
Session s = new Session(g);
Tensor input = constructTensor(data);
Tensor result = s.runner().feed("input", input).fetch("output/logits").run().get(0);
long[] rshape = result.shape();
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);
其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。
相关文章推荐
- Python下的数据处理和机器学习,对数据在线及本地获取、解析、预处理和训练、预测、交叉验证、可视化
- Python下的数据处理和机器学习,对数据在线及本地获取、解析、预处理和训练、预测、交叉验证、可视化
- Linux官方提供的Bash Bug在线方案and离线修复方案
- python selenium的在线安装及离线安装
- TensorFlow的训练模型在Android和Java的应用及调用
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)【转】
- ubuntukylin 16.04离线/在线安装tensorflow环境
- 在线的IDE(Ideone)支持Java/Python/Go/D
- 华为机试在线训练–牛客网(python)
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)
- Java各版本在线及离线JDK API——如何制作CHM文档
- it 自学编程在线网站-----java python js node.js c c++ android ios
- 在线编程语言模拟(Java,C,Python,R语言,Ruby,PHP,Perl,Go等)
- lightgbm_predict4j:LightGBM在线预测的java实现
- tensorflow保存网络参数 使用训练好的网络参数进行数据的预测
- python selenium的在线安装及离线安装
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- Java机器学习库ML之四模型训练和预测示例