您的位置:首页 > 编程语言 > MATLAB

梯度下降(Gradient Descent)简析及matlab实现

2014-03-21 20:51 393 查看
梯度下降用处广泛,既可以用于回归也可以用于分类

给定训练集,方便起见,我们用二维点表示我们的训练数据集





上图中的每一横行代表一对儿平面上的点,我们要找到一条线,来最好的拟合这些点的趋向。

假设这条线的形式为y = w0+w1*x1+w2*x2+......wn*xn

其中wi表示第i个系数,xi表示一个训练样本中第i维的值,在我们的例子中i取值为1,因为只有横坐标x这一维度的取值,如果是在立体空间,那么i自然取2,当然还有更高维的空间......

在这里,y轴对应的点,就是上图右列的值,我们称为目标值,在分类里称为label,标签。

那么回到我们这个例子,我们的直线形式就为y=w0+w1*x1;

为了方便求解,我们这里用X表示上图左列的25个点,Y表示上图右列25个点

再做一点变化,我们把X扩展一下,X变成一个25X2的矩阵,X的第一列全为1,第二列为原来X的值,那么X变成下图这样



这样的话,我们就可以把我们的要求解的方程写成这个样子 Y = X*W,其中W 为 [w0 w1] 矩阵的转置,记W是一个2X1的矩阵。这一步变化应该不成问题吧?至于为什么,程序里用起来就知道为什么了...

那么怎么求出这条直线呢?换句话说,怎么求出W矩阵呢?

换个思路,加入我们找到了一条直线,y=w0+w1*x,这条直线能比较好的拟合原来的数据集,那么什么叫比较好呢?标准比较多,一般我们都会这样想,我们把上面给出的点的x值代入y=w0+w1*x得到一个y值,用这个y值和我们的准确值(给定的值)相比较会有一个差值,当然也可能为0,那是我们梦寐以求的,但是,看看上面的点,这样一条直线是不存在的!每一个训练样本都这样做,就会的到25个差值,我们希望总差值最小就好了!这就是思路!

error = sum(ti - oi)^2 其中i = 1,2,3...25

sum表示求和,ti = w0 + w1*xi,就是用我们暂时找到的一条直线求解一个值, oi就是yi,就是第i个样本对应的真实值。

那么为了方便一会儿求导,我们给上面的error乘以一个1/2

error = 1/2 * sum(ti - oi)^2 i = 1,2,3......25

下面就是求上式的最小值喽,好好看看error怎么表示的

error = 1/2 * ( (w0+w1*x1 - yi)^2 + (w0+w1*x2 - y2)^2 + ......+(w0+w1*x25 - y25)^2 )

那么如何求上面的最小值,我们的数学家告诉我们,如果一个函数可微,那么沿着梯度的方向走,只要步伐合适,你就能走到最小值处~

error的梯度怎么表示?

error对w0和w1分别求偏导,得到的两个数组成的二维向量即为梯度

error对wi求偏导得到的是detawi = sum(tj-oj)*xij

其中,tj和oj和上文类似,表示我们的第j个样本的所求值和目标值, xij表示第j个样本的第i个维度的值,这里我们X是一个两维的但是第一维总为1

/**************************************************************************************************************************************************************************************/

下面给出梯度下降算法步骤

input: <X,Y>

output: W

1. 初始化W中的每一个wi为某个小的随机数 //比如全为0.1

2. 遇到终止条件之前,做一下操作:

初始化detawi为0 //detawi 表示的就是error对wi求得的偏导

对于训练集中的每一个样本做:

把样本x输入到我们当前得到的直线中计算输出o

对于每一个detawi做:

detawi = detawi + k * (t - o) * xi //k是学习速率,即上文说的步伐大小,t为x样本的目标值,xi为样本x中第i维的值

对于W中的每个wi做:

wi = wi +
detawi;

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////说明一点:学习速率如果把握不好有可能求不到最终解!!!,还有终止条件的选择也得合适~

附上matlab代码:

function W = Gradient_descent(X,Y)
tX = X;%tX的作用只是为了创建detaW时使用
W = 0.1*ones(size(X,2)+1,1);
X = [ones(size(X,1),1),X];
k = 0;%迭代次数
while true
detaW = zeros(size(tX,2)+1,1);
O = X*W;
detaW = detaW + 0.00005*X'*(Y - O);%0.00005是学习速率
W = W + detaW;
k = k+1;
if 1/2*( norm(Y - X*W) )^2 < 5 || k>10000 %如果误差小于5或者迭代次数大于10000则停止
break;
end
fprintf('iterator times %d, error %f\n',k,1/2*( norm(Y - X*W) )^2);
end


实验结果:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: