您的位置:首页 > 其它

spark高级数据分析实战---用决策树预测森林植被

2016-07-18 22:37 323 查看
这是我写的这本书的第二个程序,这几天一直研究storm,没时间写,第一个推荐系统由于时间我没及时发回头会补充给大家,这个找了时间参照书上写的,希望对大家有帮助。

package mllib.tree
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkConf}

/**
* Created by 汪本成 on 2016/7/12.
*/
object trainCovtype {

//开始时间
var beg = System.currentTimeMillis()

//屏蔽不必要的日志显示在终端上
//Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
//Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

//创建入口对象
val conf = new SparkConf().setAppName("trainCovtype").setMaster("local")
val sc= new SparkContext(conf)

val HDFS_COVDATA_PATH = "hdfs://node1:9000/user/spark/sparkLearning/mllib/covtype.data"
val rawData = sc.textFile(HDFS_COVDATA_PATH)

//设置LabeledPoint格式
val data = rawData.map{
line =>
val values = line.split(",").map(_.toDouble)
// init返回除最后一个值之外的所有值,最后一列是目标
val FeatureVector = Vectors.dense(values.init)
//决策树要求(目标变量)label从0开始,所以要减一
val label = values.last - 1
LabeledPoint(label, FeatureVector)
}

//分成训练集(80%),交叉验证集(10%),测试集(10%)
val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
trainData.cache()
cvData.cache()
testData.cache()

//新建决策树
val numClass = 7  //分类数量
val categoricalFeaturesInfo = Map[Int, Int]()  //用map存储类别(离散)特征及每个类特征对应值(类别)的数量
val impurity = "gini"  //纯度计算方法,用于信息增益的计算
val maxDepth = 4  //树的最大高度
val maxBins = 100   // 用于分裂特征的最大划分数量

//训练分类决策树模型
val model = DecisionTree.trainClassifier(trainData, numClass, categoricalFeaturesInfo, impurity, maxDepth, maxBins)

val metrics = getMetrics(model,cvData)
//计算精确度(样本比例)
val precision = metrics.precision
//计算每个样本的准确度(召回率)
val recall = (0 until 7).map(     //DecisionTreeModel模型的类别号从0开始
cat => (metrics.precision(cat), metrics.recall(cat))
)
//混淆矩阵
val confusionMatrix = metrics.confusionMatrix

//预测训练数据集
val trainPriorProbabilities = classProbabilities(trainData)
//预测cv集
val cvPriorProbabilities = classProbabilities(cvData)
//将所有类别在训练集合cv集出现的概率相乘,然后把结果相加,最后得到对准确度评估
val two_probabilities = trainPriorProbabilities.zip(cvPriorProbabilities).map {  //把cv集中的莫个类别的概率结成对,相乘后再相加
case (trainProd, cvProd) => trainProd * cvProd
}.sum

/**决策树的优化**/
val evaluations1 =
for (impurities <- Array("gini", "entropy");
depth <- Array(1, 20);
bins <- Array(10, 300)
)yield {

val model = DecisionTree.trainClassifier(
trainData,
numClass,
categoricalFeaturesInfo,
impurities,
depth,
bins)
val predictionAndLabels = cvData.map(
example =>
(model.predict(example.features), example.label)
)
val accuracy = new MulticlassMetrics(predictionAndLabels).precision
((impurities, depth, bins), accuracy)
}
//按照第二个值(准确度)降序排序
val result1 = evaluations1.sortBy(_._2).reverse

//对优化决策树让训练集集合cv数据集进行评估
val evaluations2 =
for (impurities <- Array("gini", "entropy");
depth <- Array(1, 20);
bins <- Array(10, 30)
)yield {
val model =
DecisionTree.trainClassifier(
trainData.union(cvData),
numClass,
categoricalFeaturesInfo,
impurities,
depth,
bins
)
val predictionAndLabels = trainData.union(cvData).map(
example =>
(model.predict(example.features), example.label)
)
val accuracy = new MulticlassMetrics(predictionAndLabels).precision
((impurities, depth, bins), accuracy)
}
//按照第二个值(准确度)降序排序
val result2 = evaluations2.sortBy(_._2).reverse

//结束时间
var end = System.currentTimeMillis()

//耗时时间
var castTime = end - beg

def main(args: Array[String]) {

println("========================================================================================")
//精确度(样本比例)
println("精确度: " + precision)
println("========================================================================================")
//准确度(召回率)
println("准确度: ")
recall.foreach(println)
println("========================================================================================")
//cv和train的数据集结合对准确度的评估
println("cv和train的数据集结合对准确度: " + two_probabilities)
println("========================================================================================")
//混淆矩阵
println("混淆矩阵如下: ")
println(confusionMatrix)
println("========================================================================================")
//cvData下的决策树不同条件下的准确度降序排序
println("cvData下的决策树不同条件下的准确度降序排如下: ")
result1.foreach(println)
println("========================================================================================")
//cvData结合trainData下不同条件下的准确度降序排序
println("cvData结合trainData下不同条件下的准确度降序排序如下: ")
result2.foreach(println)
println("========================================================================================")
println(" 运行程序耗时: " + castTime/1000 + "s")

}

/**
* 在训练集构建DecisionTreeModel
*
* @param model
* @param data
* @return
*/
def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label))
new MulticlassMetrics(predictionsAndLabels)
}
/**
* 按照类别在训练集出现的比例预测类别
*
* @param data
* @return
*/
def classProbabilities(data: RDD[LabeledPoint]): Array[Double] = {
//计算数据中每个类别的样本数(类别, 样本数)
val countsByCategory = data.map(_.label).countByValue()
//对类别的样本数进行排序并取出样本数
val counts = countsByCategory.toArray.sortBy(_._1).map(_._2)
counts.map(_.toDouble / counts.sum)
}

}

运行结果如下,因为我的机器比较差,时间长点,跑在集群上就不一样了,大家可以试试

========================================================================================

精确度: 0.6980966928106819

========================================================================================

准确度: 

(0.68203722951904,0.6760325934251195)

(0.7190111755099426,0.7894107720433132)

(0.6342031686859273,0.7663288288288288)

(0.4682926829268293,0.35294117647058826)

(0.0,0.0)

(0.7692307692307693,0.029994001199760048)

(0.7007633587786259,0.4443368828654405)

========================================================================================

cv和train的数据集结合对准确度: 0.37763488623859276

========================================================================================

混淆矩阵如下: 

14436.0  6548.0   8.0     0.0   0.0  3.0   359.0  

5615.0   22454.0  319.0   18.0  0.0  5.0   33.0   

0.0      755.0    2722.0  68.0  0.0  7.0   0.0    

0.0      0.0      176.0   96.0  0.0  0.0   0.0    

0.0      899.0    13.0    0.0   0.0  0.0   0.0    

0.0      540.0    1054.0  23.0  0.0  50.0  0.0    

1115.0   33.0     0.0     0.0   0.0  0.0   918.0  

========================================================================================

cvData下的决策树不同条件下的准确度降序排如下: 

((entropy,20,300),0.9146686803851236)

((gini,20,300),0.9035989496627593)

((entropy,20,10),0.8961676420615443)

((gini,20,10),0.8915338012940429)

((gini,1,300),0.6366039095886179)

((gini,1,10),0.6358487651672473)

((entropy,1,300),0.4881665436696586)

((entropy,1,10),0.4881665436696586)

========================================================================================

cvData结合trainData下不同条件下的准确度降序排序如下: 

((entropy,20,30),0.9509307620195527)

((gini,20,30),0.9386290152863074)

((entropy,20,10),0.9344984598901834)

((gini,20,10),0.9305190457058677)

((gini,1,30),0.6342325278845969)

((gini,1,10),0.6340756471330999)

((entropy,1,30),0.4871319520174482)

((entropy,1,10),0.4871319520174482)

========================================================================================

 运行程序耗时: 331s

16/07/18 18:27:11 INFO SparkContext: Invoking stop() from shutdown hook

16/07/18 18:27:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040
16/07/18 18:27:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped!

16/07/18 18:27:12 INFO MemoryStore: MemoryStore cleared

16/07/18 18:27:12 INFO BlockManager: BlockManager stopped

16/07/18 18:27:12 INFO BlockManagerMaster: BlockManagerMaster stopped

16/07/18 18:27:12 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint: OutputCommitCoordinator stopped!

16/07/18 18:27:12 INFO SparkContext: Successfully stopped SparkContext

16/07/18 18:27:12 INFO ShutdownHookManager: Shutdown hook called

16/07/18 18:27:12 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-d7aaabfd-bd23-4a4d-b3a4-c21fd38eea98

16/07/18 18:27:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.

Process finished with exit code 0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: