您的位置:首页 > 编程语言 > MATLAB

SVM (支持向量机)

2015-07-11 19:10 288 查看

1. 介绍

SVM (Support Vector Machine,支持向量机)是一种有监督的统计学习方法,能最小化经验误差和最大化几何边缘,被称为最大间隔分类器,可用于分类与回归分析。



如上图所述的线性分类问题可以使用 PLA 或 pocket 方法求解。得到下式的线性分类器:

h(x)=sign(wTx)

由于分类超平面不一定过原点,所以线性分类器可以表示为:

h(x)=sign(wTx+b)

然而可能会存在多个 w 都满足条件,如下图所示,三个线性分类器都能准确的把蓝色的圆和红色的叉分开,但是哪一个更好呢?PLA 或 pocket 就不能给出答案了。



直观的感觉是第三个好,那么为什么第三个好呢?如下图所示,输入的数据点可能存在测量误差(噪声),而第三个分类器能最大的容忍测量误差,也就是说即使测量数据有一定的误差,在第三个分类器下依然能获得最好的分类效果。



那么,应该如何找到上述最好的分类器呢?

2. SVM 推导



如果说线有宽度的话,那么寻找最好的线性分类器可以看作寻找最宽的线,使其恰好能把不同的类别分开,而这个宽度又等价于该线到所有点的最小距离,综上,上述问题可以转化为:

maxw,bmargin(w,b)s.t.ynwTxn+b>0,n=1,...,Nmargin(w,b)=minn=1,...,Ndistance(xn,w,b)

其中,x 为输入数据,也就是分类超平面上的点,y 为输入数据所属的类别,w 为垂直于分类超平面的向量,b 表示截距。



如上图所示,假设 x′,x′′ 为超平面上的点,满足:

1. wTx′=−b,wTx′′=−b′

2. w⊥超平面

3. distance = (x-x’) 在 w 上的投影的长度

所以:

distance(x,w,b)=|wT||w||(x−x′)|=1||w|||wTx+b|

所以,目标函数和约束条件可以转化为:

maxw,bmargin(w,b)s.t.yn(wTxn+b)>0,n=1,...,Nmargin(w,b)=minn=1,...,N1||w|||wTx+b|

上式又可以转换为:

maxw,b1||w||s.t.minn=1,...,Nyn(wTx+b)=1

上式又可以转换为:

minw,b12wTws.t.yn(wTx+b)≥1,n=1,..N

可以通过反证法证明,将约束条件的等式约束转换为不等式约束是等价的。

3. 求解

对 SVM 的求解可以转化为对 QP(Quadratic Programming,二次规划)的求解,看如下的二次规划问题的标准形式:

u←QP(Q,p,A,c)minu12uTQu+pTus.t.aTmu≥cm,m=1,2,...,M

而 SVM 的目标函数和约束问题:

minw,b12wTws.t.yn(wTx+b)≥1,n=1,..N

等价于:

minw12[bw]T[00d0TdId][bw]+0Td+1us.t.yn[1xTn][bw]≥1,n=1,..,N

也就是说:

u←QP(Q,p,A,c)minu12uTQu+pTus.t.aTmu≥cm,m=1,2,...,M

其中:

u=[bw]Q=[00d0TdId]p=0d+1aTn=yn[1xTn]cn=1M=N

matlab 中的 quadprog 函数可用于求解该问题。

4. 示例

% 功能:演示SVM算法
% 基于 SVM 实现特征分类;
% 时间:2015-07-11

clc
clear all
close all

%% 测试样本
dataLength = 2;
dataNumber = [100, 100];

% 第一类
x1 = randn(dataLength, dataNumber(1));
y1 = ones(1, dataNumber(1));

% 第二类
x2 = 5 + randn(dataLength, dataNumber(2));
y2 = -ones(1, dataNumber(2));

% 显示
figure(1);
plot(x1(1,:), x1(2,:), 'bx', x2(1,:), x2(2,:), 'k.');
axis([-3 8 -3 8]);
title('SVM')
hold on

% 合并样本
X = [x1, x2];
Y = [y1, y2];

% 打乱样本顺序
index = randperm(sum(dataNumber));
X(:, index) = X;
Y(:, index) = Y;

%% SVM 训练
% line : w1x1 + w2x2 + b = 0
% weight = [b, w1, w2]
weight = svmTrainMine(X, Y);

%% 测试输出

% y = kx + b
k = -weight(2) / weight(3);
b = weight(1) / weight(3);

xLine = -2:0.1:7;
yLine = k .* xLine - b;
plot(xLine, yLine, 'r')
hold on

%% 查找支持向量
epsilon = 1e-5;
dist = abs(k .* X(1, :) - X(2,:) - b);
i_sv = find(dist <= min(dist(:)) + epsilon);
plot(X(1,i_sv), X(2,i_sv),'ro');


结果:



5. 完整代码

GitHub

6. 参考

《视觉机器学习20讲》第九讲

《Coursera 机器学习技法(林轩田 - 台湾大学 》01 Linear Support Vector Machine
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息