tensorflow serving 服务部署与访问(Python + Java)
2017-11-21 15:00
1436 查看
我的目标是使用tensorflow serving 用docker部署模型后,将服务暴露出来,分别在Python和Java中对模型进行访问,因为tensorflow serving的文档较少,grpc使用花了不少时间,不过总算是可以用了。
后续优化:这样简单地部署的Serving服务,,所以每次调用都需要花比较多的时间(感觉像是每次都需要加载模型,本地加载完模型后单预测只需要十几毫秒),需要后续找时间看看有没有办法让模型预加载,服务调用时使用预测方法。
我的总体环境:
tensorflow 1.3.0
python 3.5
java 1.8
可以使用python命令生成模型文件夹,里面包含了saved_model.pb文件和variables文件夹
接着在container中可以新建一个文件夹,如serving-models,在文件夹下新建该模型文件夹classify_data,用来存放的模型文件夹,使用docker拷贝的命令拷贝模型到模型文件夹中:
启动模型服务,监听9000端口:
我们可以定义自己的proto文件,并使用tenserflow/serving/api中的proto来生成代码,这里我不打算如此做,而是用pip install tensorflow-serving-client安装了一个第三方提供的库来访问tensorflow serving服务,python代码如下:
最终输出,例如:
在pom.xml下加入依赖:
Java代码如下:
结果打印如下:
后续优化:这样简单地部署的Serving服务,,所以每次调用都需要花比较多的时间(感觉像是每次都需要加载模型,本地加载完模型后单预测只需要十几毫秒),需要后续找时间看看有没有办法让模型预加载,服务调用时使用预测方法。
Tensorflow Serving 服务部署
我直接用tensorflow serving docker部署的,直接按照官方的文档即可,唯一可能不同的是国内的网络问题,可以将下载和安装的步骤从dockerfile里面转移到登陆docker container去手动做。我的总体环境:
tensorflow 1.3.0
python 3.5
java 1.8
Tensorflow Serving 服务编写
这里我训练了一个分类器,主要有三个分类,主要代码如下:#设置导出时的目录特征名 export_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) #为了接收平铺开的图片数组(Java处理比较麻烦) 150528 = 224*224*3 x = tf.placeholder(tf.float32, [None, 150528]) x2 = tf.reshape(x, [-1, 224, 224, 3]) #我自己的网络预测 prob = net.network(x2) sess = tf.Session() #恢复模型参数 saver = tf.train.Saver() module_file = tf.train.latest_checkpoint(weights_path) saver.restore(sess, module_file) #获取top 1预测 values, indices = tf.nn.top_k(prob, 1) #创建模型输出builder builder = tf.saved_model.builder.SavedModelBuilder(exporter_path + export_time) #转化tensor到模型支持的格式tensor_info,下面的reshape是因为只想输出单个结果数组,否则是二维的 tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_pro = tf.saved_model.utils.build_tensor_info(tf.reshape(values, [1])) tensor_info_classify = tf.saved_model.utils.build_tensor_info(tf.reshape(indices, [1])) #定义方法名和输入输出 signature_def_map = { "predict_image": tf.saved_model.signature_def_utils.build_signature_def( inputs={"image": tensor_info_x}, outputs={ "pro": tensor_info_pro, "classify": tensor_info_classify }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME )} builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map) builder.save()
可以使用python命令生成模型文件夹,里面包含了saved_model.pb文件和variables文件夹
接着在container中可以新建一个文件夹,如serving-models,在文件夹下新建该模型文件夹classify_data,用来存放的模型文件夹,使用docker拷贝的命令拷贝模型到模型文件夹中:
docker cp 本机模型文件夹 containerId:/serving-models/classify_data/模型版本号
启动模型服务,监听9000端口:
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=classify_data --model_base_path=/serving-models/classify_data/
Python 客户端
编写Python访问客户端,可以运行看看之前保存模型时,signature_def_map的输入输出:inputs { key: "image" value { name: "Placeholder:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 224 } dim { size: 224 } dim { size: 3 } } } } outputs { key: "classify" value { name: "ToFloat_1:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 1 } } } } outputs { key: "pro" value { name: "TopKV2:0" dtype: DT_FLOAT tensor_shape { dim { size: -1 } dim { size: 1 } } } }
我们可以定义自己的proto文件,并使用tenserflow/serving/api中的proto来生成代码,这里我不打算如此做,而是用pip install tensorflow-serving-client安装了一个第三方提供的库来访问tensorflow serving服务,python代码如下:
import sys sys.path.insert(0, "./") from tensorflow_serving_client.protos import predict_pb2, prediction_service_pb2 import cv2 from grpc.beta import implementations import tensorflow as tf from tensorflow.python.framework import dtypes import time #注意,如果在windows下测试,文件名可能需要写成:im_name = r"测试文件目录\文件名" im_name = "测试文件目录/文件名" if __name__ == '__main__': #文件读取和处理 im = cv2.imread(im_name) re_im = cv2.resize(im, (224, 224), interpolation=cv2.INTER_CUBIC) #记个时 start_time = time.time() #建立连接 channel = implementations.insecure_channel("你的ip", 9000) stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) request = predict_pb2.PredictRequest() #这里由保存和运行时定义,第一个是运行时配置的模型名,第二个是保存时输入的方法名 request.model_spec.name = "classify_data" #入参参照入参定义 request.inputs["image"].ParseFromString(tf.contrib.util.make_tensor_proto(re_im, dtype=dtypes.float32, shape=[1, 224, 224, 3]).SerializeToString()) #第二个参数是最大等待时间,因为这里是block模式访问的 response = stub.Predict(request, 10.0) results = {} for key in response.outputs: tensor_proto = response.outputs[key] nd_array = tf.contrib.util.make_ndarray(tensor_proto) results[key] = nd_array print("cost %ss to predict: " % (time.time() - start_time)) print(results["pro"]) print(results["classify"])
最终输出,例如:
cost 5.115269899368286s to predict: [ 1.] [2]
Java 访问
Java和Python一样,可以选择自己编译proto文件,也可以像我一样用第三方库,我是用的是这个http://mvnrepository.com/artifact/com.yesup.oss/tensorflow-client/1.4-2在pom.xml下加入依赖:
<dependency> <groupId>com.yesup.oss</groupId> <artifactId>tensorflow-client</artifactId> <version>1.4-2</version> </dependency> <!-- 这个库是做图像处理的 --> <dependency> <groupId>net.coobird</groupId> <artifactId>thumbnailator</artifactId> <version>0.4.8</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty</artifactId> <version>1.7.0</version> </dependency> <dependency> <groupId>io.netty</groupId> <artifactId>netty-tcnative-boringssl-static</artifactId> <version>2.0.7.Final</version> </dependency>
Java代码如下:
String file = "文件地址" //读取文件,强制修改图片大小,设置输出文件格式bmp(模型定义时输入数据是无编码的) BufferedImage im = Thumbnails.of(file).forceSize(224, 224).outputFormat("bmp").asBufferedImage(); //转换图片到图片数组,匹配输入数据类型为Float Raster raster = im.getData(); List<Float> floatList = new ArrayList<>(); float [] temp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()]; float [] pixels = raster.getPixels(0,0,raster.getWidth(),raster.getHeight(),temp); for(float pixel: pixels) { floatList.add(pixel); } #记个时 long t = System.currentTimeMillis(); #创建连接,注意usePlaintext设置为true表示用非SSL连接 ManagedChannel channel = ManagedChannelBuilder.forAddress("192.168.2.24", 9000).usePlaintext(true).build(); //这里还是先用block模式 PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); //创建请求 Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder(); //模型名称和模型方法名预设 Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder(); modelSpecBuilder.setName("classify_data"); modelSpecBuilder.setSignatureName("predict_image"); predictRequestBuilder.setModelSpec(modelSpecBuilder); //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法 TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder(); tensorProtoBuilder.setDtype(DataType.DT_FLOAT); TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); #150528 = 224 * 224 * 3 tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(150528)); tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build()); tensorProtoBuilder.addAllFloatVal(floatList); predictRequestBuilder.putInputs("image", tensorProtoBuilder.build()); //访问并获取结果 Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build()); System.out.println("classify is: " + predictResponse.getOutputsOrThrow("classify").getIntVal(0)); System.out.println("prob is: " + predictResponse.getOutputsOrThrow("pro").getFloatVal(0)); System.out.println("cost time: " + (System.currentTimeMillis() - t));
结果打印如下:
classify is: 2 prob is: 1.0 cost time: 6911
相关文章推荐
- Tensorflow Serving 模型部署和服务
- tensorflow serving:bazel方式部署模型+docker方式部署模型及提供服务以及使用该服务介绍(总有一款适合你)
- Tensorflow Serving 模型部署和服务
- Docker使用tensorflow serving部署mnist模型
- python + tensorflow tensorboard HTTP://0.0.0.0:6006 无法访问 解决方法
- Tensorflow Serving介绍及部署安装
- 138、Tensorflow serving 实现模型的部署
- tensorflow serving 安装报错:java.io.IOException: Cannot run program "patch"
- [AI开发]Python+Tensorflow打造自己的计算机视觉API服务
- Win10下基于Docker使用tensorflow serving部署模型
- TensorFlow Serving和Kubernetes 服务Inception模型
- TensorFlow Serving-TensorFlow 服务
- 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口
- Java 客户端和WCF服务端访问原理和部署步骤(英文) 分享(转)
- 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口
- 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口
- win7配置数据源和ODBC数据源部署类型和访问权限(windows 服务无法访问数据源的问题)
- 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口
- 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口