R-CNN笔记2:rcnn_train.m文件
2016-04-19 10:36
288 查看
rcnn_train.m
function [rcnn_model, rcnn_k_fold_model] = ... rcnn_train(imdb, varargin) % [rcnn_model, rcnn_k_fold_model] = rcnn_train(imdb, varargin) % Trains an R-CNN detector for all classes in the imdb. % % Keys that can be passed in: % % svm_C SVM regularization parameter % bias_mult Bias feature value (for liblinear) % pos_loss_weight Cost factor on hinge loss for positives % layer Feature layer to use (either 5, 6 or 7) % k_folds Train on folds of the imdb % checkpoint Save the rcnn_model every checkpoint images % crop_mode Crop mode (either 'warp' or 'square') % crop_padding Amount of padding in crop % net_file Path to the Caffe CNN to use % cache_name Path to the precomputed feature cache % AUTORIGHTS % --------------------------------------------------------- % Copyright (c) 2014, Ross Girshick % % This file is part of the R-CNN code and is available % under the terms of the Simplified BSD License provided in % LICENSE. Please retain this notice and LICENSE if you use % this file (or any portion of it) in your project. % --------------------------------------------------------- % TODO: % - allow training just a subset of classes ip = inputParser; ip.addRequired('imdb', @isstruct); ip.addParamValue('svm_C', 10^-3, @isscalar); ip.addParamValue('bias_mult', 10, @isscalar); ip.addParamValue('pos_loss_weight', 2, @isscalar); ip.addParamValue('layer', 7, @isscalar); ip.addParamValue('k_folds', 2, @isscalar); ip.addParamValue('checkpoint', 0, @isscalar); ip.addParamValue('crop_mode', 'warp', @isstr); ip.addParamValue('crop_padding', 16, @isscalar); ip.addParamValue('net_file', ... './data/caffe_nets/finetune_voc_2007_trainval_iter_70k', ... @isstr); ip.addParamValue('cache_name', ... 'v1_finetune_voc_2007_trainval_iter_70000', @isstr); ip.parse(imdb, varargin{:}); opts = ip.Results; opts.net_def_file = './model-defs/rcnn_batch_256_output_fc7.prototxt'; conf = rcnn_config('sub_dir', imdb.name); % Record a log of the training and test procedure timestamp = datestr(datevec(now()), 'dd.mmm.yyyy:HH.MM.SS'); diary_file = [conf.cache_dir 'rcnn_train_' timestamp '.txt']; diary(diary_file); fprintf('Logging output in %s\n', diary_file); fprintf('\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'); fprintf('Training options:\n'); disp(opts); fprintf('~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'); % ------------------------------------------------------------------------ % Create a new rcnn model rcnn_model = rcnn_create_model(opts.net_def_file, opts.net_file, opts.cache_name); rcnn_model = rcnn_load_model(rcnn_model, conf.use_gpu); rcnn_model.detectors.crop_mode = opts.crop_mode; rcnn_model.detectors.crop_padding = opts.crop_padding; rcnn_model.classes = imdb.classes; % ------------------------------------------------------------------------ % ------------------------------------------------------------------------ % Get the average norm of the features % 获得特征的平均规范值 opts.feat_norm_mean = rcnn_feature_stats(imdb, opts.layer, rcnn_model); fprintf('average norm = %.3f\n', opts.feat_norm_mean); rcnn_model.training_opts = opts; % ------------------------------------------------------------------------ % ------------------------------------------------------------------------ % Get all positive examples % We cache only the pool5 features and convert them on-the-fly to % fc6 or fc7 as required % 获得所有正例样本 % 我们把pool5层的特征保存,并且把它们转换成fc6和fc7 save_file = sprintf('./feat_cache/%s/%s/gt_pos_layer_5_cache.mat', ... rcnn_model.cache_name, imdb.name); try load(save_file); fprintf('Loaded saved positives from ground truth boxes\n'); catch %获得正例样本的pool5层特征 [X_pos, keys_pos] = get_positive_pool5_features(imdb, opts); save(save_file, 'X_pos', 'keys_pos', '-v7.3'); end % Init training caches caches = {}; for i = imdb.class_ids fprintf('%14s has %6d positive instances\n', ... imdb.classes{i}, size(X_pos{i},1)); % 把pool5层特征转换成全连接层特征 X_pos{i} = rcnn_pool5_to_fcX(X_pos{i}, opts.layer, rcnn_model); X_pos{i} = rcnn_scale_features(X_pos{i}, opts.feat_norm_mean); caches{i} = init_cache(X_pos{i}, keys_pos{i}); end % ------------------------------------------------------------------------ % ------------------------------------------------------------------------ % Train with hard negative mining first_time = true; % one pass over the data is enough max_hard_epochs = 1; for hard_epoch = 1:max_hard_epochs for i = 1:length(imdb.image_ids) fprintf('%s: hard neg epoch: %d/%d image: %d/%d\n', ... procid(), hard_epoch, max_hard_epochs, i, length(imdb.image_ids)); % Get hard negatives for all classes at once (avoids loading feature cache % more than once) % 从所有的类中一次的获得难反例(避免超过一次的加载特征) % 这里X的难反例样本的特征,keys是一个索引 keys[a b],a is the clasee and b is index [X, keys] = sample_negative_features(first_time, rcnn_model, caches, ... imdb, i); % Add sampled negatives to each classes training cache, removing % duplicates for j = imdb.class_ids if ~isempty(keys{j}) if ~isempty(caches{j}.keys_neg) [~, ~, dups] = intersect(caches{j}.keys_neg, keys{j}, 'rows'); assert(isempty(dups)); end % 这里将难样本X合并到caches变量中,X[1*20 cell],caches[1*20 cell] caches{j}.X_neg = cat(1, caches{j}.X_neg, X{j}); caches{j}.keys_neg = cat(1, caches{j}.keys_neg, keys{j}); caches{j}.num_added = caches{j}.num_added + size(keys{j},1); end % Update model if % - first time seeing negatives % - more than retrain_limit negatives have been added % - its the final image of the final epoch % 更新模型 如果 % 第一次看到反例样本 % 超过retrain_limit数量的反例样本已经被添加 % 这是最后的时代的最终图像 is_last_time = (hard_epoch == max_hard_epochs && i == length(imdb.image_ids)); hit_retrain_limit = (caches{j}.num_added > caches{j}.retrain_limit); if (first_time || hit_retrain_limit || is_last_time) && ... ~isempty(caches{j}.X_neg) fprintf('>>> Updating %s detector <<<\n', imdb.classes{j}); fprintf('Cache holds %d pos examples %d neg examples\n', ... size(caches{j}.X_pos,1), size(caches{j}.X_neg,1)); % 跟新模型 [new_w, new_b] = update_model(caches{j}, opts); rcnn_model.detectors.W(:, j) = new_w; rcnn_model.detectors.B(j) = new_b; caches{j}.num_added = 0; z_pos = caches{j}.X_pos * new_w + new_b; z_neg = caches{j}.X_neg * new_w + new_b; caches{j}.pos_loss(end+1) = opts.svm_C * opts.pos_loss_weight * ... sum(max(0, 1 - z_pos)); caches{j}.neg_loss(end+1) = opts.svm_C * sum(max(0, 1 + z_neg)); caches{j}.reg_loss(end+1) = 0.5 * new_w' * new_w + ... 0.5 * (new_b / opts.bias_mult)^2; caches{j}.tot_loss(end+1) = caches{j}.pos_loss(end) + ... caches{j}.neg_loss(end) + ... caches{j}.reg_loss(end); for t = 1:length(caches{j}.tot_loss) fprintf(' %2d: obj val: %.3f = %.3f (pos) + %.3f (neg) + %.3f (reg)\n', ... t, caches{j}.tot_loss(t), caches{j}.pos_loss(t), ... caches{j}.neg_loss(t), caches{j}.reg_loss(t)); end % store negative support vectors for visualizing later % 为了之后的可视化存储反例支撑向量 SVs_neg = find(z_neg > -1 - eps); rcnn_model.SVs.keys_neg{j} = caches{j}.keys_neg(SVs_neg, :); rcnn_model.SVs.scores_neg{j} = z_neg(SVs_neg); % evict easy examples % 逐出容易的样本 easy = find(z_neg < caches{j}.evict_thresh);%where caches{j}.evict_thresh = -1.2 caches{j}.X_neg(easy,:) = []; caches{j}.keys_neg(easy,:) = []; fprintf(' Pruning easy negatives\n'); fprintf(' Cache holds %d pos examples %d neg examples\n', ... size(caches{j}.X_pos,1), size(caches{j}.X_neg,1)); fprintf(' %d pos support vectors\n', numel(find(z_pos < 1 + eps))); fprintf(' %d neg support vectors\n', numel(find(z_neg > -1 - eps))); end end first_time = false; if opts.checkpoint > 0 && mod(i, opts.checkpoint) == 0 save([conf.cache_dir 'rcnn_model'], 'rcnn_model'); end end end % save the final rcnn_model save([conf.cache_dir 'rcnn_model'], 'rcnn_model'); % ------------------------------------------------------------------------ % ------------------------------------------------------------------------ if opts.k_folds > 0 rcnn_k_fold_model = rcnn_model; [W, B, folds] = update_model_k_fold(rcnn_model, caches, imdb); rcnn_k_fold_model.folds = folds; for f = 1:length(folds) rcnn_k_fold_model.detectors(f).W = W{f}; rcnn_k_fold_model.detectors(f).B = B{f}; end save([conf.cache_dir 'rcnn_k_fold_model'], 'rcnn_k_fold_model'); else rcnn_k_fold_model = []; end % ------------------------------------------------------------------------ % ------------------------------------------------------------------------ function [X_neg, keys] = sample_negative_features(first_time, rcnn_model, ... caches, imdb, ind) % ------------------------------------------------------------------------ opts = rcnn_model.training_opts; d = rcnn_load_cached_pool5_features(opts.cache_name, ... imdb.name, imdb.image_ids{ind}); % d [1*1 struct] % d.gt [n*1 logical] % d.overlap [n*20 single] % d.boxes [n*4 double] % d.feat [n*9216 single] % d.class [n*1 uint8] class_ids = imdb.class_ids; if isempty(d.feat) X_neg = cell(max(class_ids), 1); keys = cell(max(class_ids), 1); return; end d.feat = rcnn_pool5_to_fcX(d.feat, opts.layer, rcnn_model); d.feat = rcnn_scale_features(d.feat, opts.feat_norm_mean); neg_ovr_thresh = 0.3; if first_time for cls_id = class_ids % 找出 I = find(d.overlap(:, cls_id) < neg_ovr_thresh);% where neg_ovr_thresh = 0.3 X_neg{cls_id} = d.feat(I,:); keys{cls_id} = [ind*ones(length(I),1) I]; end else zs = bsxfun(@plus, d.feat*rcnn_model.detectors.W, rcnn_model.detectors.B); for cls_id = class_ids z = zs(:, cls_id); % 找到 得分大于难样本的阈值 并且 overlap 小于 反例overlap阈值的样本 I = find((z > caches{cls_id}.hard_thresh) & ... (d.overlap(:, cls_id) < neg_ovr_thresh)); % Avoid adding duplicate features % 避免增加重复的特征 keys_ = [ind*ones(length(I),1) I]; if ~isempty(caches{cls_id}.keys_neg) && ~isempty(keys_) % 寻找在caches{cls_id}.keys_neg和keys_共同出现的元素 [~, ~, dups] = intersect(caches{cls_id}.keys_neg, keys_, 'rows'); % C = setdiff(A, B) returns the data in A that is not in B % 也就是寻找在1:size(keys_,1)出现的元素而在dups中没有出现 keep = setdiff(1:size(keys_,1), dups); % 这样I就完成了去重 I = I(keep); end % Unique hard negatives X_neg{cls_id} = d.feat(I,:); keys{cls_id} = [ind*ones(length(I),1) I]; end end % ------------------------------------------------------------------------ function [w, b] = update_model(cache, opts, pos_inds, neg_inds) % ------------------------------------------------------------------------ solver = 'liblinear'; liblinear_type = 3; % l2 regularized l1 hinge loss %liblinear_type = 5; % l1 regularized l2 hinge loss if ~exist('pos_inds', 'var') || isempty(pos_inds) num_pos = size(cache.X_pos, 1); %正样本的数量 pos_inds = 1:num_pos; else num_pos = length(pos_inds); %正样本的数量 fprintf('[subset mode] using %d out of %d total positives\n', ... num_pos, size(cache.X_pos,1)); end if ~exist('neg_inds', 'var') || isempty(neg_inds) num_neg = size(cache.X_neg, 1); %反例样本的数量 neg_inds = 1:num_neg; else num_neg = length(neg_inds); %反例样本的数量 fprintf('[subset mode] using %d out of %d total negatives\n', ... num_neg, size(cache.X_neg,1)); end switch solver case 'liblinear' ll_opts = sprintf('-w1 %.5f -c %.5f -s %d -B %.5f', ... opts.pos_loss_weight, opts.svm_C, ... liblinear_type, opts.bias_mult); fprintf('liblinear opts: %s\n', ll_opts); X = sparse(size(cache.X_pos,2), num_pos+num_neg); %创造输入的稀疏矩阵X X(:,1:num_pos) = cache.X_pos(pos_inds,:)'; %将正例样本导入X X(:,num_pos+1:end) = cache.X_neg(neg_inds,:)'; %将反例样本导入X y = cat(1, ones(num_pos,1), -ones(num_neg,1)); %创造标签变量y llm = liblinear_train(y, X, ll_opts, 'col'); %更新模型 w = single(llm.w(1:end-1)'); %这里为什么会多一个w呢? b = single(llm.w(end)*opts.bias_mult); %b的计算方法? otherwise error('unknown solver: %s', solver); end % ------------------------------------------------------------------------ function [W, B, folds] = update_model_k_fold(rcnn_model, caches, imdb) % ------------------------------------------------------------------------ opts = rcnn_model.training_opts; num_images = length(imdb.image_ids); folds = create_folds(num_images, opts.k_folds); W = cell(opts.k_folds, 1); B = cell(opts.k_folds, 1); fprintf('Training k-fold models\n'); for i = imdb.class_ids fprintf('\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n'); fprintf('Training folds for class %s\n', imdb.classes{i}); fprintf('~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n'); for f = 1:length(folds) fprintf('Holding out fold %d\n', f); [pos_inds, neg_inds] = get_cache_inds_from_fold(caches{i}, folds{f}); [new_w, new_b] = update_model(caches{i}, opts, ... pos_inds, neg_inds); W{f}(:,i) = new_w; B{f}(i) = new_b; end end % ------------------------------------------------------------------------ function [pos_inds, neg_inds] = get_cache_inds_from_fold(cache, fold) % ------------------------------------------------------------------------ pos_inds = find(ismember(cache.keys_pos(:,1), fold) == false); neg_inds = find(ismember(cache.keys_neg(:,1), fold) == false); % ------------------------------------------------------------------------ function [X_pos, keys] = get_positive_pool5_features(imdb, opts) % ------------------------------------------------------------------------ X_pos = cell(max(imdb.class_ids), 1); keys = cell(max(imdb.class_ids), 1); for i = 1:length(imdb.image_ids) tic_toc_print('%s: pos features %d/%d\n', ... procid(), i, length(imdb.image_ids)); d = rcnn_load_cached_pool5_features(opts.cache_name, ... imdb.name, imdb.image_ids{i}); for j = imdb.class_ids if isempty(X_pos{j}) X_pos{j} = single([]); keys{j} = []; end sel = find(d.class == j); if ~isempty(sel) X_pos{j} = cat(1, X_pos{j}, d.feat(sel,:)); keys{j} = cat(1, keys{j}, [i*ones(length(sel),1) sel]); end end end % ------------------------------------------------------------------------ function cache = init_cache(X_pos, keys_pos) % ------------------------------------------------------------------------ cache.X_pos = X_pos; cache.X_neg = single([]); cache.keys_neg = []; cache.keys_pos = keys_pos; cache.num_added = 0; cache.retrain_limit = 2000; cache.evict_thresh = -1.2; cache.hard_thresh = -1.0001; cache.pos_loss = []; cache.neg_loss = []; cache.reg_loss = []; cache.tot_loss = [];
相关文章推荐
- poj 1995Raising Modulo Numbers
- 模拟器安装apk提示:INSTALL_FAILED_NO_MATCHING_ABIS
- MySQL实现批量检查表并进行repair与optimize的方法
- Linux 上使用 Gmail SMTP 服务器发送邮件通知
- FOJ 2216 The Longest Straight 第六届福建省大学生程序设计竞赛 E 尺取法乱搞
- 人工智能概述
- shrio 权限管理filterChainDefinitions过滤器配置
- 服务器中很多的CLOSE_WAIT
- 自定义控件三部曲之绘图篇(八)——Paint之ColorMatrix与滤镜效果
- URAL 2060 Subpalindrome Pairs Manacher
- 人工智能60年:纪念活动宣传什么?
- A failed job change
- MySQL实现批量检查表并进行repair与optimize的方法
- 对Run-Time Check Failure #3 - The variable 'a' is being used without being initialized.的处理以及理解
- CommandInvokationFailure: Failed to re-package resources. See the Console for details.
- Raid创建 LVM创建 扩展
- editor does not contain a main type的解决方案
- IBM Watson正致力于将AI和区…
- Smart Token Chain宣布…
- BlockChain:Internet问世以来最具…