您的位置:首页 > 编程语言

随机梯度下降和批量梯度下降的简单代码实现

2016-04-16 21:36 453 查看
最近刚刚开始看斯坦福的机器学习公开课,第一堂课讲到了梯度下降,然后就简单实现了一下。本人学渣一枚,如有错误,敬请指出。

/**
* Created by Administrator on 2016/4/16 0016.
*/
public class GradientDescent {
private static double[][] data = {
{3.8, 192.0314202},
{3.5, 194.1168421},
{4, 195.1114837},
{4.4, 197.7640977},
{4.1, 196.8811122},
{4.6, 202.9643527},
{3.6, 191.245283},
{3.2, 189.2631579},
{3.4, 189.9758454},
{3, 187.6717949},
{3.9, 193.5243902},
{3.1, 189.2704403},
{2.2, 177.248366},
{3.7, 189.296875},
{3.3, 189.5043478},
{4.2, 199.6857143},

};
//根据excel得到的回归方程:y = 9.3581x + 158.3,数据来自日常的一个项目

public static void main(String[] args) {
stochastic(data);

batch(data);
}

/*
* 当rate = 0.01时
* 循环2000左右的时候值就不变化了
* parameter is 157.90981024717982 9.482991891267803
* error is 47.897064097242335
*
* 当rate = 0.001时
* 循环30000,最后结果几乎不变
* parameter is 158.25947581125462 9.36980772136795
* error is 47.7491293901012
* */

private static void stochastic(double[][] data) {
double[] p = {0, 0};//初始化参数为0
double rate = 0.001;

for (int i = 0; i < 30000; i++) {
for (double[] aData : data) {
double h = 0, err;
h += p[0] + p[1] * aData[0];
err = aData[1] - h;

//根据每一条数据更新参数
p[0] += rate * err * 1;
p[1] += rate * err * aData[0];
}
}
System.out.println("parameter is " + p[0] + " " + p[1]);

double error = 0;
for (double[] aData : data) {
error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);
}
System.out.println("error is " + error);
}

/*
* rate = 0.001, 循环次数等于30000时,所计算的结果和excel计算的几乎完全一致
*
* parameter is 158.299201608832 9.358074090590318
* error is 47.74825830393555
*
* 批量梯度下经确实更加准确
* */

private static void batch(double[][] data) {
double[] p = {0, 0};
double rate = 0.001;

for (int i=0;i<50000;i++){
double err1 = 0;
double err2 = 0;

for (double[] aData:data){
double h=0;
h=p[0]+p[1]*aData[0];
err1 += aData[1] - h;
err2 += (aData[1]-h)*aData[0];
}

//遍历整个数据集之后再更新参数
p[0] += rate*err1;
p[1] += rate*err2;
}

System.out.println("parameter is " + p[0] + " " + p[1]);

double error = 0;
for (double[] aData : data) {
error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);
}
System.out.println("error is " + error);
}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: