您的位置:首页 > 编程语言

深度学习 8. MatConvNet 相关函数解释说明,MatConvNet 代码理解(四)cnn_train.m 的注释

2017-01-18 15:59 666 查看
本文为原创文章转载必须注明本文出处以及附上 本文地址超链接  以及 博主博客地址:http://blog.csdn.net/qq_20259459  和 作者邮箱(
jinweizhi93@gmai.com )。

(如果喜欢本文,欢迎大家关注我的博客或者动手点个赞,有需要可以邮件联系我)

接上一篇文章(阅读上一篇文章:http://blog.csdn.net/qq_20259459/article/details/54600368 )

(四)cnn_train.m 

%调用cnn_train:
% [ net, info ] = cnn_train(net, imdb, @getBatch, opts.train, 'val', find(imdb.images.set == 3)) ;

function [net, stats] = cnn_train(net, imdb, getBatch, varargin)
%% --------------------------------------------------------------
% 函数名:cnn_train
% 功能: 1.用于训练过程
% 2.使用随机梯度下降法(SGD)
% ------------------------------------------------------------------------
%CNN_TRAIN An example implementation of SGD for training CNNs
% CNN_TRAIN() is an example learner implementing stochastic
% gradient descent with momentum to train a CNN. It can be used
% with different datasets and tasks by providing a suitable
% getBatch function.
%
% The function automatically restarts after each training epoch by
% checkpointing.
%
% The function supports training on CPU or on one or more GPUs
% (specify the list of GPU IDs in the `gpus` option).

% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).
% ------------------------------------------------------------------------
%翻译:
%cnn_train是一个学习器的示例,基于SGD算法对CNN进行训练。
%通过适当的getBatch函数,cnn_train可以被用在训练不同的数据集,以实现不同目的的训练。
%cnn_train提供了自动检查上次训练状态并且继续接着训练的能力。
%cnn_train支持使用GPU并且同时支持多个GPU的并行运算
% ------------------------------------------------------------------------

opts.subsetSize = 1e4;

opts.expDir = fullfile('data','exp') ; %选择保存路径
opts.continue = true ; %选择每次重启都是接着上次训练状态开始
opts.batchSize = 256 ; %选择初始化批的大小为256
opts.numSubBatches = 1 ; %选择子批的个数为1(不划分子批)
opts.train = [] ; %初始化训练集索引为空
opts.val = [] ; %初始化验证集索引为空
opts.gpus = [] ; %选择GPU
opts.prefetch = false ; %选择是否预读取下一批次的样本(初始化为否)
opts.numEpochs = 300 ; %选择epoch为300
opts.learningRate = 0.001 ; %选择学习率为0.001
opts.weightDecay = 0.0005 ; %选择权重延迟为0.0005
opts.momentum = 0.9 ; %选择动量为0.9
opts.saveMomentum = true ; %选择存储动量
opts.nesterovUpdate = false ; %选择nesterovUpdate为假
opts.randomSeed = 0 ; %选择随机种子为0
opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ; %选择内存映射文件
opts.profile = false ; %选择profile为假
opts.parameterServer.method = 'mmap' ; %选择参数server的途径为mmap
opts.parameterServer.prefix = 'mcn' ; %选择参数server的词头为mcn

opts.conserveMemory = true ; %选择是否保存内存(是)
opts.backPropDepth = +inf ; %选择BP的深度(传到底)
opts.sync = false ; %选择是否同步(是)
opts.cudnn = true ; %选择是否使用cudnn(是)
opts.errorFunction = 'multiclass' ; %选择误差函数为多类误差
opts.errorLabels = {} ; %初始化错误标签为空
opts.plotDiagnostics = false ; %选择是否绘制诊断信息(否)
opts.plotStatistics = true; %选择是否绘制过程统计信息(是)
opts = vl_argparse(opts, varargin) ; %调用vl_argparse函数,修改默认参数配置

% ------------------------------------------------------------------------
% 初始化准备工作
% ------------------------------------------------------------------------
if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end %如果不存在保存路径就创建该路径
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end %如果imdb.images.set==1就得到训练样本索引集
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end %如果imdb.images.set==2就得到验证样本索引集
if isnan(opts.train), opts.train = [] ; end %如果opts.train中有非数字元素存在就返回true并且清空训练集
if isnan(opts.val), opts.val = [] ; end %如果opts.val中有非数字元素存在就返回true并且清空val集

% -------------------------------------------------------------------------
% Initialization
% 初始化
% -------------------------------------------------------------------------

net = vl_simplenn_tidy(net); % fill in some eventually missing values|||为网络添加最终缺失值
net.layers{end-1}.precious = 1; % do not remove predictions, used for error|||不要移除predictions,用于误差计算
vl_simplenn_display(net, 'batchSize', opts.batchSize) ; %在控制台输出batchSize信息

evaluateMode = isempty(opts.train) ; %如果训练集为空就进入评估模式
if ~evaluateMode %如果训练集不为空就进入训练模式:
for i=1:numel(net.layers)
J = numel(net.layers{i}.weights) ;
if ~isfield(net.layers{i}, 'learningRate')
net.layers{i}.learningRate = ones(1, J) ;
end
if ~isfield(net.layers{i}, 'weightDecay')
net.layers{i}.weightDecay = ones(1, J) ;
end
end
end

% setup error calculation function
%设置误差计算函数
hasError = true ;
if isstr(opts.errorFunction)
switch opts.errorFunction %选择误差类型
case 'none' %没有误差的case
opts.errorFunction = @error_none ;
hasError = false ;
case 'multiclass' %多类误差的case
opts.errorFunction = @error_multiclass ;
if isempty(opts.errorLabels), opts.errorLabels = {'top1err', 'top5err'} ; end
case 'binary' %二值误差的case
opts.errorFunction = @error_binary ;
if isempty(opts.errorLabels), opts.errorLabels = {'binerr'} ; end
otherwise %其他
error('Unknown error function ''%s''.', opts.errorFunction) ;
end
end

state.getBatch = getBatch ;
stats = [] ;

% -------------------------------------------------------------------------
% Train and validate
% 训练和验证
% -------------------------------------------------------------------------

modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); %保存训练好的模型已经误差曲线
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; %训练结果统计图

start = opts.continue * findLastCheckpoint(opts.expDir) ; %选择训练开始的位置
if start >= 1 %从上次停下的状态继续训练
fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
[net, state, stats] = loadState(modelPath(start)) ;
else
state = [] ;
end

for epoch=start+1:opts.numEpochs

% Set the random seed based on the epoch and opts.randomSeed.
% This is important for reproducibility, including when training
% is restarted from a checkpoint.

rng(epoch + opts.randomSeed) ;
prepareGPUs(opts, epoch == start+1) ;

% Train for one epoch.
% 一次epoch的训练过程
params = opts ;
params.epoch = epoch ;
params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
params.train = opts.train(randperm(numel(opts.train))) ; % shuffle
params.val = opts.val(randperm(numel(opts.val))) ;
params.imdb = imdb ;
params.getBatch = getBatch ;

if numel(params.gpus) <= 1
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
else
spmd
[net, state] = processEpoch(net, state, params, 'train') ;
[net, state] = processEpoch(net, state, params, 'val') ;
if labindex == 1 && ~evaluateMode
saveState(modelPath(epoch), net, state) ;
end
lastStats = state.stats ;
end
lastStats = accumulateStats(lastStats) ;
end

stats.train(epoch) = lastStats.train ;
stats.val(epoch) = lastStats.val ;
clear lastStats ;
saveStats(modelPath(epoch), stats) ;

if params.plotStatistics
switchFigure(1) ; clf ;
plots = setdiff(...
cat(2,...
fieldnames(stats.train)', ...
fieldnames(stats.val)'), {'num', 'time'}) ;
for p = plots
p = char(p) ;
values = zeros(0, epoch) ;
leg = {} ;
for f = {'train', 'val'}
f = char(f) ;
if isfield(stats.(f), p)
tmp = [stats.(f).(p)] ;
values(end+1,:) = tmp(1,:)' ;
leg{end+1} = f ;
end
end
subplot(1,numel(plots),find(strcmp(p,plots))) ;
plot(1:epoch, values','o-') ;
xlabel('epoch') ;
title(p) ;
legend(leg{:}) ;
grid on ;
end
drawnow ;
print(1, modelFigPath, '-dpdf') ;
end
end

% With multiple GPUs, return one copy
if isa(net, 'Composite'), net = net{1} ; end

% -------------------------------------------------------------------------
function err = error_multiclass(params, labels, res)
% -------------------------------------------------------------------------
% 多类误差
% -------------------------------------------------------------------------
predictions = gather(res(end-1).x) ;
[~,predictions] = sort(predictions, 3, 'descend') ;

% be resilient to badly formatted labels
if numel(labels) == size(predictions, 4)
labels = reshape(labels,1,1,1,[]) ;
end

% skip null labels
mass = single(labels(:,:,1,:) > 0) ;
if size(labels,3) == 2
% if there is a second channel in labels, used it as weights
mass = mass .* labels(:,:,2,:) ;
labels(:,:,2,:) = [] ;
end

m = min(5, size(predictions,3)) ;

error = ~bsxfun(@eq, predictions, labels) ;
err(1,1) = sum(sum(sum(mass .* error(:,:,1,:)))) ;
err(2,1) = sum(sum(sum(mass .* min(error(:,:,1:m,:),[],3)))) ;

% -------------------------------------------------------------------------
function err = error_binary(params, labels, res)
% -------------------------------------------------------------------------
% 二值误差
% -------------------------------------------------------------------------
predictions = gather(res(end-1).x) ;
error = bsxfun(@times, predictions, labels) < 0 ;
err = sum(error(:)) ;

% -------------------------------------------------------------------------
function err = error_none(params, labels, res)
% -------------------------------------------------------------------------
% 空误差
% -------------------------------------------------------------------------
err = zeros(0,1) ;

% -------------------------------------------------------------------------
function [net, state] = processEpoch(net, state, params, mode)
% -------------------------------------------------------------------------
%
% Note that net is not strictly needed as an output argument as net
% is a handle class. However, this fixes some aliasing issue in the
% spmd caller.
% 处理一个回合的训练
% -------------------------------------------------------------------------

% initialize with momentum 0
if isempty(state) || isempty(state.momentum)
for i = 1:numel(net.layers)
for j = 1:numel(net.layers{i}.weights)
state.momentum{i}{j} = 0 ;
end
end
end

% move CNN to GPU as needed
numGpus = numel(params.gpus) ;
if numGpus >= 1
net = vl_simplenn_move(net, 'gpu') ;
for i = 1:numel(state.momentum)
for j = 1:numel(state.momentum{i})
state.momentum{i}{j} = gpuArray(state.momentum{i}{j}) ;
end
end
end
if numGpus > 1
parserv = ParameterServer(params.parameterServer) ;
vl_simplenn_start_parserv(net, parserv) ;
else
parserv = [] ;
end

% profile
if params.profile
if numGpus <= 1
profile clear ;
profile on ;
else
mpiprofile reset ;
mpiprofile on ;
end
end

subset = params.(mode) ;
num = 0 ;
stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;
adjustTime = 0 ;
res = [] ;
error = [] ;

start = tic ;
for t=1:params.batchSize:numel(subset)
fprintf('%s: epoch %02d: %3d/%3d:', mode, params.epoch, ...
fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
batchSize = min(params.batchSize, numel(subset) - t + 1) ;

for s=1:params.numSubBatches
% get this image batch and prefetch the next
batchStart = t + (labindex-1) + (s-1) * numlabs ;
batchEnd = min(t+params.batchSize-1, numel(subset)) ;
batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
num = num + numel(batch) ;
if numel(batch) == 0, continue ; end

[im, labels] = params.getBatch(params.imdb, batch) ;

if params.prefetch
if s == params.numSubBatches
batchStart = t + (labindex-1) + params.batchSize ;
batchEnd = min(t+2*params.batchSize-1, numel(subset)) ;
else
batchStart = batchStart + numlabs ;
end
nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
params.getBatch(params.imdb, nextBatch) ;
end

if numGpus >= 1
im = gpuArray(im) ;
end

if strcmp(mode, 'train')
dzdy = 1 ;
evalMode = 'normal' ;
else
dzdy = [] ;
evalMode = 'test' ;
end
net.layers{end}.class = labels ;
res = vl_simplenn(net, im, dzdy, res, ...
'accumulate', s ~= 1, ...
'mode', evalMode, ...
'conserveMemory', params.conserveMemory, ...
'backPropDepth', params.backPropDepth, ...
'sync', params.sync, ...
'cudnn', params.cudnn, ...
'parameterServer', parserv, ...
'holdOn', s < params.numSubBatches) ;

% accumulate errors
error = sum([error, [...
sum(double(gather(res(end).x))) ;
reshape(params.errorFunction(params, labels, res),[],1) ; ]],2) ;
end

% accumulate gradient
if strcmp(mode, 'train')
if ~isempty(parserv), parserv.sync() ; end
[net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv) ;
end

% get statistics
time = toc(start) + adjustTime ;
batchTime = time - stats.time ;
stats = extractStats(net, params, error / num) ;
stats.num = num ;
stats.time = time ;
currentSpeed = batchSize / batchTime ;
averageSpeed = (t + batchSize - 1) / time ;
if t == 3*params.batchSize + 1
% compensate for the first three iterations, which are outliers
adjustTime = 4*batchTime - time ;
stats.time = time + adjustTime ;
end

fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
for f = setdiff(fieldnames(stats)', {'num', 'time'})
f = char(f) ;
fprintf(' %s: %.3f', f, stats.(f)) ;
end
fprintf('\n') ;

% collect diagnostic statistics
if strcmp(mode, 'train') && params.plotDiagnostics
switchFigure(2) ; clf ;
diagn = [res.stats] ;
diagnvar = horzcat(diagn.variation) ;
diagnpow = horzcat(diagn.power) ;
subplot(2,2,1) ; barh(diagnvar) ;
set(gca,'TickLabelInterpreter', 'none', ...
'YTick', 1:numel(diagnvar), ...
'YTickLabel',horzcat(diagn.label), ...
'YDir', 'reverse', ...
'XScale', 'log', ...
'XLim', [1e-5 1], ...
'XTick', 10.^(-5:1)) ;
grid on ;
subplot(2,2,2) ; barh(sqrt(diagnpow)) ;
set(gca,'TickLabelInterpreter', 'none', ...
'YTick', 1:numel(diagnpow), ...
'YTickLabel',{diagn.powerLabel}, ...
'YDir', 'reverse', ...
'XScale', 'log', ...
'XLim', [1e-5 1e5], ...
'XTick', 10.^(-5:5)) ;
grid on ;
subplot(2,2,3); plot(squeeze(res(end-1).x)) ;
drawnow ;
end
end

% Save back to state.
state.stats.(mode) = stats ;
if params.profile
if numGpus <= 1
state.prof.(mode) = profile('info') ;
profile off ;
else
state.prof.(mode) = mpiprofile('info');
mpiprofile off ;
end
end
if ~params.saveMomentum
state.momentum = [] ;
else
for i = 1:numel(state.momentum)
for j = 1:numel(state.momentum{i})
state.momentum{i}{j} = gather(state.momentum{i}{j}) ;
end
end
end

net = vl_simplenn_move(net, 'cpu') ;

% -------------------------------------------------------------------------
function [net, res, state] = accumulateGradients(net, res, state, params, batchSize, parserv)
% -------------------------------------------------------------------------
% 梯度下降累计函数
% -------------------------------------------------------------------------
numGpus = numel(params.gpus) ;
otherGpus = setdiff(1:numGpus, labindex) ;

for l=numel(net.layers):-1:1
for j=numel(res(l).dzdw):-1:1

if ~isempty(parserv)
tag = sprintf('l%d_%d',l,j) ;
parDer = parserv.pull(tag) ;
else
parDer = res(l).dzdw{j} ;
end

if j == 3 && strcmp(net.layers{l}.type, 'bnorm')
% special case for learning bnorm moments
thisLR = net.layers{l}.learningRate(j) ;
net.layers{l}.weights{j} = vl_taccum(...
1 - thisLR, ...
net.layers{l}.weights{j}, ...
thisLR / batchSize, ...
parDer) ;
else
% Standard gradient training.
thisDecay = params.weightDecay * net.layers{l}.weightDecay(j) ;
thisLR = params.learningRate * net.layers{l}.learningRate(j) ;

if thisLR>0 || thisDecay>0
% Normalize gradient and incorporate weight decay.
parDer = vl_taccum(1/batchSize, parDer, ...
thisDecay, net.layers{l}.weights{j}) ;

% Update momentum.
state.momentum{l}{j} = vl_taccum(...
params.momentum, state.momentum{l}{j}, ...
-1, parDer) ;

% Nesterov update (aka one step ahead).
if params.nesterovUpdate
delta = vl_taccum(...
params.momentum, state.momentum{l}{j}, ...
-1, parDer) ;
else
delta = state.momentum{l}{j} ;
end

% Update parameters.
net.layers{l}.weights{j} = vl_taccum(...
1, net.layers{l}.weights{j}, ...
thisLR, delta) ;
end
end

% if requested, collect some useful stats for debugging
if params.plotDiagnostics
variation = [] ;
label = '' ;
switch net.layers{l}.type
case {'conv','convt'}
variation = thisLR * mean(abs(state.momentum{l}{j}(:))) ;
power = mean(res(l+1).x(:).^2) ;
if j == 1 % fiters
base = mean(net.layers{l}.weights{j}(:).^2) ;
label = 'filters' ;
else % biases
base = sqrt(power) ;%mean(abs(res(l+1).x(:))) ;
label = 'biases' ;
end
variation = variation / base ;
label = sprintf('%s_%s', net.layers{l}.name, label) ;
end
res(l).stats.variation(j) = variation ;
res(l).stats.power = power ;
res(l).stats.powerLabel = net.layers{l}.name ;
res(l).stats.label{j} = label ;
end
end
end

% -------------------------------------------------------------------------
function stats = accumulateStats(stats_)
% -------------------------------------------------------------------------

for s = {'train', 'val'}
s = char(s) ;
total = 0 ;

% initialize stats stucture with same fields and same order as
% stats_{1}
stats__ = stats_{1} ;
names = fieldnames(stats__.(s))' ;
values = zeros(1, numel(names)) ;
fields = cat(1, names, num2cell(values)) ;
stats.(s) = struct(fields{:}) ;

for g = 1:numel(stats_)
stats__ = stats_{g} ;
num__ = stats__.(s).num ;
total = total + num__ ;

for f = setdiff(fieldnames(stats__.(s))', 'num')
f = char(f) ;
stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;

if g == numel(stats_)
stats.(s).(f) = stats.(s).(f) / total ;
end
end
end
stats.(s).num = total ;
end

% -------------------------------------------------------------------------
function stats = extractStats(net, params, errors)
% -------------------------------------------------------------------------
stats.objective = errors(1) ;
for i = 1:numel(params.errorLabels)
stats.(params.errorLabels{i}) = errors(i+1) ;
end

% -------------------------------------------------------------------------
function saveState(fileName, net, state)
% -------------------------------------------------------------------------
save(fileName, 'net', 'state') ;

% -------------------------------------------------------------------------
function saveStats(fileName, stats)
% -------------------------------------------------------------------------
if exist(fileName)
save(fileName, 'stats', '-append') ;
else
save(fileName, 'stats') ;
end

% -------------------------------------------------------------------------
function [net, state, stats] = loadState(fileName)
% -------------------------------------------------------------------------
load(fileName, 'net', 'state', 'stats') ;
net = vl_simplenn_tidy(net) ;
if isempty(whos('stats'))
error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ...
fileName) ;
end

% -------------------------------------------------------------------------
function epoch = findLastCheckpoint(modelDir)
% -------------------------------------------------------------------------
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;

% -------------------------------------------------------------------------
function switchFigure(n)
% -------------------------------------------------------------------------
if get(0,'CurrentFigure') ~= n
try
set(0,'CurrentFigure',n) ;
catch
figure(n) ;
end
end

% -------------------------------------------------------------------------
function clearMex()
% -------------------------------------------------------------------------
%clear vl_tmove vl_imreadjpeg ;
disp('Clearing mex files') ;
clear mex ;
clear vl_tmove vl_imreadjpeg ;

% -------------------------------------------------------------------------
function prepareGPUs(params, cold)
% -------------------------------------------------------------------------
numGpus = numel(params.gpus) ;
if numGpus > 1
% check parallel pool integrity as it could have timed out
pool = gcp('nocreate') ;
if ~isempty(pool) && pool.NumWorkers ~= numGpus
delete(pool) ;
end
pool = gcp('nocreate') ;
if isempty(pool)
parpool('local', numGpus) ;
cold = true ;
end
end
if numGpus >= 1 && cold
fprintf('%s: resetting GPU\n', mfilename) ;
clearMex() ;
if numGpus == 1
disp(gpuDevice(params.gpus)) ;
else
spmd
clearMex() ;
disp(gpuDevice(params.gpus(labindex))) ;
end
end
end

本文为原创文章转载必须注明本文出处以及附上 本文地址超链接  以及 博主博客地址:http://blog.csdn.net/qq_20259459  和 作者邮箱(
jinweizhi93@gmai.com )。

(如果喜欢本文,欢迎大家关注我的博客或者动手点个赞,有需要可以邮件联系我)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐