您的位置:首页 > 其它

MOD 之"Hello World"

2014-03-08 20:49 447 查看
首先声明,MOD不是取模函数!MOD是字典学习和sparse coding的一种方法… 最近在看KSVD,其简化版就是MOD(method of directions),这么说吧,KSVD和MOD的优化目标函数是相同的,MOD之所以可以称作KSVD的简化版是因为KSVD在MOD的基础上做了顺序更新列的优化。关于KSVD和MOD的理论知识请见下面我给出的一页note和referenc中的paper。本文主要给出其基本思想及我的代码,已经过测试,如有bug欢迎提出。

Reference

<<From Sparse Solutions of Systems of Equations to Sparse Modeling of Signals and Images>>,
Page 68~70

KSVD & MOD's principle & objective function

Principle:

简单来说,其优化就是一个OMP(orthogonal matching pursuit)与Regression的迭代过程,因此代码包括一个OMP.m, regression.m.

Objective Function & the variation from MOD to KSVD:



Code

CODE1. MOD

运行Main(Main中通过MOD)学习字典和稀疏表示,MOD迭代调用Regression学习字典,调用和OMP获得sparse representation.

Main.m

[cpp] view
plaincopy

%% Main.m

clc;

clear;

P = 512;

N = 256;

M = 128;

K = 100;



%% Data Generator Method 1

% sparsity_X = 0.4;

% Y = randi(10,M,P);

% X = floor(sprand(N,P,sparsity_X)*10);



%% Data Generator Method 2

Y = randn(M,P);%Notice that Y should be full rank, that is, rank(Y) = N

X = randn(N,P);% initialization of X



%% Main Iteration

[D,X] = MOD(Y,X,K,1e-4);

MOD.m

[cpp] view
plaincopy

% @Function: Method Of Dirction of 2D signal

% For dictionary and sparse representation learning

% @CreateTime: 2013-2-22

% @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer

%

% @Reference: From Sparse Solutions of Systems of Equations to

% Sparse Modeling of Signals and Images



function [ D , X ] = MOD( Y ,X ,K ,ErrorThreshold )

%MOD Summary of this function goes here

% Detailed explanation goes here

% Sample_Data is Y

% Coefficient is X

% Dictionary is D

% sparsity is K



disp('Run Method of directions');

iteration_time = 1;

error = ErrorThreshold+1;





while error>=ErrorThreshold;

disp(['iteration time = ' num2str(iteration_time)]);

D = Regression(Y,X);

X = OMP(Y,D,K);

iteration_time = iteration_time+1;

error = sum(sum(abs(Y-D*X)))

end



end

OMP.m

[cpp] view
plaincopy

% @Function: Orthogonal Matching Pursuit of 2D signal

% Learning Sparse Representation Given Dictionary

% @CreateTime: 2013-2-21

% @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer

%

% @Reference: http://www.eee.hku.hk/~wsha/Freecode/freecode.htm



function [ X ] = OMP( Y,D,K )

% Y is the sample data to be recovered M*P

% D is the dictionary M*N

% X is the sparse coefficient N*P

% K is the sparsity



if nargin==2

K = size(D,2);

end;



M = size(D,1);

P = size(Y,2);

N = size(D,2);

m = K*2; % execute iterations



for idx = 1:P

% recover the idx-th column sample

y = Y(:,idx);

residual = y;

Aug_D = [];

D1 = D;



for times = 1:m;

product = abs(D1'*residual);

[~,pos] = max(product); % 最大投影系数对应的位置

Aug_D = [Aug_D, D1(:,pos)];

D1(:,pos) = zeros(M,1); %去掉选中的列

indx(times) = pos;

Aug_x = (Aug_D'*Aug_D)^-1*Aug_D'*y; % 最小二乘,使残差最小,i.e. x = pinv(Aug_D)*y

residual = y - Aug_D*Aug_x;



if sum(residual.^2)<1e-6

break;

end

end

temp = zeros(N,1);

temp(indx(1:times)) = Aug_x;

X(:,idx) = sparse(temp);

end

end

Regression.m

[cpp] view
plaincopy

% @Function: Dictionary learning & Regression

% Learning Dictionary Given Sparse Representation

% @CreateTime: 2013-2-21

% @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer

%

function [ D ] = Regression( Y,X )

% Y is the sample data to be recovered M*P

% D is the dictionary M*N

% X is the sparse coefficient N*P

% P>N>M



%由于X是扁矩阵,需要转置求D0 = min(D) ||Y^T-X^TD^T||

%这样就是N个未知数,P个方程去求解;

%每次解得D中的一列,共解M次



Y = Y';

X = X';

P = size(Y,1);

N = size(X,2);

M = size(Y,2);

D = zeros(N,M);



for i = 1:M;

y = Y(:,i);

D(:,i) = regress(y,X);

end

D = D';

end

============================================================================

CODE2. KSVD

ksvd函数代码是国外的人写的,很规矩,这里贴过来。

[cpp] view
plaincopy

function [Dictionary,output] = KSVD(...

Data,... % an nXN matrix that contins N signals (Y), each of dimension n.

param)

% =========================================================================

% K-SVD algorithm

% =========================================================================

% The K-SVD algorithm finds a dictionary for linear representation of

% signals. Given a set of signals, it searches for the best dictionary that

% can sparsely represent each signal. Detailed discussion on the algorithm

% and possible applications can be found in "The K-SVD: An Algorithm for

% Designing of Overcomplete Dictionaries for Sparse Representation", written

% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.

% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.

% =========================================================================

% INPUT ARGUMENTS:

% Data an nXN matrix that contins N signals (Y), each of dimension n.

% param structure that includes all required

% parameters for the K-SVD execution.

% Required fields are:

% K, ... the number of dictionary elements to train

% numIteration,... number of iterations to perform.

% errorFlag... if =0, a fix number of coefficients is

% used for representation of each signal. If so, param.L must be

% specified as the number of representing atom. if =1, arbitrary number

% of atoms represent each signal, until a specific representation error

% is reached. If so, param.errorGoal must be specified as the allowed

% error.

% preserveDCAtom... if =1 then the first atom in the dictionary

% is set to be constant, and does not ever change. This

% might be useful for working with natural

% images (in this case, only param.K-1

% atoms are trained).

% (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.

% (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.

% InitializationMethod,... mehtod to initialize the dictionary, can

% be one of the following arguments:

% * 'DataElements' (initialization by the signals themselves), or:

% * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).

% (optional, see InitializationMethod) initialDictionary,... % if the initialization method

% is 'GivenMatrix', this is the matrix that will be used.

% (optional) TrueDictionary, ... % if specified, in each

% iteration the difference between this dictionary and the trained one

% is measured and displayed.

% displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0,

% the average repersentation error (RMSE) is displayed, while if

% param.errorFlag==1, the average number of required coefficients for

% representation of each signal is displayed.

% =========================================================================

% OUTPUT ARGUMENTS:

% Dictionary The extracted dictionary of size nX(param.K).

% output Struct that contains information about the current run. It may include the following fields:

% CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.

% ratio If the true dictionary was defined (in

% synthetic experiments), this parameter holds a vector of length

% param.numIteration that includes the detection ratios in each

% iteration).

% totalerr The total representation error after each

% iteration (defined only if

% param.displayProgress=1 and

% param.errorFlag = 0)

% numCoef A vector of length param.numIteration that

% include the average number of coefficients required for representation

% of each signal (in each iteration) (defined only if

% param.displayProgress=1 and

% param.errorFlag = 1)

% =========================================================================



if (~isfield(param,'displayProgress'))

param.displayProgress = 0;

end

totalerr(1) = 99999;

if (isfield(param,'errorFlag')==0)

param.errorFlag = 0;

end



if (isfield(param,'TrueDictionary'))

displayErrorWithTrueDictionary = 1;

ErrorBetweenDictionaries = zeros(param.numIteration+1,1);

ratio = zeros(param.numIteration+1,1);

else

displayErrorWithTrueDictionary = 0;

ratio = 0;

end

if (param.preserveDCAtom>0)

FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));

else

FixedDictionaryElement = [];

end

% coefficient calculation method is OMP with fixed number of coefficients



if (size(Data,2) < param.K)

disp('Size of data is smaller than the dictionary size. Trivial solution...');

Dictionary = Data(:,1:size(Data,2));

return;

elseif (strcmp(param.InitializationMethod,'DataElements'))

Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);

elseif (strcmp(param.InitializationMethod,'GivenMatrix'))

Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);

end

% reduce the components in Dictionary that are spanned by the fixed

% elements

if (param.preserveDCAtom)

tmpMat = FixedDictionaryElement \ Dictionary;

Dictionary = Dictionary - FixedDictionaryElement*tmpMat;

end

%normalize the dictionary.

Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));

Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.

totalErr = zeros(1,param.numIteration);



% the K-SVD algorithm starts here.



for iterNum = 1:param.numIteration

% find the coefficients

if (param.errorFlag==0)

%CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);

CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);

else

%CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);

CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);

param.L = 1;

end



replacedVectorCounter = 0;

rPerm = randperm(size(Dictionary,2));

for j = rPerm

[betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...

[FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...

CoefMatrix ,param.L);

Dictionary(:,j) = betterDictionaryElement;

if (param.preserveDCAtom)

tmpCoef = FixedDictionaryElement\betterDictionaryElement;

Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;

Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));

end

replacedVectorCounter = replacedVectorCounter+addedNewVector;

end



if (iterNum>1 & param.displayProgress)

if (param.errorFlag==0)

output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));

disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]);

else

output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);

disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);

end

end

if (displayErrorWithTrueDictionary )

[ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);

disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));

output.ratio = ratio;

end

Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);



if (isfield(param,'waitBarHandle'))

waitbar(iterNum/param.counterForWaitBar);

end

end



output.CoefMatrix = CoefMatrix;

Dictionary = [FixedDictionaryElement,Dictionary];

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% findBetterDictionaryElement

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%



function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)

if (length(who('numCoefUsed'))==0)

numCoefUsed = 1;

end

relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.

if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)

ErrorMat = Data-Dictionary*CoefMatrix;

ErrorNormVec = sum(ErrorMat.^2);

[d,i] = max(ErrorNormVec);

betterDictionaryElement = Data(:,i);%ErrorMat(:,i); %

betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement);

betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1));

CoefMatrix(j,:) = 0;

NewVectorAdded = 1;

return;

end



NewVectorAdded = 0;

tmpCoefMatrix = CoefMatrix(:,relevantDataIndices);

tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.

errors =(Data(:,relevantDataIndices) - Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element

% % the better dictionary element and the values of beta are found using svd.

% % This is because we would like to minimize || errors - beta*element ||_F^2.

% % that is, to approximate the matrix 'errors' with a one-rank matrix. This

% % is done using the largest singular value.

[betterDictionaryElement,singularValue,betaVector] = svds(errors,1);

CoefMatrix(j,relevantDataIndices) = singularValue*betaVector';% *signOfFirstElem



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% findDistanseBetweenDictionaries

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)

% first, all the column in oiginal starts with positive values.

catchCounter = 0;

totalDistances = 0;

for i = 1:size(new,2)

new(:,i) = sign(new(1,i))*new(:,i);

end

for i = 1:size(original,2)

d = sign(original(1,i))*original(:,i);

distances =sum ( (new-repmat(d,1,size(new,2))).^2);

[minValue,index] = min(distances);

errorOfElement = 1-abs(new(:,index)'*d);

totalDistances = totalDistances+errorOfElement;

catchCounter = catchCounter+(errorOfElement<0.01);

end

ratio = 100*catchCounter/size(original,2);



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% I_clearDictionary

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)

T2 = 0.99;

T1 = 3;

K=size(Dictionary,2);

Er=sum((Data-Dictionary*CoefMatrix).^2,1); % remove identical atoms

G=Dictionary'*Dictionary; G = G-diag(diag(G));

for jj=1:1:K,

if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,

[val,pos]=max(Er);

Er(pos(1))=0;

Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));

G=Dictionary'*Dictionary; G = G-diag(diag(G));

end;

end;

关于Machine Learning更多的学习资料与相关讨论将继续更新,敬请关注本博客和新浪微博Rachel____Zhang.

转自http://blog.csdn.net/abcjennifer/article/details/8603214
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: