您的位置:首页 > 编程语言 > Java开发

tensorflow训练好的模型中java调用

2017-10-20 23:28 513 查看


最近基于bi-lstm做了一个辱骂识别模型准备部署到线上,之前打算用python 启动一个service 通过http请求来调用,发现公司平台是基于rpc服务的,开发部署起来也较蛋疼,今天下午闲来没事,看到tensorflow中有提供官方例子,通过python中训练好模型,用java来调用,刚刚好摸索了下,动手写了下代码,总算能在java中调用,废话不多说,直接看代码实现情况。


tensorflow版本情况:

In [1]: import  tensorflow as tf


In [2]: tf.__version__

Out[2]: '1.2.1'



java需要1.8的版本


maven依赖:

<dependency>

<groupId>org.tensorflow</groupId>

<artifactId>tensorflow</artifactId>

<version>1.2.1</version>

</dependency>



参考资料:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

http://blog.csdn.net/lyg5623/article/details/72781405


tensorflow训练模型时候要保存的模型参数,主要有是三个,一个是模型输入的tensor大小,一个是dropout参数,一个是模型预测的logits(score/pred_y 表示name_scope下的pred_y)值,也就是y;模型保存为一个二进制文件,可以在java中加载:

if i%500==0  and i>0:

graph = tf.graph_util.convert_variables_to_constants(session, session.graph_def,

["keep_prob", "input_x", "score/pred_y"])

tf.train.write_graph(graph, ".", "/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph/graph.db",

as_text=False)



java代码如下,其中gettexttoid方法参考tensorflow中 tensorflow.contrib.keras.preprocessing.sequence.pad_sequences下的实现,用于做文本预测:

package com.meituan.test;


import java.io.BufferedReader;

import java.io.File;

import java.io.FileInputStream;

import java.io.IOException;

import java.io.InputStreamReader;

import java.nio.ByteBuffer;

import java.nio.ByteOrder;

import java.nio.IntBuffer;

import java.nio.file.Files;

import java.nio.file.Paths;

import java.nio.file.Path;

import java.util.ArrayList;

import java.util.Arrays;

import java.util.Collection;

import java.util.HashMap;

import java.util.List;

import java.util.Map;


import org.apache.commons.io.FileUtils;

import org.apache.commons.lang.StringUtils;

import org.tensorflow.Graph;

import org.tensorflow.Session;

import org.tensorflow.Tensor;


public class TensorflowEx {

private static String path = "/Users/shuubiasahi/Documents/python/credit-tftextclassify-abuse/vocab_cnews.txt";

private static Map<String, Integer> word_to_id = new HashMap<String, Integer>();

static {

try {

BufferedReader buffer = null;

buffer = new BufferedReader(new InputStreamReader(new FileInputStream(path)));

int i=0;

String line=buffer.readLine().trim();

while(line!=null){

word_to_id.put(line, i++);

line=buffer.readLine().trim();

}

buffer.close();


} catch (Exception e) {


}

System.out.println("word_to_id.size is:"+word_to_id.size());


}


public static void main(String[] args) {

byte[] graphDef = readAllBytesOrExit(Paths.get(

"/Users/shuubiasahi/Desktop/tensorflow/modelsavegraph",

"graph.db"));

Graph g = new Graph();

g.importGraphDef(graphDef);

Session sess = new Session(g);

String text="艹你麻痹的垃圾店家,劳资点的香干回锅肉套餐,你他麻痹炒个香干炒肉过来凑数,套餐内所有的东西都没看到,还尼玛口口声声说退款?退你麻痹,留着给你家人买棺材用吧,狗日的东西!";

int[][] arr=gettexttoid(text);

Tensor input = Tensor.create(arr);

Tensor x = Tensor.create(1.0f);

Tensor result = sess.runner().feed("input_x", input).feed("keep_prob", x)

.fetch("score/pred_y").run().get(0);


long[] rshape = result.shape();

/*

* 模型为一个二分类模型,故nlabels=2,模型预测一条数据故batchsize=1

* 预测出来是一个1*2的数组,一条数据有两个概率

*

**/

int nlabels = (int) rshape[1];

int batchSize = (int) rshape[0];


float[][] logits = result.copyTo(new float[batchSize][nlabels]);


System.out.println("辱骂模型识别的概率为:"+logits[0][1]);


System.out.println("sucess");


}


private static byte[] readAllBytesOrExit(Path path) {

try {

return Files.readAllBytes(path);

} catch (IOException e) {

System.err.println("Failed to read [" + path + "]: "

+ e.getMessage());

System.exit(1);

}

return null;

}


/*

 * 序列默人长度为300

 * */

public  static int[][] gettexttoid(String text){

int[][] xpad = new int[1][300];


if(StringUtils.isBlank(text)){

return xpad; 

}


char[] chs=text.trim().toLowerCase().toCharArray();

List<Integer> list=new ArrayList<Integer>();

for(int i=0;i<chs.length;i++){

String element=Character.toString(chs[i]);

if(word_to_id.containsKey(element)){

list.add(word_to_id.get(element));

}

}

if(list.size()==0){

return xpad;

}

int size = list.size();

Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);

    int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();

if(size<=300){

System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);

}else{

System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);


}

return xpad;

}


/*

 * 自定义长度

 * */

public  static int[][] gettexttoid(String text,int maxlen){

if(maxlen<1){

throw new IllegalArgumentException("maxlen长度必须大于等于1");

}


int[][] xpad = new int[1][maxlen];


if(StringUtils.isBlank(text)){

return xpad; 

}


char[] chs=text.trim().toLowerCase().toCharArray();

List<Integer> list=new ArrayList<Integer>();

for(int i=0;i<chs.length;i++){

String element=Character.toString(chs[i]);

if(word_to_id.containsKey(element)){

list.add(word_to_id.get(element));

}

}

if(list.size()==0){

return xpad;

}

int size = list.size();

Integer[] targetInter= (Integer[]) list.toArray(new Integer[size]);

    int[] target= Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();

if(size<=maxlen){

System.arraycopy(target, 0, xpad[0], xpad[0].length-size, target.length);

}else{

System.arraycopy(target, size-xpad[0].length, xpad[0], 0, xpad[0].length);


}


return xpad;

}




}



结果对比:


java结果:






python启动的service结果:






结果一致,下周计划写个java service项目,把模型部署上线。

不过我碰到过问题,在java中做预测,1秒最多只能预测十来条文本,这感觉太慢了,不知道什么原因,我机器用的cpu,不知道是否要用gpu做预测,有知道的告诉我

联系我  xuxu_ge
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: