MatConNet源代码解读(2)
2016-07-26 10:04
274 查看
example/cnn_mnist.m
function [net, info] = cnn_mnist(varargin) %很多人看到varargin就吓住了,其实可以没有参数的 %CNN_MNIST Demonstrates MatConvNet on MNIST %执行vl_setupnn,这么麻烦? run(fullfile(fileparts(mfilename('fullpath')),... '..', '..', 'matlab', 'vl_setupnn.m')) ; opts.batchNormalization = false ; opts.networkType = 'simplenn' ; [opts, varargin] = vl_argparse(opts, varargin) ; %生成实验中途记录文件的名称,每epoch记录一次。记录数据放在data目录下面 sfx = opts.networkType ; if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end opts.expDir = fullfile(vl_rootnn, 'data', ['mnist-baseline-' sfx]) ; [opts, varargin] = vl_argparse(opts, varargin) ; %图像数据库位置 opts.dataDir = fullfile(vl_rootnn, 'data', 'mnist') ; opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); opts.train = struct() ; opts = vl_argparse(opts, varargin) ; if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end; % -------------------------------------------------------------------- % Prepare data % -------------------------------------------------------------------- %网络初始化 net = cnn_mnist_init('batchNormalization', opts.batchNormalization, ... 'networkType', opts.networkType) ; %如果有mnist数据库就直接加载,没有就从网上下。有没有觉得matlab一下变得好高档 if exist(opts.imdbPath, 'file') imdb = load(opts.imdbPath) ; else imdb = getMnistImdb(opts) ; mkdir(opts.expDir) ; save(opts.imdbPath, '-struct', 'imdb') ; end %arrayfun以数组的元素作为函数@x的输入,UniformOutput指输出结果的类型是否都相同,为什么要是false呢?没看明白 net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:10,'UniformOutput',false) ; % -------------------------------------------------------------------- % Train % -------------------------------------------------------------------- %开始干正事了 switch opts.networkType case 'simplenn', trainfn = @cnn_train ; case 'dagnn', trainfn = @cnn_train_dag ; end %trainfn就是cnn_train啦,val参数有什么用 [net, info] = trainfn(net, imdb, getBatch(opts), ... 'expDir', opts.expDir, ... net.meta.trainOpts, ... opts.train, ... 'val', find(imdb.images.set == 3)) ; % 取batch数据,不会吧,分割batch还要自己来。仔细看输出是个函数指针,说明实际batch是自动抽取的 %-------------------------------------------------------------------- function fn = getBatch(opts) % -------------------------------------------------------------------- switch lower(opts.networkType) case 'simplenn' fn = @(x,y) getSimpleNNBatch(x,y) ; case 'dagnn' bopts = struct('numGpus', numel(opts.train.gpus)) ; fn = @(x,y) getDagNNBatch(bopts,x,y) ; end % -------------------------------------------------------------------- function [images, labels] = getSimpleNNBatch(imdb, batch) % -------------------------------------------------------------------- images = imdb.images.data(:,:,:,batch) ; labels = imdb.images.labels(1,batch) ; % -------------------------------------------------------------------- function inputs = getDagNNBatch(opts, imdb, batch) % ------------------ 4000 -------------------------------------------------- images = imdb.images.data(:,:,:,batch) ; labels = imdb.images.labels(1,batch) ; if opts.numGpus > 0 images = gpuArray(images) ; end inputs = {'input', images, 'label', labels} ; %下载MnistImdb,话说Lecun也就是靠这个数据库一战成名,1998年那篇文章到底做了多少乱七八糟的实验啊~ % -------------------------------------------------------------------- function imdb = getMnistImdb(opts) % -------------------------------------------------------------------- % Preapre the imdb structure, returns image data with mean image subtracted files = {'train-images-idx3-ubyte', ... 'train-labels-idx1-ubyte', ... 't10k-images-idx3-ubyte', ... 't10k-labels-idx1-ubyte'} ; if ~exist(opts.dataDir, 'dir') mkdir(opts.dataDir) ; end for i=1:4 if ~exist(fullfile(opts.dataDir, files{i}), 'file') url = sprintf('http://yann.lecun.com/exdb/mnist/%s.gz',files{i}) ; fprintf('downloading %s\n', url) ; gunzip(url, opts.dataDir) ; end end f=fopen(fullfile(opts.dataDir, 'train-images-idx3-ubyte'),'r') ; x1=fread(f,inf,'uint8'); fclose(f) ; x1=permute(reshape(x1(17:end),28,28,60e3),[2 1 3]) ; f=fopen(fullfile(opts.dataDir, 't10k-images-idx3-ubyte'),'r') ; x2=fread(f,inf,'uint8'); fclose(f) ; x2=permute(reshape(x2(17:end),28,28,10e3),[2 1 3]) ; f=fopen(fullfile(opts.dataDir, 'train-labels-idx1-ubyte'),'r') ; y1=fread(f,inf,'uint8'); fclose(f) ; y1=double(y1(9:end)')+1 ; f=fopen(fullfile(opts.dataDir, 't10k-labels-idx1-ubyte'),'r') ; y2=fread(f,inf,'uint8'); fclose(f) ; y2=double(y2(9:end)')+1 ; %训练集为1,测试集为3 set = [ones(1,numel(y1)) 3*ones(1,numel(y2))]; data = single(reshape(cat(3, x1, x2),28,28,1,[])); dataMean = mean(data(:,:,:,set == 1), 4); %这个函数牛逼了,可以理解为将dataMean扩展至与data同维数,然后逐点执行minus操作。实际可能分布式计算,好牛叉! data = bsxfun(@minus, data, dataMean) ; imdb.images.data = data ; imdb.images.data_mean = dataMean; imdb.images.labels = cat(2, y1, y2) ; imdb.images.set = set ; %最难的在这里,‘val’是什么意思一直没搞懂! imdb.meta.sets = {'train', 'val', 'test'} ; imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ;
相关文章推荐
- CUDA搭建
- 稀疏自动编码器 (Sparse Autoencoder)
- 白化(Whitening):PCA vs. ZCA
- softmax回归
- 卷积神经网络初探
- 深入理解CNN的细节
- TensorFlow人工智能引擎入门教程之九 RNN/LSTM循环神经网络长短期记忆网络使用
- TensorFlow人工智能引擎入门教程之十 最强网络 RSNN深度残差网络 平均准确率96-99%
- TensorFlow人工智能入门教程之十一 最强网络DLSTM 双向长短期记忆网络(阿里小AI实现)
- TensorFlow人工智能入门教程之十四 自动编码机AutoEncoder 网络
- TensorFlow人工智能引擎入门教程所有目录
- 如何用70行代码实现深度神经网络算法
- 近200篇机器学习&深度学习资料分享(含各种文档,视频,源码等)
- 安装caffe过程记录
- DIGITS的安装与使用记录
- 我对ltsm的学习,从rnn的问题讲起
- 图像识别和图像搜索
- 卷积神经网络
- 51CTO学院优质新课抢先体验-5折好课帮你技能提升、升职加薪