您的位置:首页 > 运维架构 > Apache

Spark ML使用DataFrame进行K-Means

2019-10-18 07:10 1676 查看

1.前言

前一篇文章使用了RDD的方式,进行了K-Means聚类.

从Spark 2.0开始,程序包中基于RDD的API spark.mllib已进入维护模式.现在,用于Spark的主要机器学习API是软件包中基于DataFrame的API spark.ml.

这次使用DataFrame的方式 进行K-Means聚类.

数据集和基本配置查看前一篇文章

2.KMeansDataFrame

package com.htkj.spark.mllib;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;

public class KMeansDataFrame {
public static void main(String[] args) {
//创建SparkSession
SparkSession spark = SparkSession
.builder()
.master("local")
.appName("K-Means-DataFrame")
.getOrCreate();
//读取数据 转换为RDD
JavaRDD<String> clusterRDD = spark.sparkContext()
.textFile("C:\\Users\\Administrator\\Desktop\\cluster.txt", 1)
.toJavaRDD();
//设置列的名称
String schemaString="region x1 x2 x3 x4 x5 x6 x7 x8";
//创建集合 存储字段名称
ArrayList<StructField> fields = new ArrayList<>();
//将schemaString 切割 遍历  存储到字段中
for (int i = 0; i < schemaString.split(" ").length; i++) {
String fieldName = schemaString.split(" ")[i];
if (i==0){
StructField field = DataTypes.createStructField(fieldName, DataTypes.StringType, true);
fields.add(field);
}else {
StructField field = DataTypes.createStructField(fieldName, DataTypes.DoubleType, true);
fields.add(field);
}
}
//设置字段
StructType schema = DataTypes.createStructType(fields);
//存储数据
JavaRDD<Row> rowRDD = clusterRDD.map(record -> {
String[] split = record.split("\t");
return RowFactory.create(split[0]
, Double.valueOf(split[1])
, Double.valueOf(split[2])
, Double.valueOf(split[3])
, Double.valueOf(split[4])
, Double.valueOf(split[5])
, Double.valueOf(split[6])
, Double.valueOf(split[7])
, Double.valueOf(split[8])
);
});
//将RDD转化为dataFrame形式
Dataset<Row> dataFrame = spark.createDataFrame(rowRDD, schema);
//输出表结构
dataFrame.printSchema();
//创建临时视图
dataFrame.createOrReplaceTempView("cluster");
//查询
Dataset<Row> results = spark.sql("select region from cluster");
//单独提取某个字段的内容
Dataset<String> regionsDS = results.map((MapFunction<Row, String>) row -> "region:" + row.getString(0), Encoders.STRING());
//展示
regionsDS.show();
//查询所有
Dataset<Row> dataAll = spark.sql("select * from cluster");
//将region字段 转换为数字的 label
StringIndexerModel labelIndex = new StringIndexer().setInputCol("region").setOutputCol("label").fit(dataAll);
//将x1-x8字段 拼接为向量
VectorAssembler assembler = new VectorAssembler().setInputCols("x1,x2,x3,x4,x5,x6,x7,x8".split(",")).setOutputCol("features");
//如果有离散的数据 就使用OneHotEncoder
// OneHotEncoder x1 = new OneHotEncoder().setInputCol("x1").setOutputCol("x1").setDropLast(false);

//使用Pipeline 转化模型
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndex, assembler});
PipelineModel model = pipeline.fit(dataAll);
Dataset<Row> dataset = model.transform(dataAll);
//输出
dataset.show(100,100);
//设置聚类数量为5 迭代100次 随机数种子为100
KMeans kMeans = new KMeans().setK(5).setMaxIter(100).setSeed(100L);
//建立模型
KMeansModel kMeansModel = kMeans.fit(dataset).setFeaturesCol("features").setPredictionCol("prediction");
//输出误差平方和
double WSSSE = kMeansModel.computeCost(dataset);
System.out.println("误差平方和 "+WSSSE);
//输出聚类中心
Vector[] centers = kMeansModel.clusterCenters();
System.out.println("聚类中心: ");
for (Vector center : centers) {
System.out.println(center);
}
//输出预测结果
Dataset<Row> transform = kMeansModel.transform(dataset);
transform.foreach(s->{
int length = s.toString().length();
String[] split = s.toString().substring(1, length - 1).split(",");
String region = split[0];
String cluster = split[split.length - 1];
System.out.println(region+"属于聚类:"+cluster);
});
}
}

3.输出内容

3.1表结构

3.2单独提取某个字段的内容

3.3dataset展示内容

3.4误差平方和

3.5聚类中心

3.6预测结果

可以看到,北京 上海 广东 浙江 归为一类,可以认为分类结果基本准确

 

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  Apache Spark