您的位置:首页 > 移动开发 > Android开发

Tensorflow手写数字识别在android中的实现

2017-07-14 16:25 627 查看

说明

下载TensorFlow Android Demo

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

生成模型

运行附件压缩包里的python脚本convnet.py生成mnist_model_graph_convnet.pb文件和graph_label_strings.txt文件:
文件

编译jar包和so库

1. 下载TensorFlow Android Demo

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
备注:

--recurse-submodules
是为了避免一些protobuf 编译问题.

2. 修改WORKSPACE文件,指定SDK、NDK的版本和路径,请务必使用NDK r12b,下载路径为:
https://developer.android.com/ndk/downloads/older_releases.html  #ndk-12b-downloads

例如,我是这样配置的:

android_sdk_repository(
name = "androidsdk",
api_level = 25,
# Ensure that you have the build_tools_version below installed in the
# SDK manager as it updates periodically.
build_tools_version = "25.0.3",
# Replace with path to Android SDK on your system
path = "/home/ckt/work/Android/Sdk",
)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
android_ndk_repository(
name="androidndk",
path="/home/ckt/work/Android/ndk-r12b/",
# This needs to be 14 or higher to compile TensorFlow.
# Note that the NDK version is not the API level.
api_level=14)


3. 编译jar包和so库

编译jar包和so库需要构建工具Bazel,Ubuntu环境下如何安装Bazel请参考网页:

https://bazel.build/versions/master/docs/install-ubuntu.html

编译jar包命令:

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java


编译完成后,可以在以下路径找到libandroid_tensorflow_inference_java.jar文件:

bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

编译so库命令:

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=armeabi-v7a

###cpu一定要适配自己的手机,否则找不到so文件###

编译完成后,可以在以下路径找到libtensorflow_inference.so文件:

bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

编写应用

1. 打开Android Studio,新建一个android工程

将jar包放入libs目录,将so库放入src/main/jniLibs/armeabi-v7a目录,将之前生成的pb文件和text文件放入src/main/assets目录

2. 将TensorFlow Android Demo中的Classifier.java和TensorFlowImageClassifier.java复制到工程,这2个文件在TensorFlow Android Demo中的的路径为:

/tensorflow/examples/android/src/org/tensorflow/demo

注意:

需要将这2个类的包名修改为自己工程的包名。

3.为了简便操作,我们将下面的mnist_test.png(一张灰度图,28×28像素,白字黑底)放到src/main/assets目录下



备注:

IMAGE_MEAN和IMAGE_STD的值在本项目没有实际意义,可以随便设置。

4.在activity中调用TensorFlowImageClassifier.create()方法创建分类器:



5. 将mnist_test.png图片转换成相应的bitmap(28x28),通过classifier.recognizeImage(bitmap)来取得预测结果



注意:

因为我们的输入数据是28x28的灰度图,原始代码用到了rgb三个通道,我们只需要一个通道,所以需要修改TensorFlowImageClassifier类的recognizeImage方法来适应模型,代码如下:



bitmapToFloatArray()方法如下:
/**
* 将bitmap转为(按行优先)一个float数组。其中的每个像素点都归一化到0~1之间。
* @param bitmap 灰度图,r,g,b分量都相等。
* @return
*/
public static float[] bitmapToFloatArray(Bitmap bitmap){
int height = bitmap.getHeight();
int width = bitmap.getWidth();
float[] result = new float[height * width];

int k = 0;
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
int argb = bitmap.getPixel(i, j);
// 由于是灰度图,所以r,g,b分量是相等的。
int r = Color.red(argb);
result[k++] = r / 255.0f;
}
}
return result;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: