您的位置:首页 > 其它

weka之Bagging的源码分析及相关知识点

2017-02-13 20:11 239 查看
Bagging的源码分析及相关知识点



1、Bagging的构造函数:

public Bagging() {

m_Classifier = new weka.classifiers.trees.REPTree();

}

2、Bagging的继承关系及父类的主要属性和方法

(以下为逐级单重继承及抽象父类的一些重要属性和方法)

Bagging

(protected intm_BagSizePercent = 100;

protected boolean m_CalcOutOfBag = false;

protected double m_OutOfBagError;

public voidbuildClassifier(Instances
data) throwsException{}


public double[] distributionForInstance(Instance instance) throwsException{}

public static void
main(String[] argv) {


runClassifier(newBagging(), argv);

}



继承——>RandomizableIteratedSingleClassifierEnhancer

(protected int m_Seed = 1;)

[b]继承——>IteratedSingleClassifierEnhancer[/b]

(protected Classifier[]
m_Classifiers;


protected intm_NumIterations = 10;

public voidbuildClassifier(Instancesdata)
throws Exception{


if (m_Classifier == null) {

throw newException("A base classifier has not been specified!");

}

m_Classifiers = Classifier.makeCopies(m_Classifier,m_NumIterations);

}



继承——>SingleClassifierEnhancer

(protected Classifierm_Classifier = new ZeroR();

public void setClassifier(ClassifiernewClassifier){}

public Classifier getClassifier(){}

protected String getClassifierSpec(){}



继承——>Classifier

(protectedboolean m_Debug = false;

publicabstract voidbuildClassifier(Instances
data) throwsException;


public double
classifyInstance(Instance instance) throwsException{}


public double[]
distributionForInstance(Instance instance)throws Exception{}


public static Classifier forName(StringclassifierName, String[] options) throws Exception{}

public static Classifier
makeCopy(Classifier model) throws Exception{}


public static Classifier[] makeCopies(Classifier model, int num) throwsException{}

protected static void
runClassifier(Classifier classifier, String[]options){}




3、父类引用指向子类对象:多态、动态链接,向上转型(插曲)

ZeroR——> Classifier

Protected Classifier m_Classifier = newZeroR();

对于多态,可以总结以下几点:

Ø 使用父类类型的引用指向子类的对象;

Ø 该引用只能调用父类中定义的方法和变量;

Ø 如果子类中重写了父类中的一个方法,那么在调用这个方法的时候,将会调用子类中的这个方法;(动态连接、动态调用)

Ø 变量不能被重写(覆盖),”重写“的概念只针对方法,如果在子类中”重写“了父类中的变量,那么在编译时会报错。

一个父类类型的引用指向一个子类的对象既可以使用子类强大的功能,又可以抽取父类的共性,父类类型的引用可以调用父类中定义的所有属性和方法,而对于子类中定义而父类中没有的方法,父类引用是无法调用的;

那什么是动态链接呢?当父类中的一个方法只有在父类中定义而在子类中没有重写的情况下,才可以被父类类型的引用调用;对于父类中定义的方法,如果子类中重写了该方法,那么父类类型的引用将会调用子类中的这个方法,这就是动态连接。

注:当超类对象引用变量引用子类对象时,被引用对象的类型而不是引用变量的类型决定了调用谁的成员方法,但是这个被调用的方法必须是在超类中定义过的,也就是说被子类覆盖的方法。

4、abstract的用法(插曲)

² abstract修饰类,会使这个类成为一个抽象类,这个类将不能生成对象实例,可以做为对象变量声明的类型,也就是编译时类型,抽象类就像当于一类的半成品,需要子类继承并覆盖其中的抽象方法。

² abstract修饰方法,会使这个方法变成抽象方法,声明(定义)而没有实现,实现部分以";"代替。需要子类继承实现(覆盖)。

² abstract修饰符在修饰类时必须放在类名前。

² abstract修饰方法就是要求其子类覆盖(实现)这个方法。调用时可以以多态方式调用子类覆盖(实现)后的方法,也就是说抽象方法必须在其子类中实现,除非子类本身也是抽象类。

² 父类是抽象类,有抽象方法,子类继承父类,并把父类中的所有抽象方法都实现(覆盖),抽象类中有构造方法,是子类在构造子类对象时需要调用的父类(抽象类)的构造方法。

5、Bagging运行过程剖析:

1)运行主函数runClassifier

public static void main(String[] argv) {
runClassifier(new Bagging(), argv);
}

2)构造函数选择分类器

//Bagging构造函数默认选择分类器REPTree
public Bagging() {
m_Classifier = new weka.classifiers.trees.REPTree();
}

3)buildClassifier

① 处理数据

public void buildClassifier(Instances data) throws Exception {

// can classifier handle the data?
getCapabilities().testWithFail(data);

// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();

② 调用父类buildClassifier()方法

super.buildClassifier(data);  //调用父类的方法:【IteratedSingleClassifierEnhancer】中的【buildClassifier()】

[此方法可以得到多个分类器m_Classifiers,分类器类型与m_Classifier一致]

//父类IteratedSingleClassifierEnhancer.java中的buildClassifier过程

public void buildClassifier(Instances data) throws Exception {

if (m_Classifier == null) {
throw new Exception("A base classifier has not been specified!");
}
m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);
}


③在m_CalcOutOfBag为真且m_BagSizePercent = 100时,准备计算OOB 及抽样数据

if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
throw new IllegalArgumentException("Bag size needs to be 100% if "
+ "out-of-bag error is to be calculated!");
}

int bagSize = (int) (data.numInstances() * (m_BagSizePercent / 100.0));
Random random = new Random(m_Seed);

boolean[][] inBag = null;    //inBag:一行代表一种分类器的抽样样本情况
if (m_CalcOutOfBag)
inBag = new boolean[m_Classifiers.length][];   //[m_Classifiers.length]为Classifier[]数组长度,即分类器的个数

for (int j = 0; j < m_Classifiers.length; j++) {
Instances bagData = null;

// create the in-bag dataset
if (m_CalcOutOfBag) {       //计算OOB时,inBag为二维数组存取样本采样情况
inBag[j] = new boolean[data.numInstances()];
// bagData = resampleWithWeights(data, random, inBag[j]);
bagData = data.resampleWithWeights(random, inBag[j]);   //[resampleWithWeights]有放回采样数据
} else {                    //不计算OOB时也没必要用inBag了
bagData = data.resampleWithWeights(random);             //[resampleWithWeights]有放回采样数据
if (bagSize < data.numInstances()) {
bagData.randomize(random);
Instances newBagData = new Instances(bagData, 0, bagSize);
bagData = newBagData;
}
}

if (m_Classifier instanceof Randomizable) {   //Randomizable接口:设置seed
((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
}


④ 构建分类树,选择m_Classifier所为分类器的分类方法,默认为REPTree方法

// build the classifier
m_Classifiers[j].buildClassifier(bagData);   //构建分类树,调用m_Classifier所为分类器的buildClassifier()方法
}

⑤ 计算OOB误差情况

// calc OOB error?
if (getCalcOutOfBag()) {
double outOfBagCount = 0.0;
double errorSum = 0.0;
boolean numeric = data.classAttribute().isNumeric();

for (int i = 0; i < data.numInstances(); i++) {
double vote;
double[] votes;
if (numeric)
votes = new double[1];  //数值型求均值,一个数组单元
else
votes = new double[data.numClasses()];  //枚举型需要投票

// determine predictions for instance
int voteCount = 0;
for (int j = 0; j < m_Classifiers.length; j++) {
if (inBag[j][i])  //寻找未被抽到的样本实例,用来计算OOB
continue;

voteCount++;
// double pred = m_Classifiers[j].classifyInstance(data.instance(i));
if (numeric) {   //数值型
// votes[0] += pred;
votes[0] += m_Classifiers[j].classifyInstance(data.instance(i));  //数值型直接把预测结果累加
} else {
// votes[(int) pred]++;
double[] newProbs = m_Classifiers[j].distributionForInstance(data.instance(i));
// average the probability estimates
for (int k = 0; k < newProbs.length; k++) {
votes[k] += newProbs[k];    //枚举型要累加枚举概率
}
}
}

// "vote"
if (numeric) {
vote = votes[0];
if (voteCount > 0) {
vote /= voteCount; // average  算数均值
}
} else {
if (Utils.eq(Utils.sum(votes), 0)) {
} else {
Utils.normalize(votes);   //归一化
}
vote = Utils.maxIndex(votes); // predicted class  选出最大的index
}

// error for instance
outOfBagCount += data.instance(i).weight();  //累加权重
if (numeric) {
errorSum += StrictMath.abs(vote - data.instance(i).classValue())*data.instance(i).weight(); //累加错误偏差
} else {
if (vote != data.instance(i).classValue())
errorSum += data.instance(i).weight();   //枚举型对出错进行计数
}
}

m_OutOfBagError = errorSum / outOfBagCount;
} else {
m_OutOfBagError = 0;   //不计算OOB了
}
}


4)主要训练过程在于bagging的基分类器,默认为REPTree

6、Bagging建立分类树过程:

建分类器完整代码:

/**
* Bagging method.
*
* @param data the training data to be used for generating the bagged
* classifier.
* @throws Exception if the classifier could not be built successfully
*/
@Override
public void buildClassifier(Instances data) throws Exception {

// can classifier handle the data?
getCapabilities().testWithFail(data);

// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();

super.buildClassifier(data); //调用父类的方法:【IteratedSingleClassifierEnhancer】中的【buildClassifier()】
if (m_CalcOutOfBag && (m_BagSizePercent != 100)) { throw new IllegalArgumentException("Bag size needs to be 100% if " + "out-of-bag error is to be calculated!"); } int bagSize = (int) (data.numInstances() * (m_BagSizePercent / 100.0)); Random random = new Random(m_Seed); boolean[][] inBag = null; //inBag:一行代表一种分类器的抽样样本情况 if (m_CalcOutOfBag) inBag = new boolean[m_Classifiers.length][]; //[m_Classifiers.length]为Classifier[]数组长度,即分类器的个数 for (int j = 0; j < m_Classifiers.length; j++) { Instances bagData = null; // create the in-bag dataset if (m_CalcOutOfBag) { //计算OOB时,inBag为二维数组存取样本采样情况 inBag[j] = new boolean[data.numInstances()]; // bagData = resampleWithWeights(data, random, inBag[j]); bagData = data.resampleWithWeights(random, inBag[j]); //[resampleWithWeights]有放回采样数据 } else { //不计算OOB时也没必要用inBag了 bagData = data.resampleWithWeights(random); //[resampleWithWeights]有放回采样数据 if (bagSize < data.numInstances()) { bagData.randomize(random); Instances newBagData = new Instances(bagData, 0, bagSize); bagData = newBagData; } } if (m_Classifier instanceof Randomizable) { //Randomizable接口:设置seed ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt()); }
// build the classifier m_Classifiers[j].buildClassifier(bagData); //构建分类树,调用m_Classifier所为分类器的buildClassifier()方法 }

// calc OOB error? if (getCalcOutOfBag()) { double outOfBagCount = 0.0; double errorSum = 0.0; boolean numeric = data.classAttribute().isNumeric(); for (int i = 0; i < data.numInstances(); i++) { double vote; double[] votes; if (numeric) votes = new double[1]; //数值型求均值,一个数组单元 else votes = new double[data.numClasses()]; //枚举型需要投票 // determine predictions for instance int voteCount = 0; for (int j = 0; j < m_Classifiers.length; j++) { if (inBag[j][i]) //寻找未被抽到的样本实例,用来计算OOB continue; voteCount++; // double pred = m_Classifiers[j].classifyInstance(data.instance(i)); if (numeric) { //数值型 // votes[0] += pred; votes[0] += m_Classifiers[j].classifyInstance(data.instance(i)); //数值型直接把预测结果累加 } else { // votes[(int) pred]++; double[] newProbs = m_Classifiers[j].distributionForInstance(data.instance(i)); // average the probability estimates for (int k = 0; k < newProbs.length; k++) { votes[k] += newProbs[k]; //枚举型要累加枚举概率 } } } // "vote" if (numeric) { vote = votes[0]; if (voteCount > 0) { vote /= voteCount; // average 算数均值 } } else { if (Utils.eq(Utils.sum(votes), 0)) { } else { Utils.normalize(votes); //归一化 } vote = Utils.maxIndex(votes); // predicted class 选出最大的index } // error for instance outOfBagCount += data.instance(i).weight(); //累加权重 if (numeric) { errorSum += StrictMath.abs(vote - data.instance(i).classValue())*data.instance(i).weight(); //累加错误偏差 } else { if (vote != data.instance(i).classValue()) errorSum += data.instance(i).weight(); //枚举型对出错进行计数 } } m_OutOfBagError = errorSum / outOfBagCount; } else { m_OutOfBagError = 0; //不计算OOB了 } }

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息