随机梯度下降和批量梯度下降的简单代码实现
2016-04-16 21:36
453 查看
最近刚刚开始看斯坦福的机器学习公开课,第一堂课讲到了梯度下降,然后就简单实现了一下。本人学渣一枚,如有错误,敬请指出。
/** * Created by Administrator on 2016/4/16 0016. */ public class GradientDescent { private static double[][] data = { {3.8, 192.0314202}, {3.5, 194.1168421}, {4, 195.1114837}, {4.4, 197.7640977}, {4.1, 196.8811122}, {4.6, 202.9643527}, {3.6, 191.245283}, {3.2, 189.2631579}, {3.4, 189.9758454}, {3, 187.6717949}, {3.9, 193.5243902}, {3.1, 189.2704403}, {2.2, 177.248366}, {3.7, 189.296875}, {3.3, 189.5043478}, {4.2, 199.6857143}, }; //根据excel得到的回归方程:y = 9.3581x + 158.3,数据来自日常的一个项目 public static void main(String[] args) { stochastic(data); batch(data); } /* * 当rate = 0.01时 * 循环2000左右的时候值就不变化了 * parameter is 157.90981024717982 9.482991891267803 * error is 47.897064097242335 * * 当rate = 0.001时 * 循环30000,最后结果几乎不变 * parameter is 158.25947581125462 9.36980772136795 * error is 47.7491293901012 * */ private static void stochastic(double[][] data) { double[] p = {0, 0};//初始化参数为0 double rate = 0.001; for (int i = 0; i < 30000; i++) { for (double[] aData : data) { double h = 0, err; h += p[0] + p[1] * aData[0]; err = aData[1] - h; //根据每一条数据更新参数 p[0] += rate * err * 1; p[1] += rate * err * aData[0]; } } System.out.println("parameter is " + p[0] + " " + p[1]); double error = 0; for (double[] aData : data) { error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2); } System.out.println("error is " + error); } /* * rate = 0.001, 循环次数等于30000时,所计算的结果和excel计算的几乎完全一致 * * parameter is 158.299201608832 9.358074090590318 * error is 47.74825830393555 * * 批量梯度下经确实更加准确 * */ private static void batch(double[][] data) { double[] p = {0, 0}; double rate = 0.001; for (int i=0;i<50000;i++){ double err1 = 0; double err2 = 0; for (double[] aData:data){ double h=0; h=p[0]+p[1]*aData[0]; err1 += aData[1] - h; err2 += (aData[1]-h)*aData[0]; } //遍历整个数据集之后再更新参数 p[0] += rate*err1; p[1] += rate*err2; } System.out.println("parameter is " + p[0] + " " + p[1]); double error = 0; for (double[] aData : data) { error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2); } System.out.println("error is " + error); } }
相关文章推荐
- 安装java之后,找不到tools.jar和dt.jar(dos下javac命令无法执行)
- python-MySQL学习笔记-第四章-利用Connector/Python来查询数据
- struts2入门
- Java中字符串两种等于的方法的对比
- vb6.0陈伟教学视频总结
- JAVA中最常见到的exception
- window下的github使用步骤
- [LeetCode]189. Rotate Array
- 破解Zend Studio步骤
- 快速排序(C语言)
- struts2中配置文件加载的顺序是什么?
- php7连不上mysql求帮忙!!!
- 使用springmvc遇到的问题
- C++实现单链表的创建和打印
- php输入输出
- 【Python】统计个人新浪微博词频并给出相应的柱状图
- java实现 用数组实现循环队列
- VB.NET章鱼哥出品—怎样解决MDI子窗口被父窗口中的控件覆盖的问题
- 深入理解Java内存模型
- Java 泛型类