您的位置:首页 > 运维架构

数据挖掘笔记-分类-决策树-MapReduce实现-2

2014-05-28 16:38 591 查看
上篇文章里面虽然结合Hadoop用到MapReduce去计算属性的增益率,但是发现整个程序似乎也并没有做到并行化处理。后面又看了一些网上的资料,自己又想了想,然后又重新实现了一下决策树,大体思路如下:

1、将一个大数据集文件拆分成N个小数据集文件,对数据做好预处理工作,上传到HDFS

2、计算HDFS上小数据集文件的最佳分割属性与分割点

3、汇总N个小数据集文件的最佳划分,投票选出最佳划分

4、N个小数据集的节点根据最终的最佳划分,分割自己节点上的数据,上传到HDFS,跳转到第二步

 

下面是具体的实现代码:其中用到了JobControl来控制多Job执行,还有涉及到几个MR程序,代码未进行过整理,望见谅。

public class DecisionTreeSprintBJob extends AbstractJob {

private Map<String, Map<Object, Integer>> attributeValueStatistics = null;

private Map<String, Set<String>> attributeNameToValues = null;

private Set<String> allAttributes = null;

/** 数据拆分,大数据文件拆分为小数据文件,便于分配到各个节点开启Job*/
private List<String> split(String input, String splitNum) {
String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();
String[] args = new String[]{input, output, splitNum};
DataFileSplitMR.main(args);
List<String> inputs = new ArrayList<String>();
Path outputPath = new Path(output);
try {
FileSystem fs = outputPath.getFileSystem(conf);
Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);
for(Path path : paths) {
System.out.println("split input path: " + path);
InputStream in = fs.open(path);
BufferedReader reader = new BufferedReader(new InputStreamReader(in));
String line = reader.readLine();
while (null != line && !"".equals(line)) {
inputs.add(line);
line = reader.readLine();
}
IOUtils.closeQuietly(in);
IOUtils.closeQuietly(reader);
}
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("inputs size: " + inputs.size());
return inputs;
}

/** 初始化工作,主要是获取特征属性集以及属性值的统计,主要是为了填充默认值*/
private void initialize(String input) {
System.out.println("initialize start.");
allAttributes = new HashSet<String>();
attributeNameToValues = new HashMap<String, Set<String>>();
attributeValueStatistics = new HashMap<String, Map<Object, Integer>>();
String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();
String[] args = new String[]{input, output};
AttributeStatisticsMR.main(args);
Path outputPath = new Path(output);
SequenceFile.Reader reader = null;
try {
FileSystem fs = outputPath.getFileSystem(conf);
Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);
for(Path path : paths) {
reader = new SequenceFile.Reader(fs, path, conf);
AttributeKVWritable key = (AttributeKVWritable)
ReflectionUtils.newInstance(reader.getKeyClass(), conf);
IntWritable value = new IntWritable();
while (reader.next(key, value)) {
String attributeName = key.getAttributeName();
allAttributes.add(attributeName);
Set<String> values = attributeNameToValues.get(attributeName);
if (null == values) {
values = new HashSet<String>();
attributeNameToValues.put(attributeName, values);
}
String attributeValue = key.getAttributeValue();
values.add(attributeValue);
Map<Object, Integer> valueStatistics =
attributeValueStatistics.get(attributeName);
if (null == valueStatistics) {
valueStatistics = new HashMap<Object, Integer>();
attributeValueStatistics.put(attributeName, valueStatistics);
}
valueStatistics.put(attributeValue, value.get());
value = new IntWritable();
}
}
} catch (IOException e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(reader);
}
System.out.println("initialize end.");
}

/** 预处理,主要是将分割后的小文件填充好默认值后在上传到HDFS上面*/
private List<String> preHandle(List<String> inputs) throws IOException {
List<String> fillInputs = new ArrayList<String>();
for (String input : inputs) {
Data data =null;
try {
Path inputPath = new Path(input);
FileSystem fs = inputPath.getFileSystem(conf);
FSDataInputStream fsInputStream = fs.open(inputPath);
data = DataLoader.load(fsInputStream, true);
} catch (IOException e) {
e.printStackTrace();
}
DataHandler.computeFill(data.getInstances(),
allAttributes.toArray(new String[0]),
attributeValueStatistics, 1.0);
OutputStream out = null;
BufferedWriter writer = null;
String outputDir = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();
fillInputs.add(outputDir);
String output = outputDir + File.separator + IdentityUtils.generateUUID();
try {
Path outputPath = new Path(output);
FileSystem fs = outputPath.getFileSystem(conf);
out = fs.create(outputPath);
writer = new BufferedWriter(new OutputStreamWriter(out));
StringBuilder sb = null;
for (Instance instance : data.getInstances()) {
sb = new StringBuilder();
sb.append(instance.getId()).append("\t");
sb.append(instance.getCategory()).append("\t");
Map<String, Object> attrs = instance.getAttributes();
for (Map.Entry<String, Object> entry : attrs.entrySet()) {
sb.append(entry.getKey()).append(":");
sb.append(entry.getValue()).append("\t");
}
writer.write(sb.toString());
writer.newLine();
}
writer.flush();
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(out);
IOUtils.closeQuietly(writer);
}
}
return fillInputs;
}

/** 创建JOB*/
private Job createJob(String jobName, String input, String output) {
Configuration conf = new Configuration();
conf.set("mapred.job.queue.name", "q_hudong");
Job job = null;
try {
job = new Job(conf, jobName);

FileInputFormat.addInputPath(job, new Path(input));
FileOutputFormat.setOutputPath(job, new Path(output));

job.setJarByClass(DecisionTreeSprintBJob.class);

job.setMapperClass(CalculateGiniMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(AttributeWritable.class);

job.setReducerClass(CalculateGiniReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(AttributeGiniWritable.class);

job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
} catch (IOException e) {
e.printStackTrace();
}
return job;
}

/** 根据HDFS上的输出路径选择最佳属性*/
private AttributeGiniWritable chooseBestAttribute(String... outputs) {
AttributeGiniWritable minSplitAttribute = null;
double minSplitPointGini = 1.0;
try {
for (String output : outputs) {
System.out.println("choose output: " + output);
Path outputPath = new Path(output);
FileSystem fs = outputPath.getFileSystem(conf);
Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);
ShowUtils.print(paths);
SequenceFile.Reader reader = null;
for (Path path : paths) {
reader = new SequenceFile.Reader(fs, path, conf);
Text key = (Text) ReflectionUtils.newInstance(
reader.getKeyClass(), conf);
AttributeGiniWritable value = new AttributeGiniWritable();
while (reader.next(key, value)) {
double gini = value.getGini();
System.out.println(value.getAttribute() + " : " + gini);
if (gini <= minSplitPointGini) {
minSplitPointGini = gini;
minSplitAttribute = value;
}
value = new AttributeGiniWritable();
}
IOUtils.closeQuietly(reader);
}
System.out.println("delete hdfs file start: " + outputPath.toString());
HDFSUtils.delete(conf, outputPath);
System.out.println("delete hdfs file end: " + outputPath.toString());
}
} catch (IOException e) {
e.printStackTrace();
}
if (null == minSplitAttribute) {
System.out.println("minSplitAttribute is null");
}
return minSplitAttribute;
}

private Data obtainData(String input) {
Data data = null;
Path inputPath = new Path(input);
try {
FileSystem fs = inputPath.getFileSystem(conf);
Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, inputPath);
FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]);
data = DataLoader.load(fsInputStream, true);
} catch (IOException e) {
e.printStackTrace();
}
return data;
}

/** 构建决策树*/
private Object build(List<String> inputs) throws IOException {
List<String> outputs = new ArrayList<String>();
JobControl jobControl = new JobControl("CalculateGini");
for (String input : inputs) {
System.out.println("split path: " + input);
String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL + IdentityUtils.generateUUID();
outputs.add(output);
Configuration conf = new Configuration();
ControlledJob controlledJob = new ControlledJob(conf);
controlledJob.setJob(createJob(input, input, output));
jobControl.addJob(controlledJob);
}
Thread jcThread = new Thread(jobControl);
jcThread.start();
while(true){
if(jobControl.allFinished()){
//                System.out.println(jobControl.getSuccessfulJobList());
jobControl.stop();
AttributeGiniWritable bestAttr = chooseBestAttribute(
outputs.toArray(new String[0]));
String attribute = bestAttr.getAttribute();
System.out.println("best attribute: " + attribute);
System.out.println("isCategory: " + bestAttr.isCategory());
if (bestAttr.isCategory()) {
return attribute;
}
TreeNode treeNode = new TreeNode(attribute);
Map<String, List<String>> splitToInputs =
new HashMap<String, List<String>>();
for (String input : inputs) {
Data data = obtainData(input);
String splitPoint = bestAttr.getSplitPoint();
//        			Map<String, Set<String>> attrName2Values =
//        					DataHandler.attributeValueStatistics(data.getInstances());
Set<String> attributeValues = attributeNameToValues.get(attribute);
System.out.println("attributeValues:");
ShowUtils.print(attributeValues);
if (attributeNameToValues.size() == 0 || null == attributeValues) {
continue;
}
attributeValues.remove(splitPoint);
StringBuilder sb = new StringBuilder();
for (String attributeValue : attributeValues) {
sb.append(attributeValue).append(",");
}
if (sb.length() > 0) sb.deleteCharAt(sb.length() - 1);
String[] names = new String[]{splitPoint, sb.toString()};
DataSplit dataSplit = DataHandler.split(new Data(
data.getInstances(), attribute, names));
for (DataSplitItem item : dataSplit.getItems()) {
if (item.getInstances().size() == 0) continue;
String path = item.getPath();
String name = path.substring(path.lastIndexOf(File.separator) + 1);
String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;
HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);
String split = item.getSplitPoint();
List<String> nextInputs = splitToInputs.get(split);
if (null == nextInputs) {
nextInputs = new ArrayList<String>();
splitToInputs.put(split, nextInputs);
}
nextInputs.add(hdfsPath);
}
}
for (Map.Entry<String, List<String>> entry :
splitToInputs.entrySet()) {
treeNode.setChild(entry.getKey(), build(entry.getValue()));
}
return treeNode;
}
if(jobControl.getFailedJobList().size() > 0){
//                System.out.println(jobControl.getFailedJobList());
jobControl.stop();
}
}
}

/** 分类样本集*/
private void classify(TreeNode treeNode, String testSet, String output) {
OutputStream out = null;
BufferedWriter writer = null;
try {
Path testSetPath = new Path(testSet);
FileSystem testFS = testSetPath.getFileSystem(conf);
Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath);
FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]);
Data testData = DataLoader.load(fsInputStream, true);

DataHandler.computeFill(testData.getInstances(),
allAttributes.toArray(new String[0]),
attributeValueStatistics, 1.0);
Object[] results = (Object[]) treeNode.classifySprint(testData);
ShowUtils.print(results);
DataError dataError = new DataError(testData.getCategories(), results);
dataError.report();
String path = FileUtils.obtainRandomTxtPath();
out = new FileOutputStream(new File(path));
writer = new BufferedWriter(new OutputStreamWriter(out));
StringBuilder sb = null;
for (int i = 0, len = results.length; i < len; i++) {
sb = new StringBuilder();
sb.append(i+1).append("\t").append(results[i]);
writer.write(sb.toString());
writer.newLine();
}
writer.flush();
Path outputPath = new Path(output);
FileSystem fs = outputPath.getFileSystem(conf);
if (!fs.exists(outputPath)) {
fs.mkdirs(outputPath);
}
String name = path.substring(path.lastIndexOf(File.separator) + 1);
HDFSUtils.copyFromLocalFile(conf, path, output +
File.separator + name);
} catch (IOException e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(out);
IOUtils.closeQuietly(writer);
}
}

public void run(String[] args) {
try {
if (null == conf) conf = new Configuration();
String[] inputArgs = new GenericOptionsParser(
conf, args).getRemainingArgs();
if (inputArgs.length != 4) {
System.out.println("error, please input three path.");
System.out.println("1. trainset path.");
System.out.println("2. testset path.");
System.out.println("3. result output path.");
System.out.println("4. data split number.");
System.exit(2);
}
List<String> splitInputs = split(inputArgs[0], inputArgs[3]);
initialize(inputArgs[0]);
List<String> inputs = preHandle(splitInputs);
TreeNode treeNode = (TreeNode) build(inputs);
TreeNodeHelper.print(treeNode, 0, null);
classify(treeNode, inputArgs[1], inputArgs[2]);
} catch (Exception e) {
e.printStackTrace();
}
}

public static void main(String[] args) {
DecisionTreeSprintBJob job = new DecisionTreeSprintBJob();
long startTime = System.currentTimeMillis();
job.run(args);
long endTime = System.currentTimeMillis();
System.out.println("spend time: " + (endTime - startTime));
}

}


代码托管:https://github.com/fighting-one-piece/repository-datamining.git

 

 

 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息