您的位置:首页 > 编程语言

Weka算法Classifier-tree-RandomForest源码分析(二)代码实现

2017-01-14 11:16 375 查看
                           Weka算法Classifier-tree-RandomForest源码分析(二)代码实现

RandomForest的实现异常的简单,简单的超出博主的预期,Weka在实现方式上组合了Bagging和RandomTree。

一、RandomForest的训练

构建RandomForest的代码如下:

[java] view
plain copy

public void buildClassifier(Instances data) throws Exception {  

  

  // can classifier handle the data?  

  getCapabilities().testWithFail(data);  

  

  // remove instances with missing class  

  data = new Instances(data);  

  data.deleteWithMissingClass();  

  

  m_bagger = new Bagging();  

  RandomTree rTree = new RandomTree();  

  

  // set up the random tree options  

  m_KValue = m_numFeatures;  

  if (m_KValue < 1)  

    m_KValue = (int) Utils.log2(data.numAttributes()) + 1;  

  rTree.setKValue(m_KValue);  

  rTree.setMaxDepth(getMaxDepth());  

  

  // set up the bagger and build the forest  

  m_bagger.setClassifier(rTree);  

  m_bagger.setSeed(m_randomSeed);  

  m_bagger.setNumIterations(m_numTrees);  

  m_bagger.setCalcOutOfBag(true);  

  m_bagger.buildClassifier(data);  

}  

通过这段代码很直观的可以看出首先把无效数据去掉,然后建立了一个Bag,设置随机森林中每棵树所用到的属性的值,设置最大深度,接着把这棵RandomTree当做基分类器传递给Bagging,最后调用bagging的训练方法进行训练。

二、RandomForest分类

看完训练过程看具体的分类过程,也就是classifyInstance函数,值得注意的是,RandomForest继承自Classifier,却没有队classifyInstance方法进行重载,使用的是基类Classifier的classifyInstance函数,但却重载了distributionForInstance,而distributionForInstance却是Classifier的classifyInstance函数所用到的一个函数,返回一个instance在所有类上的概率。代码如下:

[java] view
plain copy

public double[] distributionForInstance(Instance instance) throws Exception {  

  

  return m_bagger.distributionForInstance(instance);  

}  

可以看到,算出给定instance在各class上的分布是委托给bagger去做的(真懒),所以这里也不做详细分析,详细分析留到分析bagger的时候再说。

接下来看基类Classifier是如何使用distribution来给出分类结果的。

[java] view
plain copy

public double classifyInstance(Instance instance) throws Exception {  

  

  double[] dist = distributionForInstance(instance);  

  if (dist == null) {  

    throw new Exception("Null distribution predicted");  

  }  

  switch (instance.classAttribute().type()) {  

  case Attribute.NOMINAL:  

    double max = 0;  

    int maxIndex = 0;  

  

    for (int i = 0; i < dist.length; i++) {  

      if (dist[i] > max) {  

        maxIndex = i;  

        max = dist[i];  

      }  

    }  

    if (max > 0) {  

      return maxIndex;  

    } else {  

      return Instance.missingValue();  

    }  

  case Attribute.NUMERIC:  

  case Attribute.DATE:  

    return dist[0];  

  default:  

    return Instance.missingValue();  

  }  

}  

可以很直观的看到,如果要是一个分类,则给出概率最大值,如果是一个回归(即classIndex对应的属性是数值),则返回dist[0],这里是使用了一个约定,第一个元素代表回归值。

三、总结

对于RandomForest的代码分析差不多就结束了,基本没什么实质内容,因为算法的主要工作都交由Bagging和RandomForest去做了,值得注意的是,当没有指定抽样属性的数量时,Weka使用的log2(K)作为经验值。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐