您的位置:首页 > 编程语言 > Java开发

ID3决策树的Java实现

2016-03-10 15:42 435 查看
package DecisionTree;

import java.io.*;
import java.util.*;

public class ID3 {

//节点类
public class DTNode {
private String attribute;
private HashMap<String, DTNode> children = new HashMap<String, DTNode>();
public String getAttribute() {
return attribute;
}
public void setAttribute(String attribute) {
this.attribute = attribute;
}
public HashMap<String, DTNode> getChildren() {
return children;
}
public void setChildren(HashMap<String, DTNode> children) {
this.children = children;
}
}

private String decisionColumn;        //决定字段

public String getDecisionColumn() {
return decisionColumn;
}

public void setDecisionColumn(String decisionColumn) {
this.decisionColumn = decisionColumn;
}

//统计每个属性在集合中出现的次数
public HashMap<String, Integer> getTypeCounts(ArrayList<String> dataset) {
HashMap<String, Integer> map = new HashMap<String, Integer>();
for (int i = 0; i < dataset.size(); i++) {
String key = dataset.get(i);
if(!map.containsKey(key))
map.put(key, 1);
else
map.put(key, map.get(key)+1);
}
return map;
}

//获取key的indexlist
public ArrayList<Integer> getIndex(String key, ArrayList<String> dataset){
ArrayList<Integer> indexlist = new ArrayList<Integer>();

for(int i = 0; i < dataset.size(); i++){
if(key.equals(dataset.get(i)))
indexlist.add(Integer.valueOf(i));
}
return indexlist;
}

//根据index获取数据集
public ArrayList<String> getSubset(ArrayList<Integer> indexlist, ArrayList<String> dataset) {
ArrayList<String> subset = new ArrayList<String>();
for(Integer i : indexlist){
subset.add(dataset.get(i.intValue()));
}
return subset;
}

//计算信息熵
public double getEntropy(ArrayList<String> dataset) {
double entropy = 0;
double prob = 0;
int sum = dataset.size();
HashMap<String, Integer> map = getTypeCounts(dataset);
Iterator<String> iter = map.keySet().iterator();
while(iter.hasNext()){
String key = iter.next();
prob = (double)map.get(key).intValue()/sum;
entropy += -1*prob*Math.log10(prob)/Math.log10(2);
}
return entropy;
}

//计算已知条件下的信息熵
public double getConditionEntropy(HashMap<String, ArrayList<String>> dataset, String IndexCol) {
double entropy = 0;
double prob = 0;
int sum = dataset.get(IndexCol).size();
HashMap<String, Integer> map = getTypeCounts(dataset.get(IndexCol));
Iterator<String> iter = map.keySet().iterator();
while(iter.hasNext()){
String key = iter.next();
prob = (double)map.get(key)/sum;
entropy+=prob*getEntropy(getSubset(getIndex(key,dataset.get(IndexCol)),dataset.get(this.decisionColumn)));
}
return entropy;
}

//建立决策树
public DTNode buildDT(HashMap<String, ArrayList<String>>dataset) {

DTNode node = new DTNode();
double info_entropy = getEntropy(dataset.get(this.decisionColumn));
//递归结束条件
if(info_entropy == 0){
node.setAttribute((dataset.get(this.decisionColumn).get(0)));
return node;
}

//求出拥有最小熵数据集的column,即最大entropy gain
double max_gain = 0;            //设置默认值
double gain = 0;
String max_column="";
Iterator<String> entropy_iter = dataset.keySet().iterator();

while(entropy_iter.hasNext()){
String key = entropy_iter.next();
if(key.equals(this.decisionColumn))
continue;
gain = getEntropy(dataset.get(decisionColumn)) - getConditionEntropy(dataset,key);  //计算信息增益
if(gain > max_gain){
max_gain = gain;
max_column = key;
}
}

node.setAttribute(max_column);
ArrayList<String> ds = dataset.get(max_column);        //最小熵数据集

//生成新数据集
Iterator<String> iter = getTypeCounts(ds).keySet().iterator();
while(iter.hasNext()){
String key = iter.next();
HashMap<String, ArrayList<String>> subset = new HashMap<String, ArrayList<String>>();
DTNode childNode;
ArrayList<Integer> indexlist = getIndex(key,ds);
Iterator<String> sub_iter = dataset.keySet().iterator();
while(sub_iter.hasNext()){
String sub_key = sub_iter.next();
if(!sub_key.equals(max_column))
subset.put(sub_key, getSubset(indexlist,dataset.get(sub_key)));
}

childNode = buildDT(subset);
node.getChildren().put(key, childNode);
}

return node;
}

//输出树
public void printDT(DTNode root){

if(root == null)
return;
System.out.println(root.attribute);
if(root.getChildren() == null)
return;

Iterator<String> iter = root.getChildren().keySet().iterator();
while(iter.hasNext()){
String key = iter.next();
System.out.print(key+" ");
printDT(root.getChildren().get(key));
}
}

//读取源文件
public HashMap<String,ArrayList<String>> read(String path){
HashMap<String,ArrayList<String>> dataset = new HashMap<String,ArrayList<String>>();
try{
File file = new File(path);
if(file.isFile() && file.exists()){ //判断文件是否存在
InputStreamReader input = new InputStreamReader(new FileInputStream(file),"UTF-8");
BufferedReader read = new BufferedReader(input);
String line = null;

ArrayList<ArrayList<String>> ds = new ArrayList<ArrayList<String>>();
while((line = read.readLine()) != null){
String[] data = line.split(",");
ArrayList<String> temp = new ArrayList<String>();
for(int i = 0; i < data.length; i++)
temp.add(data[i]);
ds.add(temp);
}

for(int i = 0; i < ds.get(0).size(); i++){
ArrayList<String> newds = new ArrayList<String>();
for(int j = 0; j < ds.size(); j++){
newds.add(ds.get(j).get(i));
}
String key = newds.get(0);
newds.remove(0);
dataset.put(key,newds);
}
input.close();
}
}catch(Exception e){
e.printStackTrace();
}

return dataset;
}

public static void main(String[] args) {
ID3 tree = new ID3();
HashMap<String,ArrayList<String>> ds = tree.read("C:"+File.separator+"Users"+File.separator+"mhua005"+File.separator+
"Desktop"+File.separator+"sample.txt");
tree.setDecisionColumn("play");
ArrayList<String> attr = new ArrayList<String>();
attr.add("outlook");
attr.add("temperature");
attr.add("humidity");
attr.add("windy");
attr.add("play");
DTNode root = tree.buildDT(ds);
tree.printDT(root);
}
}


源文件内容:

outlook,temperature,humidity,windy,play
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: