您的位置:首页 > 编程语言

C45算法代码实现及其详解

2017-07-10 20:21 1056 查看

1.概述

C45算法在weka已经有具体的实现,即weka中的J48.java。不过J48.java中的具体代码牵扯到较多的类和其他东西,直接看源代码比较容易混乱,且需要了解的东西较多,有比较多和C45算法本身没有太大关系而是为了方便代码实现的类、变量和方法等。

本文是基于C45算法思想和对J48源代码的详细解读,自己写了一个C45算法的代码(之后均称为MyC45)。该代码只含有两个类(99%的代码只在一个类中实现),需要了解的结构相对简单,算法的实现过程相对清晰,算法的效果和J48.java相差无几,写在这里权作参考。

weka中的每个分类器的类文件都实现了buildClassifier和distributionForInstance两个算法,前者是构建分类器,后者是根据构建的分类器对每个实例进行类标记预测。所以,MyC45算法中主要分为建树(buildClassifier)和预测(distributionForInstance)两部分,其中建树(buildClassifier)是主体部分,预测(distributionForInstance)只是简单的利用前面构建好的树对每个实例进行预测。

特别说明,本文所使用的实例集均已进行数据预处理,所以在MyC45中没有J48中对缺失值的处理、判断J48能否处理该实例集等等对实例集的处理。

2.类及其变量说明

MyC45包括两个类,一个是MyC45,另一个是MyClassifierTree,它们的具体功能和变量说明如下:

1.MyC45:实现C45算法的类。

变量:

m_Root:C45决策树的根节点。

2.MyClassifierTree:具体实现C45算法中的建树和预测实例类标记两大部分的功能。

变量:

int[] m_AttributeList: 当前节点的可选分裂属性集合,以整数形式保存.-1表示当前节点不能选择该属性作为分裂属性。

int m_SplitAttribute:当前节点的分裂属性,以整数的形式表示。

int m_NumAttributes:训练实例集中属性的个数(包括类属性)。

MyCLassifierTree[] m_Sons:当前节点的儿子节点。

Instances m_Instances:当前节点对应的实例集。

int m_MinInstances:最小实例数,若当前节点对应的实例集的实例数小于该值,则当前节点只能作为叶节点。

注:以上是关键变量说明,类中其他变量并不是很重要就没有特别说明,在后面的代码中出现时会有说明。


3.伪代码

输入:训练实例集D,可选属性列表m_AttributeList,最小实例数m_MinInstances。

1.创建根节点m_root,m_AttributeList初始化为全0.

2.if D中的所有实例类标记都一样,则将m_root标记为叶节点并返回m_root。

3.if m_AttributeList为空,则将m_root标记为叶节点并返回m_root。

4.if m_Instances的实例数小于m_MinInstances,则将m_root标记为叶节点并返回m_root。

5.根据getBestSplitAttribute方法,在m_AttributeList中找出增益率最大的属性作为当前节点的分裂属性m_SplitAttribute。

6.根据m_SplitAttribute的属性值将当前节点的实例集m_Instances划分成多个子集,也就是当前节点的儿子节点 m_Sons。

7.对m_Sons中的每个节点递归地循环2-6的步骤以构建子树。

8.决策树构建好后,需要对其进行collapse处理和prune处理。前者是折叠子树过程,后者是后剪枝过程,具体说明在后面。

4.函数调用流程图

以下是各个部分的主要函数调用流程图,后面将对这些主要函数进行详细说明。



图1. buildClassifier部分的主要函数调用流程图

在MyC45类中的buildClassifier函数中,先初始化可选属性列表m_AttributeList和最小实例数m_MinInstances,然后用这两个参数初始化m_Root,之后再用m_Root调用MyClassifierTree类中的buildeClassifier函数即可。



图2. distributionForInstance部分的主要函数调用流程图

在MyC45类中的distributionForInstance函数中,直接用m_Root调用MyClassifierTree类中的distributionForInstance函数即可。



图3. 建树(buildTree)的主要函数调用流程图

建树首先根据isAllTheSame和canSplit两个函数和其他条件判断当前节点是否为叶节点;若可以分裂则调用getBestSplitAttribute函数求出最佳的分裂属性(也就是增益率最大的属性),getBestSplitAttribute函数中对每个可选属性分别调用beforeSplit、afterSplit和computeSplitInfo函数以求出该可选属性的增益率;随后将当前节点对应的实例集根据最佳分裂属性(即当前节点的分裂属性)分裂出多个子集,这些子集即为当前节点的子节点的实例集;然后对每个子集调用getNewTree函数以构建相应的子树,getNewTree函数中即调用buildTree函数,从而递归地构建决策树。



图4. 折叠子树(collapse)的主要函数调用流程图

collapse是对非叶节点进行操作,对于非叶节点,调用getCurrentTraningErrors求出当前节点的实例集上误分实例数,再调用getSubTraningErrors求出所有子树上实例集上误分实例数;如果后者大于前者,说明这些子树并不能提高这颗树的准确度,则把这些子树删除。否则在每个子树上递归的collapse。



图5. 后剪枝(prune)的主要函数调用流程图

prune是后剪枝操作,所以需要先递归找到叶节点的上一层的第一个子树开始剪枝。首先,调用getLargetBranch求出当前节点的最大树枝;随后,调用getEstimatedErrorsForBranch求出当前结点实例集在最大树枝上的误分实例数,标记为a;调用getEstimatedErrorsForDistribution求出当前节点的实例集在当前节点上的误分实例数,标记为b;调用getEstimatedErrors求出当前节点的所有子树上的误分实例数,标记为c。接着,判断是否应该把当前结点设置为叶节点,即第一个if语句。若不成立,则判断是否用最大树枝代替当前结点,即第二个if语句。若是,则将最大树枝上的变量信息覆盖当前结点的变量信息,并调用restInstances根据更改后的当前节点对当前实例集进行调整树的结构,并递归地对更改后的当前节点进行prune操作。

5.buildTree

isAllTheSame和canSplit函数比较简单,split就是根据属性值对实例集进行划分,getNewTree只是简单的初始化节点并调用buildTree形成递归,这些都比较简单,略过。

1.buildTree代码

public void buildTree(Instances instances) throws Exception
{
initializePara(instances);//初始化一些变量

if (instances.numInstances()  <= m_MinInstances|| isAllTheSame(instances))
//小于m_MinInstances的实例集只能做叶节点,或当前实例集的类标记都一样也做叶节点,
{
m_IsLeaf = true;                   // 是则该节点为叶节点
m_Instances = instances;// 叶节点对应的实例集为instances
return;
}

if (!canSplit(m_AttributeList)) // 若可选属性集为空,则实例集不能继续分裂,所以当前节点是叶节点
{
m_IsLeaf = true;
m_Instances = instances;
return;
} else
{
int[] sonAttributeList = new int[m_NumAttributes];  //子节点的可选属性列表
for (int i = 0; i < sonAttributeList.length; i++)
{
if (i == m_ClassIndex)
sonAttributeList[i] = -1;
else
sonAttributeList[i] = m_AttributeList[i];
}

m_SplitAttribute = getBestSplitAttribute(instances, m_AttributeList); // 从当前可选的分裂属性集合中获取最佳的分裂属性
m_NameOfCurrentNode = instances.attribute(m_SplitAttribute).name(); // 获取当前分裂属性的名称
if (m_SplitAttribute != -1) // 当前实例集可以划分,且求得最佳分裂属性时
{
sonAttributeList[m_SplitAttribute] = -1; // 当前分裂属性在所有子节点上是不可选的,所以这里进行标记一下
int numOfSubTree = m_NumAttsValues[m_SplitAttribute];// 该节点对应的子节点数量,等于当前分裂属性的属性值个数

Instances[] localInstances;
localInstances = split(instances, numOfSubTree, m_SplitAttribute);// 根据分裂属性的属性值个数将实例集进行划分
m_NameOfLineToSon = new String[numOfSubTree];//每个子节点对应的属性值
for (int i = 0; i < numOfSubTree; i++)
m_NameOfLineToSon[i] = m_Instances.attribute(m_SplitAttribute).value(i);

m_Sons = new MyCLassifierTree[numOfSubTree];
for (int i = 0; i < m_Sons.length; i++)
{// 接着为每一个localInstances构建子树
m_Sons[i] = getNewTree(localInstances[i], sonAttributeList,m_MinInstances);
localInstances[i] = null;
if (m_Sons[i].m_IsLeaf)  //统计当前结点叶节点数
m_NumLeaf ++;
}
} else// 当前实例集不可以划分,说明该节点是叶节点
{
m_IsLeaf = true;
m_Instances = instances;
return;
}
}
}


2.getBestSplitAttribute

/**
* 从当前可选属性列表中求出最佳分裂属性,即增益率最大的属性
* @param instances
* @param attributesList
* @return
*/
public int getBestSplitAttribute(Instances instances, int[] attributesList)
{
int bestSplitAttribute = 0; //标记最佳分裂属性
boolean canSplit = false; //判断该实例集是否可以继续划分
double[] gainRatio = new double[m_NumAttributes ]; //增益率,
double[] infoGain =  new double[m_NumAttributes ]; //信息增益,
double[] splitInfo =  new double[m_NumAttributes ]; //分裂信息,

for (int i = 0; i < m_NumAttributes ; i++) //遍历属性,计算每个属性增益率
{
if (i != m_ClassIndex && attributesList[i] != -1)//对可选的非类属性进行计算增益率
{
infoGain[i] = beforeSplit(instances) - afterSplit(instances,i);
splitInfo[i] = computeSplitInfo(instances,i);
gainRatio[i] = infoGain[i] / splitInfo[i];
canSplit = true; //当进入增益率计算时,说明该实例集可以进行划分
}
else
{
gainRatio[i] = 0;
infoGain[i] = 0;
splitInfo[i] = 0;
}
}

if (canSplit)
{  //若可以分裂,则找出gainRatio数组中最大且attributesList数组中不等于-1的下标,即为最佳分裂属性
bestSplitAttribute = getMaxIndex(gainRatio,attributesList);
}
else {//若当前实例集无法继续分裂,则返回-1作为没有找到最佳分裂属性的标记
bestSplitAttribute = -1;
}

return bestSplitAttribute;
}


getBestSplitAttribute是通过计算增益率求出的,以下下是计算增益率的一些公式:

(1)GainRito(A) = Gain(A)/SplitInfo(A)

GainRito(A):属性A的增益率。

Gain(A):属性A 的信息增益。

SplitInfo(A):属性A的分裂信息量。

(2)Gain(A) = Info(Insts) - Info(Insts,A)

Info(Insts):实例集Insts分裂前的信息量。

Info(Insts,A):实例集Insts根据属性A分裂后的信息量。

(3)


C:实例集Insts中的类标记个数。

Pi:第i种类标记对应的实例数与实例总数的比值。

(4)


nA:属性A的属性值个数。

n:实例集的实例总数。

ni:属性A的第i种属性值对应的实例数。

Instsi:属性A的第i种属性值对应的实例集。

Info(Instsi):根据公式(3)计算属性A的第i种属性值对应的实例集的信息量。

(5)


3.beforeSplit

/**
* 计算分裂前实例集的信息值
* @param instances
* @return
*/
public double beforeSplit(Instances instances)
{
double infoBeforeSplit = 0;

int numClasses = instances.numClasses();
int numInstances = instances.numInstances(); //实例总数
double allWeight = 0;                                                 //实例集instances中的实例数(权重之和)
double[] numInstancesInClass = new double[numClasses];//每个类标记对应的实例数(权重)

for (int i = 0; i < numInstances; i++) //遍历每个实例,统计出每个类标记对应的实例数(权重)
{
int classLable = (int)instances.instance(i).classValue();
numInstancesInClass[classLable] += instances.instance(i).weight();
allWeight += instances.instance(i).weight();
}

if (onlyOneNotZero(numInstancesInClass,allWeight))//1.如果只有一个类标记的实例数(权重)不为0(其他类标记的实例数(权重)为0),则信息值为0
return 0.0;
if (eachEqualAve(numInstancesInClass))//2.如果所有类标记对应的实例数(权重)相等,则信息值最大,这里设置为1
return 1.0;

for (int i = 0; i < numClasses; i++) //3.根据公式(3)求出实例集的信息值
{
double pi = numInstancesInClass[i] / allWeight;
if (pi != 0.0) //注意pi为0时不要纳入计算,因为log0是一个无效值,这会导致整个infoBeforeSplit值无效(NaN)。反正pi等于0时 pi * log2(pi)即为0,所以不纳入计算即可
infoBeforeSplit = infoBeforeSplit + pi * log2(pi) ;
}
return - infoBeforeSplit; //注意加个负号
}


4.afterSplit

/**
* 计算实例集根据属性attribute进行分裂后的信息量
* @param instances
* @param attribute
* @return
*/
public double afterSplit(Instances instances, int attribute)
{
double infoAfterSplit = 0;
int numAttributeValue = instances.attribute(attribute).numValues(); //属性attribute的属性值个数
int numInstances = instances.numInstances(); // 实例总数
double allWeight = 0; ////实例集instances中的实例数权重之和

double[] weightInAttValue = new double[numAttributeValue];//每个属性值对应的实例数(权重之和)
Instances[] instsOfValue = new Instances[numAttributeValue];//每个属性值对应的实例子集
for (int i = 0; i < instsOfValue.length; i++)//初始化
instsOfValue[i] = new Instances(instances, 0);

for (int i = 0; i < numInstances; i++)//遍历实例集,将实例集instances根据属性值划分实例集,统计每个实例集的实例数(权重)
{
int attValue = (int)instances.instance(i).value(attribute); //获取第i个实例在属性attribute中的属性值
weightInAttValue[attValue] += instances.instance(i).weight(); //计算每个属性值对应的实例数(权重之和)
allWeight +=  instances.instance(i).weight();  //计算实
fd49
例集instances中的实例权重之和
instsOfValue[attValue].add(instances.instance(i)); //将实例i放入对应的实例子集之中
}

for (int i = 0; i < numAttributeValue; i++)//根据公式(4)计算根据属性i分裂后的实例集的信息值
{
double value = weightInAttValue[i]/allWeight * beforeSplit(instsOfValue[i]);
infoAfterSplit = infoAfterSplit + value;
}

return infoAfterSplit;
}
}


5.computeSplitInfo

/**
* 计算分裂信息量
* @param instances
* @param attribute
* @return
*/
public double computeSplitInfo(Instances instances, int attribute)
{
double splitInfo = 0;
double allWeight = 0; ////实例集instances中的实例数(权重之和)
int numAttributeValue = instances.attribute(attribute).numValues(); //属性attribute的属性值个数

double[] weightInEachValue = new double[numAttributeValue]; //每个属性值对应的实例数(权重之和)

for (int i = 0; i < instances.numInstances(); i++)
{
int  attValue = (int)instances.instance(i).value(attribute); //获取第i个实例在属性attribute中的属性值
weightInEachValue[attValue] += instances.instance(i).weight(); //计算每个属性值对应的实例数(权重之和)
allWeight += instances.instance(i).weight(); //计算实例集instances中的实例数(权重之和)
}

for (int i = 0; i < numAttributeValue; i++) //根据公式(5)计算分裂信息值
{
double pi = weightInEachValue[i] / allWeight;
if (pi != 0)//注意pi为0时不要纳入计算,因为log0是一个无效值,这会导致整个splitInfo值无效。
splitInfo = splitInfo +pi* log2(pi);
}

return  - splitInfo; //注意加个负号
}


6.collapse

1.collapse

/**
* Collapses a tree to a node if training error doesn't increase.
* 如果当前节点存在很多子节点,但这些子节点并不能提高这颗分类树的准确度,则把这些孩子节点删除。否则在每个孩子上递归的collapse。
* 通过collapse方法可以在不减少精度的前提下减少决策树的深度,进而提高效率。
*/
public final void collapse( )
{
double errorsOfTree;       // 当前节点上训练实例集误分的实例数(权重)
double errorsOfSubtree; // 当前节点的所有子树上训练实例集误分的实例数(权重)
int i;

if (!m_IsLeaf)//只有对非叶节点才进行折叠子树操作
{
errorsOfTree = getCurrentTrainingErrors();
errorsOfSubtree = getSubTrainingErrors();

if (errorsOfSubtree >= errorsOfTree - 1E-3)
//所有子树上误分实例数(权重)大于当前节点误分实例数(权重)时,说明这些孩子节点不好,将他们删除。
//删除的方式是将当前节点 的子树变量设置为空并将该节点设置为叶节点。1E-3是10的-3次方,即0.001
{
m_Sons = null;
m_IsLeaf = true;
}
else
for (i = 0; i < m_Sons.length; i++) // 在每个孩子上递归地进行折叠子树操作
m_Sons[i].collapse();
}
}


2.getCurrentTrainingErrors

/**
* 计算当前结点的误分实例数
* @return
*/
private double getCurrentTrainingErrors()
{
double wrongWeight = 0;

int majorityClassLable = majorityClassLable(m_Instances); //获取当前实例集中的多数类

for (int i = 0; i < m_Instances.numInstances(); i++)  //遍历当前实例集,求出总误分实例数(权重)
{
int classValue = (int)m_Instances.instance(i).classValue();
if (classValue != majorityClassLable)
{
m_Predictions[i] = -1;
wrongWeight += m_Instances.instance(i).weight();
}
}

return wrongWeight;
}


3.getSubTrainingErrors

/**
* 计算当前结点的所有子节点中误分的实例数(权重)
* @return
*/
private double getSubTrainingErrors()
{
double errors = 0;

if (m_IsLeaf)   //对叶节点,直接调用getCurrentTrainingErrors函数求出该叶节点上的误分实例数(权重)
return getCurrentTrainingErrors();
else //对非叶节点,递归调用getSubTrainingErrors以求出所有子节点上的误分实例数(权重)
{
for (int i = 0; i < m_Sons.length; i++)
errors = errors + m_Sons[i].getSubTrainingErrors();
return errors;
}
}


7.prune

1.prune

/**
* 后剪枝操作
* @throws Exception
*/
public final void prune( ) throws Exception
{
double errorsLargestBranch; //当前节点实例集在最大树枝上的误分实例数
double errorsLeaf;                     //假设当前节点是叶节点时,该节点对应的实例集在该节点上的误分实例数
double errorsTree;                     //计算当前节点的所有子树上的误分实例数
int indexOfLargestBranch;     //最大树枝的下标
MyCLassifierTree largestBranch;  //临时保存最大树枝
int i;

if (!m_IsLeaf) //对非叶节点均进行剪枝
{
for (i = 0; i < m_Sons.length; i++)// 对当前节点的子节点递归地进行剪枝,由于是后剪枝,所以从树的最底层开始往上
m_Sons[i].prune();

// 求出当前树上的最大树枝,即当前节点的所有子集中实例数最大的子集下标,
indexOfLargestBranch = getLargetBranch();

// 计算当前节点实例集在最大树枝上的误分实例数
errorsLargestBranch = m_Sons[indexOfLargestBranch].getEstimatedErrorsForBranch(m_Instances);

//计算当前节点对应的实例集在当前节点上的误分实例数
errorsLeaf = getEstimatedErrorsForDistribution(m_Instances);

// 计算当前节点的所有子树上的误分实例数
errorsTree = getEstimatedErrors();

// 判断将该节点设置为叶节点是不是最好的选择,
if (Utils.smOrEq(errorsLeaf, errorsTree + 0.1) && Utils.smOrEq(errorsLeaf, errorsLargestBranch + 0.1))
{
m_Sons = null;
m_IsLeaf = true;
return;
}

// 判断用最大树枝代替当前节点是不是最好的选择
if (Utils.smOrEq(errorsLargestBranch, errorsTree + 0.1))
{
largestBranch = m_Sons[indexOfLargestBranch];  //获取最大树枝
m_Sons = largestBranch.m_Sons;
m_AttributeList = largestBranch.m_AttributeList;   //将最大树枝的可选分裂属性列表覆盖当前节点的可选分裂属性列表
m_AttributeList[m_SplitAttribute] = 0;                        //由于会用最大树枝的分裂属性代替了原先节点的分裂属性,所以原先节点的分裂属性处于可选状态
m_SplitAttribute = largestBranch.m_SplitAttribute;  //用最大树枝的分裂属性代替了原先节点的分裂属性
m_IsLeaf = largestBranch.m_IsLeaf;

resetInstances(m_Instances);  //将当前实例集根据修改后的分裂属性进行划分
prune(); //递归地对修改后的节点进行剪枝
}
}
}


2.getEstimatedErrorsForDistribution

/**
* 求出testInstances实例集在以m_Instances为根据的分类器中的误分实例数(权重)
* @param theDistribution
*            the distribution to use
* @return the estimated errors
*/
private double getEstimatedErrorsForDistribution(Instances testInstances)
{

if (Utils.eq(testInstances.numInstances(), 0)) //若testInstances实例数为0,则误分实例数只能为0
return 0;
else
{
double inCorrectWeight = 0;
double allWeight = 0.0;
int majorityClassLable ;

majorityClassLable = majorityClassLable(m_Instances);//求出当前实例集m_Instances中的多数类
for (int i = 0; i < testInstances.numInstances(); i++)
{
allWeight += testInstances.instance(i).weight(); //统计测试实例集testInstances中的实例总数(权重)
int classVlaue = (int)testInstances.instance(i).classValue();
if (classVlaue != majorityClassLable)
inCorrectWeight += testInstances.instance(i).weight(); //统计测试实例集testInstances中误分实例的实例总数(权重)
}

return inCorrectWeight +  Stats.addErrs(allWeight,inCorrectWeight, 0.25f);
}
}


3.getEstimatedErrorsForBranch

/**
*  求出testInstances实例集在以m_Instances为根据的分类器中的误分实例数(权重)
* @param data
*            the data to work with
* @return the estimated errors
* @throws Exception
*             if something goes wrong
*/
private double getEstimatedErrorsForBranch(Instances testInstances) throws Exception
{
double errors = 0;
int i;

if (m_IsLeaf) //若当前节点是叶节点,则调用getEstimatedErrorsForDistribution求出当前节点上的误分实例数
return getEstimatedErrorsForDistribution(testInstances);
else //若当前节点不是叶节点,则计算testInstances在其所有子节点上的误分实例数之和
{
//将testInstances根据当前节点的分裂属性的属性值,划分成不同的测试实例集
int numSubset = testInstances.attribute(m_SplitAttribute).numValues();
Instances[] localInstances = split(testInstances, numSubset, m_SplitAttribute);//测试实例子集

for (i = 0; i < m_Sons.length; i++) //计算每个测试实例子集在对应子节点上的误分实例数(权重)
errors = errors + m_Sons[i].getEstimatedErrorsForBranch(localInstances[i]);
return errors;
}
}


4.getEstimatedErrors

/**
*计算当前结点的所有子树上的误分实例数(权重)
* @return the estimated errors
*/
private double getEstimatedErrors()
{

double errors = 0;
int i;

if (m_IsLeaf)  //若当前结点是叶节点,则直接计算误分实例树(权重)
return getEstimatedErrorsForDistribution(m_Instances);
else
{
for (i = 0; i < m_Sons.length; i++)  //若当前节点不是叶节点,则递归地计算其所有子树上的误分实数(权重)
errors = errors + m_Sons[i].getEstimatedErrors();
return errors;
}
}


5.resetInstances

/**
* 将当前实例集根据修改后的分裂属性进行划分,并修改预测结果数组m_Prediction[]
* @param instances
* @throws Exception
*/
private void resetInstances(Instances instances) throws Exception
{

m_Instances = instances;

if (!m_IsLeaf) //若当前节点不是叶子节点,则递归地对其及其子树进行重新划分实例集
{
int numSubset = (int)instances.attribute(m_SplitAttribute).numValues();
Instances[] localInstances = split(instances, numSubset, m_SplitAttribute);

for (int i = 0; i < m_Sons.length; i++)//递归地对其子树进行重新划分实例集
{
m_Sons[i].m_Instances = localInstances[i];
m_Sons[i].resetInstances(localInstances[i]);
}
}
else
{//由于实例集发生了变化,所以需要根据新实例集和新的可选分裂属性列表构建树
m_AttributeList[m_SplitAttribute] = -1;  //在当前节点上建树,则当前分裂属性在其子树的构建过程中不可选
m_IsLeaf = false;

int numSubset = (int)instances.attribute(m_SplitAttribute).numValues();   //此时是在各个子节点上建树
Instances[] localInstances = split(instances, numSubset, m_SplitAttribute);
m_Sons = new MyCLassifierTree[numSubset];
for (int i = 0; i < localInstances.length; i++)
{
m_Sons[i] = getNewTree(localInstances[i], m_AttributeList, m_MinInstances);
}
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  weka C45 代码详解