您的位置:首页 > 其它

Spark平台下的组合分类器AdaBoost

2016-07-11 10:58 483 查看
首先在github上发现了写好的Adaboost包,可以用来测试下能否使用。
https://github.com/tizfa/sparkboost
对于Java程序需求的是JavaRDD<MultilabelPoint> 数据格式,而读取的是RDD<labeledPoint>,转化为JavaRDD<labeledPoint>。

所以要对于两种数据格式进行转换。把label,feature对应起来。

public class ClassifierTask {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("ClassifierTask").setMaster("local");

                JavaSparkContext sc = new JavaSparkContext(conf);

               // 得到常用的Sparkconf和sc, JavaSparkContext to SparkContext

                SparkContext sc1 = sc.sc();

                String inputFile = "D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_binary_classification_data.txt";

                JavaRDD<String> StringFile = sc.textFile("D:\\softs\\spark-1.6.0-bin-hadoop2.6\\data\\mllib\\sample_libsvm_data.txt");

         
JavaRDD<LabeledPoint> FileLabeledPoint = MLUtils.loadLibSVMFile(sc1, inputFile).toJavaRDD();

               // from RDD to train model,转换成multilabelpoint
        JavaRDD<MultilabelPoint> rdd = FileLabeledPoint.map(Row -> {
    int a = (int)Row.label();
    SparseVector b = (SparseVector)Row.features();
    int docID =0;
    int[] labels = {a};
    SparseVector feature = b;
    return new MultilabelPoint(docID, feature, labels);
   });
        //train set is 0.8, test set is 0.2,设置权重
       double[] weights = {0.8,0.2};
       
       JavaRDD<MultilabelPoint>[] data = rdd.randomSplit(weights);
       AdaBoostMHLearner learner = new AdaBoostMHLearner(sc);

               //设置分类器的各项参数
learner.setNumIterations(100);
learner.setNumDocumentsPartitions(2);
learner.setNumFeaturesPartitions(2);
learner.setNumLabelsPartitions(2); 
BoostClassifier classifier = learner.buildModel(data[0]);
   
ClassificationResults  results = classifier.classifyWithResults(sc, data[1], 1);

// Print results in a StringBuilder.
StringBuilder sb = new StringBuilder();
sb.append("**** Effectiveness\n");
sb.append(results.getCt().toString() + "\n");
sb.append("********\n");
for (int i = 0; i < results.getNumDocs(); i++) {
int docID = results.getDocuments()[i];
int[] labels = results.getLabels()[i];
int[] goldLabels = results.getGoldLabels()[i];
sb.append("DocID: " + docID + ", Labels assigned: " + Arrays.toString(labels) + ", Labels scores: " + Arrays.toString(results.getScores()[i]) + ", Gold labels: " +                             Arrays.toString(goldLabels)
+ "\n");
}
System.out.print(sb);
}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: