您的位置:首页 > 其它

OneHotEncoder介绍单属性多属性scala实现

2017-06-07 21:00 375 查看
       因为项目的需要,将数据库中表的属性向量化,然后进行机器学习,所以去spark官网学习了一下OneHotEncoder,官网的相关介绍比较少,主要是针对单属性的处理,但是项目的要求是多属性的处理,网上找了很多的资料,研究了大半天终于将它集成到了自己的项目之中,下面分享一下自己的学习心得,说的不好的地方,还请各位大神多多指教。
      介绍:将类别映射为二进制向量,其中至多一个值为1(其余为零),这种编码可供期望连续特征的算法使用,比如逻辑回归,这些分类的算法。
     好处:1.解决分类器不好处理属性数据的问题(分类器往往默认数据是连续的,并且是有序的)
                2.在一定程度上也起到了扩充特征的作用
     原理:1.String字符串转换成索引IndexDouble
                2.索引转化成SparseVector
      总结:OneHotEncoder=String->IndexDouble->SparseVector

单属性的官网实现:

package com.iflytek.features

import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}

import org.apache.spark.ml.feature.{IndexToString, StringIndexer}

import org.apache.spark.sql.SparkSession

import org.apache.spark.ml.linalg.SparseVector
object OneHotEncoder {

  val spark=SparkSession.builder().appName("pca").master("local").getOrCreate()

  def main(args: Array[String]): Unit = {

  val df = spark.createDataFrame(Seq(

  (0, "a"),

  (1, "b"),

  (2, "c"),

  (3, "a"),

  (4, "a"),

  (5, "c")

  )).toDF("id", "category")
  //可以把一个属性列里的值映射成数值类型

  val indexer = new StringIndexer()

    .setInputCol("category")

    .setOutputCol("categoryIndex")

    .fit(df)

   

  val indexed = indexer.transform(df)
  indexed.select("category", "categoryIndex").show()
  val encoder = new OneHotEncoder()

    .setInputCol("categoryIndex")

    .setOutputCol("categoryVec")

   

  val encoded = encoder.transform(indexed)

  encoded.select("id","categoryIndex", "categoryVec").show()

 

  encoded.select("categoryVec").foreach {

    x => println(x.getAs[SparseVector]("categoryVec").toArray.foreach {

      x => print(x+" ")

      }

    )

    }

    }

}

输出结果如下:
+--------+-------------+

|category|categoryIndex|

+--------+-------
4000
------+

|       a|          0.0|

|       b|          2.0|

|       c|          1.0|

|       a|          0.0|

|       a|          0.0|

|       c|          1.0|

+--------+-------------+

+---+-------------+-------------+

| id|categoryIndex|  categoryVec|

+---+-------------+-------------+

|  0|          0.0|(2,[0],[1.0])|

|  1|          2.0|    (2,[],[])|

|  2|          1.0|(2,[1],[1.0])|

|  3|          0.0|(2,[0],[1.0])|

|  4|          0.0|(2,[0],[1.0])|

|  5|          1.0|(2,[1],[1.0])|

+---+-------------+-------------+

1.0 0.0 ()

0.0 0.0 ()

0.0 1.0 ()

1.0 0.0 ()

1.0 0.0 ()

0.0 1.0 ()

多属性的找了很多资料,业务需求一般都是多属性的:

import  sc.implicits._

    val vectorData = dataRDD

      //将 枚举的值 转化为 Double

     .map( x => (  enum2Double("是否已流失",x._1),   x._2(0) , x._2(1) ,x._2(2),x._2(3) ) )

       //ml.feature.LabeledPoint

     .toDF("loss","gender","age","grade","region")

     //indexing columns

    val stringColumns = Array("gender","age","grade","region")

    val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringColumns.map(

    cname => new StringIndexer()

        .setInputCol(cname)

        .setOutputCol(s"${cname}_index")

     )
    // Add the rest of your pipeline like VectorAssembler and algorithm

    val index_pipeline = new Pipeline().setStages(index_transformers)

    val index_model = index_pipeline.fit(vectorData)

    val df_indexed = index_model.transform(vectorData)

    //encoding columns

    val indexColumns  = df_indexed.columns.filter(x => x contains "index")

    val one_hot_encoders: Array[org.apache.spark.ml.PipelineStage] = indexColumns.map(

    cname => new OneHotEncoder()

       .setInputCol(cname)

       .setOutputCol(s"${cname}_vec")

    )

    val pipeline = new Pipeline().setStages(index_transformers ++ one_hot_encoders)
    val model = pipeline.fit(vectorData)

   

    model.transform(vectorData).select("loss","gender_index_vec","age_index_vec","grade_index_vec","region_index_vec")

    .map (

        x=>

        ml.feature.LabeledPoint(x.apply(0).toString().toDouble ,ml.linalg.Vectors.dense(x.getAs[SparseVector]    ("gender_index_vec").toArray++x.getAs[SparseVector]("age_index_vec").toArray++x.getAs[SparseVector]("grade_index_vec").toArray++x.getAs[SparseVector]("region_index_vec").toArray)) 

    )
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐