ID3、C4.5算法介绍以及java代码实现
2015-09-28 21:42
603 查看
一、ID3算法介绍
0、引言套用俗语,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:
女儿:多大年纪了? 母亲:26。 女儿:长的帅不帅? 母亲:挺帅的。 女儿:收入高不? 母亲:不算很高,中等情况。 女儿:是公务员不? 母亲:是,在税务局上班呢。 女儿:那好,我去见见。
其中,分支节点都是数据的属性,而叶节点是数据的类。
构造决策树的方法简单直观。当我们想要用机器学习的方法区构造这样一颗决策树的时候则需要考虑两个问题:
(1)如何选择分支节点上的特征。
(2)设置什么样的终止条件。
1、概述
1.1 ID3算法
ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。
1.2 信息熵
1948 年,香农提出了“信息熵”(shāng) 的概念,才解决了对信息的量化度量问题。香农用信息熵表示信息的不确定度,一个信息的不确定度越大,需要把他弄明白需要的信息量就越大。并从概率的角度给出了信息熵的计算公式。
1.3 信息增益度
信息增益表示两个信息熵之间的差值。他的计算公式为
ID3算法使用信息增益度衡量用哪个属性作为分支节点。Gain表示分类前后信息不确定性的差值,Gain值越大,表示该特征用于分类的效果越好。
算法实现思路如下:
2.实例
2.1 分类解决用户是否会购买电脑
共14条记录,目标属性是,是否买电脑,共有两个情况,yes或者no。参考属性有4种情况,分别是,age,income,student,credit_rating。属性age有3种取值情况,分别是,youth,middle_aged,senior,属性income有3种取值情况,分别是,high,medium,low,属性student有2种取值情况,分别是,no,yes,属性credit_rating有2种取值情况,分别是fair,excellent。我们先求参考属性的信息熵:
式中的5表示5个no,9表示9个yes,14是总的记录数。接下来我们求各个参考属性在取各自的值对应目标属性的信息熵,以属性age为例,有3种取值情况,分别是youth,middle_aged,senior,先考虑youth,youth共出现5次,3次no,2次yes,于是信息熵:
类似得到middle_aged和senior的信息熵,分别是:0和0.971。整个属性age的信息熵应该是它们的加权平均值:
计算得到信息增益度为:
Gain(income)=0.029,Gain(stduent)=0.151,Gain(credit_rating)=0.048。最大值为Gain(age),所以首先按照参考属性age,将数据分为3类,如下:
3.总结
3.1 优点
(1):生成的决策树简单直观。分类规则易于理解。
(2):全盘考虑数据,从而抵抗噪声。
(3):不存在无解,计算量少。
3.2 缺点
(1):采用信息增益率作为衡量分类所用特征的好坏,容易偏向选择较多的属性,但是这样的属性往往不能提供有用的信息(规范化)。引进其他先验知识,或者限制决策树为二叉树。C4.5算法用信息增益率衡量特征可以改进这个问题。信息增益率:
(2):需要多次遍历数据库,效率不高。
(3):如果不进行剪枝,容易过拟合。
(4):单变量决策树,在分类时候每次只考虑一个变量。
(5):难以处理连续型数据。虽然预排序或者离散化可以改进这个问题。
(6):不能回溯,每次增加一个新的特征都需要重新扫描数据库,拓展性差。
3.3 其他注意事项
(1):若所有特征属性已经用于分类后,仍然没有得到完全纯正的分类结果。可以使用多数表决,将叶节点中大部分数据所属类作为分类结果。
(2):为了防止过拟合,可以设定阈值。比如当某个节点的数据表中多数类的比例超过70%时停止分类。
4 代码实现(在luowen3405代码基础上改动)
http://blog.csdn.net/luowen3405/article/details/6250731
用于表示每个树节点
import java.util.ArrayList; public class TreeNode { private String name; // 节点名(分裂属性的名称) private ArrayList<String> rule; // 结点的分裂规则 二分属性 ArrayList<TreeNode> child; // 子结点集合 private ArrayList<ArrayList<String>> datas; // 划分到该结点的训练元组 private ArrayList<String> candAttr; // 划分到该结点的候选属性 public TreeNode() { this.name = ""; this.rule = new ArrayList<String>(); this.child = new ArrayList<TreeNode>(); this.datas = null; this.candAttr = null; } public ArrayList<TreeNode> getChild() { return child; } public void setChild(ArrayList<TreeNode> child) { this.child = child; } public ArrayList<String> getRule() { return rule; } public void setRule(ArrayList<String> rule) { this.rule = rule; } public String getName() { return name; } public void setName(String name) { this.name = name; } public ArrayList<ArrayList<String>> getDatas() { return datas; } public void setDatas(ArrayList<ArrayList<String>> datas) { this.datas = datas; } public ArrayList<String> getCandAttr() { return candAttr; } public void setCandAttr(ArrayList<String> candAttr) { this.candAttr = candAttr; } }
提供准确计算的工具类
import java.math.BigDecimal; public class DecimalCalculate { /** * 由于Java的简单类型不能够精确的对浮点数进行运算,这个工具类提供精 确的浮点数运算,包括加减乘除和四舍五入。 */ // 默认除法运算精度 private static final int DEF_DIV_SCALE = 10; // 这个类不能实例化 private DecimalCalculate() { } /** * 提供精确的加法运算。 * * @param v1 * 被加数 * @param v2 * 加数 * @return 两个参数的和 */ public static double add(double v1, double v2) { BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.add(b2).doubleValue(); } /** * 提供精确的减法运算。 * * @param v1 * 被减数 * @param v2 * 减数 * @return 两个参数的差 */ public static double sub(double v1, double v2) { BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.subtract(b2).doubleValue(); } /** * 提供精确的乘法运算。 * * @param v1 * 被乘数 * @param v2 * 乘数 * @return 两个参数的积 */ public static double mul(double v1, double v2) { BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.multiply(b2).doubleValue(); } /** * 提供(相对)精确的除法运算,当发生除不尽的情况时,精确到 小数点以后10位,以后的数字四舍五入。 * * @param v1 * 被除数 * @param v2 * 除数 * @return 两个参数的商 */ public static double div(double v1, double v2) { return div(v1, v2, DEF_DIV_SCALE); } /** * 提供(相对)精确的除法运算。当发生除不尽的情况时,由scale参数指 定精度,以后的数字四舍五入。 * * @param v1 * 被除数 * @param v2 * 除数 * @param scale * 表示表示需要精确到小数点以后几位。 * @return 两个参数的商 */ public static double div(double v1, double v2, int scale) { if (scale < 0) { throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.divide(b2, scale, BigDecimal.ROUND_HALF_UP).doubleValue(); } /** * 提供精确的小数位四舍五入处理。 * * @param v * 需要四舍五入的数字 * @param scale * 小数点后保留几位 * @return 四舍五入后的结果 */ public static double round(double v, int scale) { if (scale < 0) { throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b = new BigDecimal(Double.toString(v)); BigDecimal one = new BigDecimal("1"); return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue(); } /** * 提供精确的类型转换(Float) * * @param v * 需要被转换的数字 * @return 返回转换结果 */ public static float convertsToFloat(double v) { BigDecimal b = new BigDecimal(v); return b.floatValue(); } /** * 提供精确的类型转换(Int)不进行四舍五入 * * @param v * 需要被转换的数字 * @return 返回转换结果 */ public static int convertsToInt(double v) { BigDecimal b = new BigDecimal(v); return b.intValue(); } /** * 提供精确的类型转换(Long) * * @param v * 需要被转换的数字 * @return 返回转换结果 */ public static long convertsToLong(double v) { BigDecimal b = new BigDecimal(v); return b.longValue(); } /** * 返回两个数中大的一个值 * * @param v1 * 需要被对比的第一个数 * @param v2 * 需要被对比的第二个数 * @return 返回两个数中大的一个值 */ public static double returnMax(double v1, double v2) { BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.max(b2).doubleValue(); } /** * 返回两个数中小的一个值 * * @param v1 * 需要被对比的第一个数 * @param v2 * 需要被对比的第二个数 * @return 返回两个数中小的一个值 */ public static double returnMin(double v1, double v2) { BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.min(b2).doubleValue(); } /** * 精确对比两个数字 * * @param v1 * 需要被对比的第一个数 * @param v2 * 需要被对比的第二个数 * @return 如果两个数一样则返回0,如果第一个数比第二个数大则返回1,反之返回-1 */ public static int compareTo(double v1, double v2) { BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.compareTo(b2); } }
计算信息增益度的量
import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import classify_id3.DecimalCalculate; public class Gain { private ArrayList<ArrayList<String>> D = null; // 训练元组 private ArrayList<String> attrList = null; // 候选属性集 public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) { this.D = datas; this.attrList = attrList; } /** * 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的) * * @param attrIndex * 指定的属性列的索引 * @return 值域集合 */ public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas, int attrIndex) { ArrayList<String> values = new ArrayList<String>(); String r = ""; for (int i = 0; i < datas.size(); i++) { r = datas.get(i).get(attrIndex); if (!values.contains(r)) { values.add(r); } } return values; } /** * 获取指定数据集中指定属性列索引的域值及其计数 * * @param d * 指定的数据集 * @param attrIndex * 指定的属性列索引 * @return 类别及其计数的map */ public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas, int attrIndex) { Map<String, Integer> valueCount = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(attrIndex); if (valueCount.containsKey(c)) { valueCount.put(c, valueCount.get(c) + 1); } else { valueCount.put(c, 1); } } return valueCount; } /** * 求对datas中元组分类所需的期望信息,即datas的熵 * * @param datas * 训练元组 * @return datas的熵值 */ public double infoD(ArrayList<ArrayList<String>> datas) { double info = 0.000; int total = datas.size(); Map<String, Integer> classes = valueCounts(datas, attrList.size()); Iterator<Map.Entry<String, Integer>> iter = classes.entrySet().iterator(); Integer[] counts = new Integer[classes.size()]; for (int i = 0; iter.hasNext(); i++) { Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>) iter.next(); Integer val = (Integer) entry.getValue(); counts[i] = val; } for (int i = 0; i < counts.length; i++) { double base = DecimalCalculate.div(counts[i], total, 3); info += (-1) * base * Math.log(base); } return info; } /** * 获取指定属性列上指定值域的所有元组 * * @param attrIndex * 指定属性列索引 * @param value * 指定属性列的值域 * @return 指定属性列上指定值域的所有元组 */ public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value) { ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>(); ArrayList<String> t = null; for (int i = 0; i < D.size(); i++) { t = D.get(i); if (t.get(attrIndex).equals(value)) { Di.add(t); } } return Di; } /** * 基于按指定属性划分对D的元组分类所需要的期望信息 * * @param attrIndex * 指定属性的索引 * @return 按指定属性划分的期望信息值 */ public double infoAttr(int attrIndex) { double info = 0.000; ArrayList<String> values = getValues(D, attrIndex); for (int i = 0; i < values.size(); i++) { ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values.get(i)); info += DecimalCalculate.mul( DecimalCalculate.div(dv.size(), D.size(), 3), infoD(dv)); } return info; } /** * 获取最佳分裂属性的索引 * * @return 最佳分裂属性的索引 */ public int bestGainAttrIndex() { int index = 0; double gain = 0.000; double tempGain = 0.000; for (int i = 0; i < attrList.size(); i++) { tempGain = infoD(D) - infoAttr(i); if (tempGain > gain) { gain = tempGain; index = i; } } return index; } }
生成决策树
import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; public class DecisionTree { private Integer attrSelMode; // 最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2 public DecisionTree() { this.attrSelMode = 1; } public DecisionTree(int attrSelMode) { this.attrSelMode = attrSelMode; } public void setAttrSelMode(Integer attrSelMode) { this.attrSelMode = attrSelMode; } /** * 获取指定数据集中的类别及其计数 * * @param datas * 指定的数据集 * @return 类别及其计数的map */ public static Map<String, Integer> classOfDatas( ArrayList<ArrayList<String>> datas) { Map<String, Integer> classes = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(tuple.size() - 1); if (classes.containsKey(c)) { classes.put(c, classes.get(c) + 1); } else { classes.put(c, 1); } } return classes; } /** * 获取具有最大计数的类名,即求多数类 * * @param classes * 类的键值集合 * @return 多数类的类名 */ public static String maxClass(Map<String, Integer> classes) { String maxC = ""; int max = -1; Iterator iter = classes.entrySet().iterator(); for (int i = 0; iter.hasNext(); i++) { Map.Entry entry = (Map.Entry) iter.next(); String key = (String) entry.getKey(); Integer val = (Integer) entry.getValue(); if (val > max) { max = val; maxC = key; } } return maxC; } /** * 构造决策树 * * @param datas * 训练元组集合 * @param attrList * 候选属性集合 * @return 决策树根结点 */ public static TreeNode buildTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) { System.out.print("候选属性列表: "); for (int i = 0; i < attrList.size(); i++) { System.out.print(" " + attrList.get(i) + " "); } System.out.println(); TreeNode node = new TreeNode(); node.setDatas(datas); node.setCandAttr(attrList); Map<String, Integer> classes = classOfDatas(datas); System.out.println(classes);// # String maxC = maxClass(classes); System.out.println(maxC);// # System.out.println("存放分类类型的个数是" + classes.size()); System.out.println("剩余特征数为" + attrList.size()); if (classes.size() == 1 || attrList.size() == 1) { node.setName(maxC); return node; } Gain gain = new Gain(datas, attrList); int bestAttrIndex = gain.bestGainAttrIndex(); System.out.println("最佳分类特征索引为" + bestAttrIndex);// # ArrayList<String> rules = gain.getValues(datas, bestAttrIndex); System.out.println("分类规则为" + rules);// # node.setRule(rules); node.setName(attrList.get(bestAttrIndex)); attrList.remove(bestAttrIndex); ArrayList<ArrayList<ArrayList<String>>> allDatas=new ArrayList<ArrayList<ArrayList<String>>>() ; for (int i = 0; i < rules.size(); i++) { String rule = rules.get(i); ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex, rule); allDatas.add(di); } for(int i=0;i<allDatas.size();i++){ for (int j = 0; j < allDatas.get(i).size(); j++) { allDatas.get(i).get(j).remove(bestAttrIndex); } System.out.println("剩余分类特征为" + attrList);// # System.out.println(); if (allDatas.get(i).size() == 0 || attrList.size() == 0) { TreeNode leafNode = new TreeNode(); leafNode.setName(maxC); leafNode.setDatas(allDatas.get(i)); leafNode.setCandAttr(attrList); node.getChild().add(leafNode); } else { TreeNode newNode = buildTree(allDatas.get(i), attrList); node.getChild().add(newNode); } } return node; } }
入口文件
import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; public class Id3 { /** * 文件流读取训练元组 * * @return 训练元组集合 * @throws IOException */ public static ArrayList<ArrayList<String>> readFData(String fileUrl) throws IOException { ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); BufferedReader in = new BufferedReader(new InputStreamReader( new FileInputStream(new File(fileUrl)), "UTF-8")); String temp = null; String[] tempArray = null; while ((temp = in.readLine()) != null) { tempArray = temp.split("\\t"); ArrayList<String> s = new ArrayList<String>(); for (int i = 0; i < tempArray.length; i++) { // /////////////// s.add(tempArray[i]); } datas.add(s); } in.close(); return datas; } /** * 文件流读取候选属性 * * @return 候选属性集合 * @throws IOException */ public static ArrayList<String> readFCandAttr(String fileUrl) throws IOException { ArrayList<String> candAttr = new ArrayList<String>(); String temp = null; BufferedReader in = new BufferedReader(new InputStreamReader( new FileInputStream(new File(fileUrl)), "UTF-8")); while ((temp = in.readLine()) != null) { candAttr.add(temp); } in.close(); return candAttr; } public static void main(String[] args) throws IOException { ArrayList<ArrayList<String>> Datas = readFData(".//data//text.txt"); //读入数据文件 长宽不限 需用table键隔开 ArrayList<String> features = readFCandAttr(".//data//feature.txt"); //读入特征文件 长宽不限 需用table键隔开 DecisionTree.buildTree(Datas, features); } }
数据格式如下
二、C4.5算法
1、C4.5算法介绍C4.5,是机器学习算法中的另一个决策树构造算法,也是上文所介绍的ID3的改进算法。决策树构造方法其实就是每次选择一个好的特征以及分裂点作为当前节点的分类条件。
2、C4.5与ID3的区别
用信息增益率来选择属性。ID3选择属性用的是信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy,熵是一种不纯度度量准则)定义信息,用熵的变化值定义信息增益,而C4.5用的是信息增益率。对,区别就在于一个是信息增益,一个是信息增益率。
在树构造过程中进行剪枝,在构造决策树的时候,那些挂着几个元素的节点,不考虑最好,不然容易导致overfitting。
对非离散数据也能处理。
能够对不完整数据进行处理
解释下信息增益和信息增益率的区别:类似出生人口数与出生率的区别,中国的出生人口很多,但是出生率不一定大,用出生率代替出生人口,能避免人口基数大影响结果。同理,用信息增益率代替信息增益能有效解决属性的候选值过多导致信息增益过大的问题。
信息增益率(Gain ratio)是由前面的信息增益(Gain)和分裂信息度量(SplitInformation)来共同决定的。其中:
其中,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
其中S1到Sc是c个值的属性A分割S而形成的c个样例子集。注意分裂信息实际上就是S关于属性A的各值的熵。这与我们前面对熵的使用不同,在那里我们只考虑S关于学习到的树要预测的目标属性的值的熵。
请注意,分裂信息项阻碍选择值为均匀分布的属性。例如,考虑一个含有n个样例的集合被属性A彻底分割(译注:分成n组,即一个样例一组)。这时分裂信息的值为log2n。相反,一个布尔属性B分割同样的n个实例,如果恰好平分两半,那么分裂信息是1。如果属性A和B产生同样的信息增益,那么根据增益比率度量,明显B会得分更高。
3、具体实现过程
Function C4.5(R:包含连续属性的无类别属性集合,C:类别属性,S:训练集) /*返回一棵决策树*/ Begin If S为空,返回一个值为Failure的单个节点; If S是由相同类别属性值的记录组成, 返回一个带有该值的单个节点; If R为空,则返回一个单节点,其值为在S的记录中找出的频率最高的类别属性值; [注意未出现错误则意味着是不适合分类的记录]; For 所有的属性R(Ri) Do If 属性Ri为连续属性,则 Begin 将Ri的最小值赋给A1: 将Rm的最大值赋给Am;/*m值手工设置*/ For j From 2 To m-1 Do Aj=A1+j*(A1Am)/m; 将Ri点的基于{< =Aj,>Aj}的最大信息增益属性(Ri,S)赋给A; End; 将R中属性之间具有最大信息增益的属性(D,S)赋给D; 将属性D的值赋给{dj/j=1,2...m}; 将分别由对应于D的值为dj的记录组成的S的子集赋给{sj/j=1,2...m}; 返回一棵树,其根标记为D;树枝标记为d1,d2...dm; 再分别构造以下树: C4.5(R-{D},C,S1),C4.5(R-{D},C,S2)...C4.5(R-{D},C,Sm);
参考网站“分类算法—–决策树”
相关文章推荐
- 控制逻辑的分离——springMVC
- spring helloworld
- java (web)异常分析java.lang.ClassNotFoundException: Aservlet
- java周期调度几种实现
- JAVA泛型方法的声明与实现
- Java基础知识强化之IO流笔记07:自定义的异常概述和自定义异常实现
- GsonFormat快速实现JavaBean
- Java中的Iterator
- HelloJava
- spring 中 hibernate 的 2种 配置方式(新旧 2种方式)
- NewProductActivity.java
- spring 中 hibernate 的 2种 配置方式(新旧 2种方式)
- java项目结局篇之项目进度
- Java下利用Jackson进行JSON解析和序列化
- JAVA i=i++,i=i+1的误区
- eclipse 使用
- java中Comparator的用法(转载)
- 平时用到的eclipse快捷键
- Java中的二进制、八进制、十六进制和移位运算
- Spring+Hibernate整合配置 --- 比较完整的spring、hibernate 配置