AdaBoost的java实现
2015-06-11 11:46
537 查看
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。
bagging和boosting的区别
bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
em=
3 计算该分类器的权重
可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
其中Zm是规范化银子:
5 构建基本分类器
F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
终于不用打公式了。。。。
附上代码:
这里的数据采用的是统计学习方法中的数据
这里是单个特征的,也可以是多维数据,例如
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。
bagging和boosting的区别
bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
em=
3 计算该分类器的权重
可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
其中Zm是规范化银子:
5 构建基本分类器
F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
终于不用打公式了。。。。
附上代码:
import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; class Stump{ public int dim; public double thresh; public String condition; public double error; public ArrayList<Integer> labelList; double factor; public String toString(){ return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList; } } class Utils{ //加载数据集 public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{ ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>(); FileInputStream fis=new FileInputStream(filename); InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); BufferedReader br=new BufferedReader(isr); String line=""; while((line=br.readLine())!=null){ ArrayList<Double> data=new ArrayList<Double>(); String[] s=line.split(" "); for(int i=0;i<s.length-1;i++){ data.add(Double.parseDouble(s[i])); } dataSet.add(data); } return dataSet; } //加载类别 public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{ ArrayList<Integer> labelSet=new ArrayList<Integer>(); FileInputStream fis=new FileInputStream(filename); InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); BufferedReader br=new BufferedReader(isr); String line=""; while((line=br.readLine())!=null){ String[] s=line.split(" "); labelSet.add(Integer.parseInt(s[s.length-1])); } return labelSet; } //测试用的 public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){ for(ArrayList<Double> data:dataSet){ System.out.println(data); } } //获取最大值,用于求步长 public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){ double max=-9999.0; for(ArrayList<Double> data:dataSet){ if(data.get(index)>max){ max=data.get(index); } } return max; } //获取最小值,用于求步长 public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){ double min=9999.0; for(ArrayList<Double> data:dataSet){ if(data.get(index)<min){ min=data.get(index); } } return min; } //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别 public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){ ArrayList<Integer> labelList=new ArrayList<Integer>(); if(condition.compareTo("lt")==0){ for(ArrayList<Double> data:dataSet){ if(data.get(feature)<=thresh){ labelList.add(1); }else{ labelList.add(-1); } } }else{ for(ArrayList<Double> data:dataSet){ if(data.get(feature)>=thresh){ labelList.add(1); }else{ labelList.add(-1); } } } return labelList; } //求预测类别与真实类别的加权误差 public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){ double error=0; int n=real.size(); for(int i=0;i<fake.size();i++){ if(fake.get(i)!=real.get(i)){ error+=weights.get(i); } } return error; } //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。 public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){ int featureNum=dataSet.get(0).size(); int rowNum=dataSet.size(); Stump stump=new Stump(); double minError=999.0; System.out.println("第"+n+"次迭代"); for(int i=0;i<featureNum;i++){ double min=getMin(dataSet,i); double max=getMax(dataSet,i); double step=(max-min)/(rowNum); for(double j=min-step;j<=max+step;j=j+step){ String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类 for(String condition:conditions){ ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); double error=Utils.getError(labelList,labelSet,weights); if(error<minError){ minError=error; stump.dim=i; stump.thresh=j; stump.condition=condition; stump.error=minError; stump.labelList=labelList; stump.factor=0.5*(Math.log((1-error)/error)); } } } } return stump; } public static ArrayList<Double> getInitWeights(int n){ double weight=1.0/n; ArrayList<Double> weights=new ArrayList<Double>(); for(int i=0;i<n;i++){ weights.add(weight); } return weights; } //更新样本权值 public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){ double Z=0; ArrayList<Double> newWeights=new ArrayList<Double>(); int row=labelList.size(); double e=Math.E; double factor=stump.factor; for(int i=0;i<row;i++){ Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i)); } for(int i=0;i<row;i++){ double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z; newWeights.add(weight); } return newWeights; } //对加权误差累加 public static ArrayList<Double> InitAccWeightError(int n){ ArrayList<Double> accError=new ArrayList<Double>(); for(int i=0;i<n;i++){ accError.add(0.0); } return accError; } public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){ ArrayList<Integer> t=stump.labelList; double factor=stump.factor; ArrayList<Double> newAccError=new ArrayList<Double>(); for(int i=0;i<t.size();i++){ double a=accerror.get(i)+factor*t.get(i); newAccError.add(a); } return newAccError; } public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){ ArrayList<Integer> a=new ArrayList<Integer>(); int wrong=0; for(int i=0;i<accError.size();i++){ if(accError.get(i)>0){ if(labelList.get(i)==-1){ wrong++; } }else if(labelList.get(i)==1){ wrong++; } } double error=wrong*1.0/accError.size(); return error; } public static void showStumpList(ArrayList<Stump> G){ for(Stump s:G){ System.out.println(s); System.out.println(" "); } } } public class Adaboost { /** * @param args * @throws IOException */ public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){ int row=labelList.size(); ArrayList<Double> weights=Utils.getInitWeights(row); ArrayList<Stump> G=new ArrayList<Stump>(); ArrayList<Double> accError=Utils.InitAccWeightError(row); int n=1; while(true){ Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树 G.add(stump); weights=Utils.updateWeights(stump,labelList,weights);//更新权值 accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了 double error=Utils.calErrorRate(accError,labelList); if(error<0.001){ break; } n++; } return G; } public static void main(String[] args) throws IOException { // TODO Auto-generated method stub String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt"; ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file); ArrayList<Integer> labelSet=Utils.loadLabelSet(file); ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet); Utils.showStumpList(G); System.out.println("finished"); } }
这里的数据采用的是统计学习方法中的数据
0 1 1 1 2 1 3 -1 4 -1 5 -1 6 1 7 1 8 1 9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1 2.0 1.1 1 1.3 1.0 -1 1.0 1.0 -1 2.0 1.0 1
相关文章推荐
- 【转载】eclipse调试arm裸机程序
- java处理json的工具类
- Struts2中动态方法调用
- Java多线程系列--“JUC锁”01之 框架
- 运用spring task定时器发布定时任务
- java中Arrays类
- Java String.Format() 方法及参数说明
- Mac下的eclipse中svn插件使用代理
- 整合hibernate4.2和spring框架,出现No Session found for current threa报错
- Java模式(适配器型号)
- Shiro配置---基于spring框架
- Java位运算符
- struts2配置详解
- JAVA 中两种判断输入的是否是数字的方法
- mark:Eclipse导入workspace存在的项目
- 编译Hadoop-Eclipse插件
- Spring + JdbcTemplate + JdbcDaoSupport + HibernateDaoSupport examples
- java基础学习步骤
- 【深入JAVA】cglib动态代理
- eclipse自动生成方法注释 快捷键