spark高级数据分析实战---用决策树预测森林植被
2016-07-18 22:37
323 查看
这是我写的这本书的第二个程序,这几天一直研究storm,没时间写,第一个推荐系统由于时间我没及时发回头会补充给大家,这个找了时间参照书上写的,希望对大家有帮助。
运行结果如下,因为我的机器比较差,时间长点,跑在集群上就不一样了,大家可以试试
package mllib.tree import org.apache.log4j.{Level, Logger} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkConf} /** * Created by 汪本成 on 2016/7/12. */ object trainCovtype { //开始时间 var beg = System.currentTimeMillis() //屏蔽不必要的日志显示在终端上 //Logger.getLogger("org.apache.spark").setLevel(Level.WARN) //Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) //创建入口对象 val conf = new SparkConf().setAppName("trainCovtype").setMaster("local") val sc= new SparkContext(conf) val HDFS_COVDATA_PATH = "hdfs://node1:9000/user/spark/sparkLearning/mllib/covtype.data" val rawData = sc.textFile(HDFS_COVDATA_PATH) //设置LabeledPoint格式 val data = rawData.map{ line => val values = line.split(",").map(_.toDouble) // init返回除最后一个值之外的所有值,最后一列是目标 val FeatureVector = Vectors.dense(values.init) //决策树要求(目标变量)label从0开始,所以要减一 val label = values.last - 1 LabeledPoint(label, FeatureVector) } //分成训练集(80%),交叉验证集(10%),测试集(10%) val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1)) trainData.cache() cvData.cache() testData.cache() //新建决策树 val numClass = 7 //分类数量 val categoricalFeaturesInfo = Map[Int, Int]() //用map存储类别(离散)特征及每个类特征对应值(类别)的数量 val impurity = "gini" //纯度计算方法,用于信息增益的计算 val maxDepth = 4 //树的最大高度 val maxBins = 100 // 用于分裂特征的最大划分数量 //训练分类决策树模型 val model = DecisionTree.trainClassifier(trainData, numClass, categoricalFeaturesInfo, impurity, maxDepth, maxBins) val metrics = getMetrics(model,cvData) //计算精确度(样本比例) val precision = metrics.precision //计算每个样本的准确度(召回率) val recall = (0 until 7).map( //DecisionTreeModel模型的类别号从0开始 cat => (metrics.precision(cat), metrics.recall(cat)) ) //混淆矩阵 val confusionMatrix = metrics.confusionMatrix //预测训练数据集 val trainPriorProbabilities = classProbabilities(trainData) //预测cv集 val cvPriorProbabilities = classProbabilities(cvData) //将所有类别在训练集合cv集出现的概率相乘,然后把结果相加,最后得到对准确度评估 val two_probabilities = trainPriorProbabilities.zip(cvPriorProbabilities).map { //把cv集中的莫个类别的概率结成对,相乘后再相加 case (trainProd, cvProd) => trainProd * cvProd }.sum /**决策树的优化**/ val evaluations1 = for (impurities <- Array("gini", "entropy"); depth <- Array(1, 20); bins <- Array(10, 300) )yield { val model = DecisionTree.trainClassifier( trainData, numClass, categoricalFeaturesInfo, impurities, depth, bins) val predictionAndLabels = cvData.map( example => (model.predict(example.features), example.label) ) val accuracy = new MulticlassMetrics(predictionAndLabels).precision ((impurities, depth, bins), accuracy) } //按照第二个值(准确度)降序排序 val result1 = evaluations1.sortBy(_._2).reverse //对优化决策树让训练集集合cv数据集进行评估 val evaluations2 = for (impurities <- Array("gini", "entropy"); depth <- Array(1, 20); bins <- Array(10, 30) )yield { val model = DecisionTree.trainClassifier( trainData.union(cvData), numClass, categoricalFeaturesInfo, impurities, depth, bins ) val predictionAndLabels = trainData.union(cvData).map( example => (model.predict(example.features), example.label) ) val accuracy = new MulticlassMetrics(predictionAndLabels).precision ((impurities, depth, bins), accuracy) } //按照第二个值(准确度)降序排序 val result2 = evaluations2.sortBy(_._2).reverse //结束时间 var end = System.currentTimeMillis() //耗时时间 var castTime = end - beg def main(args: Array[String]) { println("========================================================================================") //精确度(样本比例) println("精确度: " + precision) println("========================================================================================") //准确度(召回率) println("准确度: ") recall.foreach(println) println("========================================================================================") //cv和train的数据集结合对准确度的评估 println("cv和train的数据集结合对准确度: " + two_probabilities) println("========================================================================================") //混淆矩阵 println("混淆矩阵如下: ") println(confusionMatrix) println("========================================================================================") //cvData下的决策树不同条件下的准确度降序排序 println("cvData下的决策树不同条件下的准确度降序排如下: ") result1.foreach(println) println("========================================================================================") //cvData结合trainData下不同条件下的准确度降序排序 println("cvData结合trainData下不同条件下的准确度降序排序如下: ") result2.foreach(println) println("========================================================================================") println(" 运行程序耗时: " + castTime/1000 + "s") } /** * 在训练集构建DecisionTreeModel * * @param model * @param data * @return */ def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]): MulticlassMetrics = { val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label)) new MulticlassMetrics(predictionsAndLabels) } /** * 按照类别在训练集出现的比例预测类别 * * @param data * @return */ def classProbabilities(data: RDD[LabeledPoint]): Array[Double] = { //计算数据中每个类别的样本数(类别, 样本数) val countsByCategory = data.map(_.label).countByValue() //对类别的样本数进行排序并取出样本数 val counts = countsByCategory.toArray.sortBy(_._1).map(_._2) counts.map(_.toDouble / counts.sum) } } |
======================================================================================== 精确度: 0.6980966928106819 ======================================================================================== 准确度: (0.68203722951904,0.6760325934251195) (0.7190111755099426,0.7894107720433132) (0.6342031686859273,0.7663288288288288) (0.4682926829268293,0.35294117647058826) (0.0,0.0) (0.7692307692307693,0.029994001199760048) (0.7007633587786259,0.4443368828654405) ======================================================================================== cv和train的数据集结合对准确度: 0.37763488623859276 ======================================================================================== 混淆矩阵如下: 14436.0 6548.0 8.0 0.0 0.0 3.0 359.0 5615.0 22454.0 319.0 18.0 0.0 5.0 33.0 0.0 755.0 2722.0 68.0 0.0 7.0 0.0 0.0 0.0 176.0 96.0 0.0 0.0 0.0 0.0 899.0 13.0 0.0 0.0 0.0 0.0 0.0 540.0 1054.0 23.0 0.0 50.0 0.0 1115.0 33.0 0.0 0.0 0.0 0.0 918.0 ======================================================================================== cvData下的决策树不同条件下的准确度降序排如下: ((entropy,20,300),0.9146686803851236) ((gini,20,300),0.9035989496627593) ((entropy,20,10),0.8961676420615443) ((gini,20,10),0.8915338012940429) ((gini,1,300),0.6366039095886179) ((gini,1,10),0.6358487651672473) ((entropy,1,300),0.4881665436696586) ((entropy,1,10),0.4881665436696586) ======================================================================================== cvData结合trainData下不同条件下的准确度降序排序如下: ((entropy,20,30),0.9509307620195527) ((gini,20,30),0.9386290152863074) ((entropy,20,10),0.9344984598901834) ((gini,20,10),0.9305190457058677) ((gini,1,30),0.6342325278845969) ((gini,1,10),0.6340756471330999) ((entropy,1,30),0.4871319520174482) ((entropy,1,10),0.4871319520174482) ======================================================================================== 运行程序耗时: 331s 16/07/18 18:27:11 INFO SparkContext: Invoking stop() from shutdown hook 16/07/18 18:27:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040 16/07/18 18:27:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped! 16/07/18 18:27:12 INFO MemoryStore: MemoryStore cleared 16/07/18 18:27:12 INFO BlockManager: BlockManager stopped 16/07/18 18:27:12 INFO BlockManagerMaster: BlockManagerMaster stopped 16/07/18 18:27:12 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint: OutputCommitCoordinator stopped! 16/07/18 18:27:12 INFO SparkContext: Successfully stopped SparkContext 16/07/18 18:27:12 INFO ShutdownHookManager: Shutdown hook called 16/07/18 18:27:12 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-d7aaabfd-bd23-4a4d-b3a4-c21fd38eea98 16/07/18 18:27:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon. Process finished with exit code 0 |
相关文章推荐
- Linux安装Samba共享配置
- 使用Express + Socket.io + MongoDB实现简单的聊天
- Codeforces Round #130 (Div. 2) C - Police Station 最短路+dp
- Zenject——轻量级依赖注入框架 for Unity
- C++ - PAT - L1-028. 判断素数(天梯赛决赛题目)
- 阿里云RDS接口开发笔记
- php-删除非空目录
- 如何高效率学习R?[转自微信:R语言中文社区]
- django app的命名规则
- python核心编程学习笔记-2016-07-18-02-enumerate()函数
- oracle 的绑定变量
- 一些经典的T-SQL语句
- Linux centos 安装jdk
- FASTQ 格式说明
- 基于ITK和VTK实现三维体数据的区域生长分割和可视化
- HDU 1811 Rank of Tetris(并查集+拓扑排序)
- JDBC操作封装
- plsql 安装后database下拉没有东西
- LVS负载均衡DR模式+keepalived
- Linux 用户环境变量