您的位置:首页 > 编程语言 > Python开发

caffe之python接口实战 :pascal-multilabel-with-datalayer 官方教程源码解析

2017-12-03 10:59 786 查看

本文是官方文档的源码解析笔记系列之一

注1:本文内容属于caffe_root/example/下的ipynb文件的源码解析,旨在通过源码注释,加速初学者的学习进程。

注2:以下解析中,未对各部分英文注释做翻译,旨在告诫初学者,应该去适应原汁原味的英文教程阅读,这样有助于提升自己阅读技术文献的能力,也是高级程序员的必备素养。

注3:建议大家在jupyter nootebook环境下结合源码注释,运行程序。

Multilabel classification on PASCAL using python data-layers

In this tutorial we will do multilabel classification on PASCAL VOC 2012.

Multilabel classification is a generalization of multiclass classification, where each instance (image) can belong to many classes. For example, an image may both belong to a “beach” category and a “vacation pictures” category. In multiclass classification, on the other hand, each image belongs to a single class.

Caffe supports multilabel classification through the SigmoidCrossEntropyLoss layer, and we will load data using a Python data layer. Data could also be provided through HDF5 or LMDB data layers, but the python data layer provides endless flexibility, so that’s what we will use.

1. Preliminaries

First, make sure you compile caffe using

WITH_PYTHON_LAYER := 1

Second, download PASCAL VOC 2012. It’s available here: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html

Third, import modules:

import sys
import os

import numpy as np
import os.path as osp
import matplotlib.pyplot as plt

from copy import copy

% matplotlib inline
plt.rcParams['figure.figsize'] = (6, 6)

caffe_root = '../'  # this file is expected to be in {caffe_root}/examples
sys.path.append(caffe_root + 'python')
import caffe # If you get "No module named _caffe", either you have not built pycaffe or you have the wrong path.

from caffe import layers as L, params as P # Shortcuts to define the net prototxt.

sys.path.append("pycaffe/layers") # the datalayers we will use are in this directory.
sys.path.append("pycaffe") # the tools file is in this folder

import tools #this contains some tools that we need


Fourth, set data directories and initialize caffe

# set data root directory, e.g:
pascal_root = osp.join(caffe_root, 'data/pascal/VOC2012')

# these are the PASCAL classes, we'll need them later.
classes = np.asarray(['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])

# make sure we have the caffenet weight downloaded.
if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'):
print("Downloading pre-trained CaffeNet model...")
!../scripts/download_model_binary.py ../models/bvlc_reference_caffenet

# initialize caffe for gpu mode
caffe.set_mode_gpu()
caffe.set_device(0)


2. Define network prototxts

Let’s start by defining the nets using caffe.NetSpec. Note how we used the SigmoidCrossEntropyLoss layer. This is the right loss for multilabel classification. Also note how the data layer is defined.

# helper function for common structures
def conv_relu(bottom, ks, nout, stride=1, pad=0, group=1):
conv = L.Convolution(bottom, kernel_size=ks, stride=stride,
num_output=nout, pad=pad, group=group)
return conv, L.ReLU(conv, in_place=True)

# another helper function
def fc_relu(bottom, nout):
fc = L.InnerProduct(bottom, num_output=nout)
return fc, L.ReLU(fc, in_place=True)

# yet another helper function
def max_pool(bottom, ks, stride=1):
return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)

# main netspec wrapper
def caffenet_multilabel(data_layer_params, datalayer): #接口为数据源的输入
# setup the python data layer
n = caffe.NetSpec()
#This net uses a python datalayer: 'PascalMultilabelDataLayerSync', which is defined in './pycaffe/layers/pascal_multilabel_datalayers.py'.
#数据层data_layer_params = dict(batch_size = 128, im_shape = [227, 227], split = 'train', pascal_root = pascal_root)
n.data, n.label = L.Python(module = 'pascal_multilabel_datalayers', layer = datalayer, #the module name -- usually the filename -- that needs to be in $PYTHONPATH
ntop = 2, param_str=str(data_layer_params))

# the net itself
n.conv1, n.relu1 = conv_relu(n.data, 11, 96, stride=4)
n.pool1 = max_pool(n.relu1, 3, stride=2)
n.norm1 = L.LRN(n.pool1, local_size=5, alpha=1e-4, beta=0.75)
n.conv2, n.relu2 = conv_relu(n.norm1, 5, 256, pad=2, group=2)
n.pool2 = max_pool(n.relu2, 3, stride=2)
n.norm2 = L.LRN(n.pool2, local_size=5, alpha=1e-4, beta=0.75)
n.conv3, n.relu3 = conv_relu(n.norm2, 3, 384, pad=1)
n.conv4, n.relu4 = conv_relu(n.relu3, 3, 384, pad=1, group=2)
n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1, group=2)
n.pool5 = max_pool(n.relu5, 3, stride=2)
n.fc6, n.relu6 = fc_relu(n.pool5, 4096)
n.drop6 = L.Dropout(n.relu6, in_place=True)
n.fc7, n.relu7 = fc_relu(n.drop6, 4096)
n.drop7 = L.Dropout(n.relu7, in_place=True)
n.score = L.InnerProduct(n.drop7, num_output=20)
n.loss = L.SigmoidCrossEntropyLoss(n.score, n.label)   #注意:多标签分类采用交叉熵函数,而多分类任务是其一般情况,一般采用softmax!!!!!

return str(n.to_proto())


3. Write nets and solver files

Now we can create net and solver prototxts. For the solver, we use the CaffeSolver class from the “tools” module

workdir = './pascal_multilabel_with_datalayer' #放置prototxt文件及其他
if not os.path.isdir(workdir):
os.makedirs(workdir)

solverprototxt = tools.CaffeSolver(trainnet_prototxt_path = osp.join(workdir, "trainnet.prototxt"), testnet_prototxt_path = osp.join(workdir, "valnet.prototxt"))
solverprototxt.sp['display'] = "1"
solverprototxt.sp['base_lr'] = "0.0001"
solverprototxt.write(osp.join(workdir, 'solver.prototxt'))  #生成workdir路径下的solver.prototxt文件

# write train net.写训练网络配置文件
with open(osp.join(workdir, 'trainnet.prototxt'), 'w') as f:
# provide parameters to the data layer as a python dictionary. Easy as pie!               pascal_root = osp.join(caffe_root, 'data/pascal/VOC2012')
data_layer_params = dict(batch_size = 128, im_shape = [227, 227], split = 'train', pascal_root = pascal_root)#后面几项代表了数据源
#This net uses a python datalayer: 'PascalMultilabelDataLayerSync', which is defined in './pycaffe/layers/pascal_multilabel_datalayers.py'.
##PascalMultilabelDataLayerSync为 L.Python(module = 'pascal_multilabel_datalayers'。。。。中自定义的继承自caffe.layer的python data 类
##ascalMultilabelDataLayerSync为synchronous datalayer for training a multilabel model on PASCAL.
f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))

# write validation net.写验证网络配置文件
with open(osp.join(workdir, 'valnet.prototxt'), 'w') as f:
data_layer_params = dict(batch_size = 128, im_shape = [227, 227], split = 'val', pascal_root = pascal_root) ##后面几项代表了数据源
f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))


This net uses a python datalayer: ‘PascalMultilabelDataLayerSync’, which is defined in ‘./pycaffe/layers/pascal_multilabel_datalayers.py’.

Take a look at the code. It’s quite straight-forward, and gives you full control over data and labels.

Now we can load the caffe solver as usual.

solver = caffe.SGDSolver(osp.join(workdir, 'solver.prototxt'))
solver.net.copy_from(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel') #权重参数初始化采用caffenet
solver.test_nets[0].share_with(solver.net) #test和trainnet分享同样的内存网络结构和参数初始化
solver.step(1)


Let’s check the data we have loaded.

transformer = tools.SimpleTransformer() # This is simply to add back the bias, re-shuffle the color channels to RGB, and so on... #solverprototxt = tools.CaffeSolver()
image_index = 0 # First image in the batch.
plt.figure()
plt.imshow(transformer.deprocess(copy(solver.net.blobs['data'].data[image_index, ...])))  #deprocess把blobs的data变为RGB,HWC形式的图像,方便绘图显示!
gtlist = solver.net.blobs['label'].data[image_index, ...].astype(np.int) #第一张图片的真是标签
#classes = np.asarray(['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 。。。。
plt.title('GT: {}'.format(classes[np.where(gtlist)]))
plt.axis('off');


NOTE: we are readin the image from the data layer, so the resolution is lower than the original PASCAL image.

4. Train a net.

Let’s train the net. First, though, we need some way to measure the accuracy. Hamming distance is commonly used in multilabel problems. We also need a simple test loop. Let’s write that down.

def hamming_distance(gt, est):
return sum([1 for (g, e) in zip(gt, est) if g == e]) / float(len(gt)) #一张图片真实与预测的值的相似度,汉明距离为1时最大,

def check_accuracy(net, num_batches, batch_size = 128):
acc = 0.0
for t in range(num_batches):
net.forward()
gts = net.blobs['label'].data #1个batchsize的gts,batchsize×label类别数
ests = net.blobs['score'].data > 0 #batchsize×sscore得分向量
for gt, est in zip(gts, ests): #for each ground truth and estimated label vector
acc += hamming_distance(gt, est)
return acc / (num_batches * batch_size)


Alright, now let’s train for a while

for itt in range(6):
solver.step(100) #train迭代100次,权重参数更新100次
print 'itt:{:3d}'.format((itt + 1) * 100), 'accuracy:{0:.4f}'.format(check_accuracy(solver.test_nets[0], 50)) #每100次训练做一次测试,并输出准确率


Great, the accuracy is increasing, and it seems to converge rather quickly. It may seem strange that it starts off so high but it is because the ground truth is sparse. There are 20 classes in PASCAL, and usually only one or two is present. So predicting all zeros yields rather high accuracy. Let’s check to make sure.

def check_baseline_accuracy(net, num_batches, batch_size = 128):
acc = 0.0
for t in range(num_batches):
net.forward()
gts = net.blobs['label'].data
ests = np.zeros((batch_size, len(gts))) # zeros yields ,验证There are 20 classes in PASCAL, and usually only one or two is present,从而也能得很高的正确率
for gt, est in zip(gts, ests): #for each ground truth and estimated label vector
acc += hamming_distance(gt, est)
return acc / (num_batches * batch_size)

print 'Baseline accuracy:{0:.4f}'.format(check_baseline_accuracy(solver.test_nets[0], 5823/128))


6. Look at some prediction results

test_net = solver.test_nets[0] #只有一个测试网络,即第0个
for image_index in range(5):
plt.figure()
plt.imshow(transformer.deprocess(copy(test_net.blobs['data'].data[image_index, ...])))#copy一个blobs的图片数据,去预处理转化为可plt绘制的数据形式
gtlist = test_net.blobs['label'].data[image_index, ...].astype(np.int)
estlist = test_net.blobs['score'].data[image_index, ...] > 0
plt.title('GT: {} \n EST: {}'.format(classes[np.where(gtlist)], classes[np.where(estlist)]))
plt.axis('off')
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐