您的位置:首页 > 其它

利用spark的mllib构建GBDT模型

2017-05-02 11:40 369 查看

GBDT模型

GBDT模型的介绍,我主要是参考博客:http://blog.csdn.net/w28971023/article/details/8240756

在这里,我主要归纳以下几点要素:

1.GBDT中的树都是回归树;

2.回归树节点分割点衡量最好的标准是叶子个数的上限;

3.GBDT的核心在于,每个棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量;

4.GB为Gradient Boosting, Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则趋向于0;

5.GBDT采用一个Shrinkage策略,本质上,Shrinkage为每棵树设置了一个weight,累加时要乘以这个weight,但和Gradient并没有关系。

利用spark构建GBDT模型

训练GBDT模型

public void trainModel(){

//初始化spark
SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
conf.set("spark.testing.memory","2147480000");
SparkContext sc = new SparkContext(conf);

//加载训练文件, 使用MLUtils包
JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.trainsetFile).toJavaRDD();

//训练模型, 默认情况下使用均值方差作为阈值标准
int numIteration = 10;  //boosting提升迭代的次数
int maxDepth = 3;       //回归树的最大深度
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
boostingStrategy.setNumIterations(numIteration);
boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
//记录所有特征的连续结果
Map<Integer, Integer> categoricalFeaturesInfoMap = new HashMap<Integer, Integer>();
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfoMap);
//gdbt模型
final GradientBoostedTreesModel model = GradientBoostedTrees.train(lpdata, boostingStrategy);
model.save(sc, modelpath);
sc.stop();
}


预测数据

public void predict() {
//初始化spark
SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
conf.set("spark.testing.memory","2147480000");
SparkContext sc = new SparkContext(conf);

//加载gbdt模型
final GradientBoostedTreesModel model = GradientBoostedTreesModel.load(sc, this.modelpath);

//加载测试文件
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictFile).toJavaRDD();
testData.cache();

//预测数据
JavaRDD<Tuple2<Double, Double>>  predictionAndLabel = testData.map(new Prediction(model)) ;

//计算所有数据的平均值方差
Double testMSE = predictionAndLabel.map(new CountSquareError()).reduce(new ReduceSquareError()) / testData.count();
System.out.println("testData's MSE is : " + testMSE);
sc.stop();
}

static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
GradientBoostedTreesModel model;
public Prediction(GradientBoostedTreesModel model){
this.model = model;
}
public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
Double score = model.predict(p.features());
return new Tuple2<Double , Double>(score, p.label());
}
}

static class CountSquareError implements Function<Tuple2<Double, Double>, Double> {
public Double call (Tuple2<Double, Double> pl) {
double diff = pl._1() - pl._2();
return diff * diff;
}
}

static  class ReduceSquareError implements Function2<Double, Double, Double> {
public Double call(Double a , Double b){
return a + b ;
}
}


关于具体的代码放至我的github上:https://github.com/Quincy1994/MachineLearning
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  spark gbdt