Spark MLlib特徴処理:OneHotEncoder OneHot符号化---原理と実戦

4223 ワード

げんり


1)String文字列をインデックスIndexDoubleに変換
2)インデックスをSparseVectorに変換
まとめ:OneHotEncoder=String>IndexDouble>SparseVector

コード実戦

import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}

object OneHotEncoderExample {
  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[8]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    //  Seq DataFrame
    // Seq ( ),Vector Range List Array Seq 
    val df: DataFrame = sqlContext.createDataFrame(Seq(
      (0, "a"),
      (1, "b"),
      (2, "c"),
      (3, "a"),
      (4, "a"),
      (5, "c")
    )).toDF("id", "category")

    // String => IndexDouble
    val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
    val indexed = indexer.fit(df).transform(df)

    // IndexDouble => SparseVector
    // OneHotEncode: 
    // Spark : The last category is not included by default  
    //  python scikit-learn's OneHotEncoder ,scikit-learn's OneHotEncoder 
    val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec")
      //  
        .setDropLast(false)
    //transform  
    val encoded = encoder.transform(indexed)
    encoded.select("category","categoryIndex", "categoryVec").show()
    sc.stop()
  }

}
//  
// +--------+-------------+-------------+
// |category|categoryIndex|  categoryVec|
// +--------+-------------+-------------+
// |       a|          0.0|(3,[0],[1.0])|
// |       b|          2.0|(3,[2],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// +--------+-------------+-------------+