您的位置:首页 > 编程语言 > MATLAB

坐标下降法(坐标上升法)matlab程序

2017-06-12 20:28 225 查看

起因

因为求解SVM的最牛算法SMO算法,使用的时坐标下降法的思路,所以学习一下这个算法。

方法

某函数包含多个自变量,需要求这个函数的最大或最小值时,可以应用此坐标下降法(最小值)或坐标上升法(最大值)。

其过程是,对每个自变量求偏导,交替的对每个自变量进行梯度下降(或上升法)。

案例

我们使用以下函数作为案例

z=f(x,y)=xe−(x2+y2)

其函数图为:



可以看到,这个函数优一个最大值和最小值。

对其求关于x,y的偏导。

∂z∂x=e−(x2+y2)+xe−(x2+y2)(−2x)=e−(x2+y2)(1−2x2)

∂z∂y=xe−(x2+y2)(−2y)

这个问题比较简单,我们直接令上述两个导数等于0,就可以求出 x=+−2√/2

y=0

但很多实际问题是难以求解的,这时就应使用迭代的算法。

对每个参数交替使用梯度下降。

求解

1 先对x和y赋随机的初值。

2 随后对x使用梯度上升(我们要求最大值),a 为学习因子

x=x+a∗∂z∂x

3 随后对y使用梯度上升

y=y+a∗∂z∂y

4 重复2-3步,直至收敛。

代码

function z = f(x,y)
z = x.*exp(-x.^2 - y.^2);

end


clc;clear;
x = 0.2;
y = 0.7; %初始值
a = 0.2; %学习率
xa = [];
ya = [];
za = [];
oldx = 1;
oldy = 1;
while abs(oldx-x)+abs(oldy - y) > 1e-7
z= f(x,y);
dx = exp(-x.^2 - y.^2) -2*x*z;
dy = z.*(-2*y);
oldx=x;
oldy=y;
x = x + a*dx;
y = y + a*dy;
xa = [xa x];
ya = [ya y];
za = [za f(x,y)];
fprintf('x = %f ,y = %f , cha = %f\n',x,y,abs(oldx-x)+abs(oldy - y));

end

fw = -2:0.1:2;
[x,y] = meshgrid(fw,fw);
z = f(x,y);
hold off;
mesh(x,y,z);
xlabel('x');
ylabel('y');
zlabel('z');
pause
hold on;
plot3(xa,ya,za,'LineWidth',2);


结果

matlab输出

x = 0.308303 ,y = 0.667038 , cha = 0.141265
x = 0.402698 ,y = 0.619101 , cha = 0.142332
x = 0.481018 ,y = 0.561303 , cha = 0.136119
x = 0.543232 ,y = 0.498771 , cha = 0.124746
x = 0.590809 ,y = 0.435857 , cha = 0.110491
x = 0.626028 ,y = 0.375773 , cha = 0.095303
x = 0.651398 ,y = 0.320559 , cha = 0.080583
x = 0.669268 ,y = 0.271252 , cha = 0.067178
x = 0.681635 ,y = 0.228145 , cha = 0.055474
x = 0.690075 ,y = 0.191040 , cha = 0.045545
x = 0.695776 ,y = 0.159460 , cha = 0.037281
x = 0.699596 ,y = 0.132798 , cha = 0.030482
x = 0.702141 ,y = 0.110417 , cha = 0.024926
x = 0.703830 ,y = 0.091705 , cha = 0.020401
x = 0.704947 ,y = 0.076105 , cha = 0.016718
x = 0.705685 ,y = 0.063124 , cha = 0.013718
x = 0.706171 ,y = 0.052338 , cha = 0.011272
x = 0.706492 ,y = 0.043384 , cha = 0.009274
x = 0.706702 ,y = 0.035955 , cha = 0.007639
x = 0.706841 ,y = 0.029795 , cha = 0.006299
x = 0.706932 ,y = 0.024688 , cha = 0.005198
x = 0.706992 ,y = 0.020455 , cha = 0.004293
x = 0.707031 ,y = 0.016948 , cha = 0.003547
x = 0.707057 ,y = 0.014041 , cha = 0.002932
x = 0.707074 ,y = 0.011633 , cha = 0.002425
x = 0.707085 ,y = 0.009637 , cha = 0.002007
x = 0.707093 ,y = 0.007984 , cha = 0.001661
x = 0.707098 ,y = 0.006615 , cha = 0.001374
x = 0.707101 ,y = 0.005480 , cha = 0.001138
x = 0.707103 ,y = 0.004540 , cha = 0.000942
x = 0.707104 ,y = 0.003761 , cha = 0.000780
x = 0.707105 ,y = 0.003116 , cha = 0.000646
x = 0.707106 ,y = 0.002581 , cha = 0.000535
x = 0.707106 ,y = 0.002138 , cha = 0.000443
x = 0.707106 ,y = 0.001772 , cha = 0.000367
x = 0.707106 ,y = 0.001468 , cha = 0.000304
x = 0.707107 ,y = 0.001216 , cha = 0.000252
x = 0.707107 ,y = 0.001007 , cha = 0.000209
x = 0.707107 ,y = 0.000835 , cha = 0.000173
x = 0.707107 ,y = 0.000691 , cha = 0.000143
x = 0.707107 ,y = 0.000573 , cha = 0.000119
x = 0.707107 ,y = 0.000474 , cha = 0.000098
x = 0.707107 ,y = 0.000393 , cha = 0.000081
x = 0.707107 ,y = 0.000326 , cha = 0.000067
x = 0.707107 ,y = 0.000270 , cha = 0.000056
x = 0.707107 ,y = 0.000224 , cha = 0.000046
x = 0.707107 ,y = 0.000185 , cha = 0.000038
x = 0.707107 ,y = 0.000153 , cha = 0.000032
x = 0.707107 ,y = 0.000127 , cha = 0.000026
x = 0.707107 ,y = 0.000105 , cha = 0.000022
x = 0.707107 ,y = 0.000087 , cha = 0.000018
x = 0.707107 ,y = 0.000072 , cha = 0.000015
x = 0.707107 ,y = 0.000060 , cha = 0.000012
x = 0.707107 ,y = 0.000050 , cha = 0.000010
x = 0.707107 ,y = 0.000041 , cha = 0.000009
x = 0.707107 ,y = 0.000034 , cha = 0.000007
x = 0.707107 ,y = 0.000028 , cha = 0.000006
x = 0.707107 ,y = 0.000023 , cha = 0.000005
x = 0.707107 ,y = 0.000019 , cha = 0.000004
x = 0.707107 ,y = 0.000016 , cha = 0.000003
x = 0.707107 ,y = 0.000013 , cha = 0.000003
x = 0.707107 ,y = 0.000011 , cha = 0.000002
x = 0.707107 ,y = 0.000009 , cha = 0.000002
x = 0.707107 ,y = 0.000008 , cha = 0.000002
x = 0.707107 ,y = 0.000006 , cha = 0.000001
x = 0.707107 ,y = 0.000005 , cha = 0.000001
x = 0.707107 ,y = 0.000004 , cha = 0.000001
x = 0.707107 ,y = 0.000004 , cha = 0.000001
x = 0.707107 ,y = 0.000003 , cha = 0.000001
x = 0.707107 ,y = 0.000002 , cha = 0.000001
x = 0.707107 ,y = 0.000002 , cha = 0.000000
x = 0.707107 ,y = 0.000002 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000001 , cha = 0.000000
x = 0.707107 ,y = 0.000000 , cha = 0.000000


图形化显示



其中蓝线为x和y的变化曲线。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息