您的位置:首页 > 编程语言 > MATLAB

【caffe学习笔记之7】caffe-matlab/python训练LeNet模型并应用于mnist数据集(2)

2017-01-11 22:23 876 查看
【案例介绍】

LeNet网络模型是一个用来识别手写数字的最经典的卷积神经网络,是Yann LeCun在1998年设计并提出的,是早期卷积神经网络中最有代表性的实验系统之一,其论文是CNN领域第一篇经典之作。本篇博客详细介绍基于Matlab、Python训练lenet手写模型的案例,作为前几次caffe深度学习框架的阶段性总结。

【生成均值文件】

接上回,mnist数据集生成leveldb数据库之后,需要计算图片均值

在train_leveldb文件夹同级建立mean文件夹,然后在当前目录下打开doc界面,输入以下命令:

compute_image_mean train_leveldb mean/mean.binaryproto --backend leveldb


然后,在mean文件夹下生成mean.binaryproto文件



【训练LeNet网络】
训练网络,有3种方法:

(1)使用可执行程序caffe.exe训练,在命令提示符下运行caffe.exe train命令,参考之前的帖子:

http://blog.csdn.net/lance313/article/details/53964874

(2)利用Matlab接口训练网络,使用solver.solve()  命令,参考之前的帖子

http://blog.csdn.net/lance313/article/details/53968657

(3)利用Python接口训练网络,本节进行相关内容的介绍

首先,修改lenet_solver.prototxt与lenet_train_test.prototxt两个文件的内容,主要是数据库的路径、类型以及求
4000
解模式,CPU/GPU

然后打开python,运行以下脚本:

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

solver = caffe.SGDSolver('./examples/mnist/lenet_solver.prototxt')
solver.solve()


在D:\caffe-master\caffe-master\examples\mnist路径下生成以下4个文件:



【均值文件格式转换】

使用Caffe的C++接口进行操作时,需要的图像均值文件是pb格式,例如常见的均值文件名为mean.binaryproto;但在使用Python接口进行操作时,需要的图像均值文件是numpy格式,例如mean.npy。所以在跨语言进行操作时,需要将mean.binaryproto转换成mean.npy,转换代码如下:

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

import numpy as np

#%%

MEAN_PROTO_PATH = "./examples/mnist/data/mean/mean.binaryproto"
MEAN_NPY_PATH = "./examples/mnist/data/mean/mean.npy"

blob = caffe.proto.caffe_pb2.BlobProto()           # 创建protobuf blob
data = open(MEAN_PROTO_PATH, 'rb' ).read()         # 读入mean.binaryproto文件内容
blob.ParseFromString(data)                         # 解析文件内容到blob

array = np.array(caffe.io.blobproto_to_array(blob))# 将blob中的均值转换成numpy格式,array的shape (mean_number,channel, hight, width)
mean_npy = array[0]                                # 一个array中可以有多组均值存在,故需要通过下标选择其中一组均值
np.save(MEAN_NPY_PATH ,mean_npy)


运行后,在mean文件夹下生成mean.npy文件:



【CPU实现图片分类】

运行以下python命令:

# -*- coding: utf-8 -*-

import caffe
caffe_root = 'D:/caffe-master/caffe-master/'

import os
os.chdir(caffe_root)
print os.getcwd()

import numpy as np
import matplotlib.pyplot as plt
# %%
# Set Caffe to CPU mode and load the net from disk.
caffe.set_mode_cpu()

model_def = caffe_root + 'examples/mnist/lenet.prototxt'  #注意!
model_weights = caffe_root + 'examples/mnist/lenet_iter_5000.caffemodel'

net = caffe.Net(model_def,      # defines the structure of the model
model_weights,  # contains the trained weights
caffe.TEST)     # use test mode (e.g., don't perform dropout)

# load the mean ImageNet image (as distributed with Caffe) for subtraction
mu = np.load(caffe_root + 'examples/mnist/data/mean/mean.npy')
mu = mu.mean(1).mean(1)  # average over pixels to obtain the mean (BGR) pixel values

# %%
# Load an image (that comes with Caffe) and perform the preprocessing we've set up.
# create transformer for the input called 'data'
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})

transformer.set_transpose('data', (2,0,1))  # move image channels to outermost dimension
transformer.set_mean('data', mu)            # subtract the dataset-mean value in each channel
transformer.set_raw_scale('data', 255)      # rescale from [0, 1] to [0, 255]
transformer.set_channel_swap('data', (2,1,0))  # swap channels from RGB to BGR

image = caffe.io.load_image(caffe_root + 'examples/mnist/data/test/TestImage_17.bmp')
transformed_image = transformer.preprocess('data', image)
plt.imshow(image)

# %%
# copy the image data into the memory allocated for the net
net.blobs['data'].data[...] = transformed_image

### perform classification
output = net.forward()
output_prob = output['prob'][0]  # the output probability vector for the first image in the batch
print 'predicted class is:', output_prob.argmax()


结果如下图所示:



程序中有个地方需要注意:

model_def = caffe_root + 'examples/mnist/lenet.prototxt'  #注意!

lenet.prototxt文件需要修改一个地方:

input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } }

需要改成

input_param { shape: { dim: 64 dim: 3 dim: 28 dim: 28 } }

这是因为手写图片虽然是黑白图片,但是上篇帖子在数据转换时,图片已转换为RGB3通道格式
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  深度学习 caffe python
相关文章推荐