您的位置:首页 > 其它

Spark MLlib特征处理:OneHotEncoder OneHot编码 ---原理及实战

2016-11-12 14:06 351 查看

原理

1)String字符串转换成索引IndexDouble

2)索引转换成SparseVector

总结:OneHotEncoder = String > IndexDouble > SparseVector

代码实战

import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}

object OneHotEncoderExample {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[8]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)

// 将Seq集合转换成DataFrame
// Seq是一个有先后次序的序列(也可以叫集合),Vector Range List Array都属于Seq类型
val df: DataFrame = sqlContext.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")

// String => IndexDouble
val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
val indexed = indexer.fit(df).transform(df)

// IndexDouble => SparseVector
// OneHotEncode:实际上是转换成了稀疏向量
// Spark源码: The last category is not included by default 最后一个种类默认不包含
// 和python scikit-learn's OneHotEncoder不同,scikit-learn's OneHotEncoder包含所有
val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec")
// 设置最后一个是否包含
.setDropLast(false)
//transform 转换成稀疏向量
val encoded = encoder.transform(indexed)
encoded.select("category","categoryIndex", "categoryVec").show()
sc.stop()
}

}
// 输出
// +--------+-------------+-------------+
// |category|categoryIndex|  categoryVec|
// +--------+-------------+-------------+
// |       a|          0.0|(3,[0],[1.0])|
// |       b|          2.0|(3,[2],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// +--------+-------------+-------------+
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: