Spark MLlib特徴処理:OneHotEncoder OneHot符号化---原理と実戦
4223 ワード
げんり
1)String文字列をインデックスIndexDoubleに変換
2)インデックスをSparseVectorに変換
まとめ:OneHotEncoder=String>IndexDouble>SparseVector
コード実戦
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}
object OneHotEncoderExample {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[8]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
// Seq DataFrame
// Seq ( ),Vector Range List Array Seq
val df: DataFrame = sqlContext.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")
// String => IndexDouble
val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
val indexed = indexer.fit(df).transform(df)
// IndexDouble => SparseVector
// OneHotEncode:
// Spark : The last category is not included by default
// python scikit-learn's OneHotEncoder ,scikit-learn's OneHotEncoder
val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec")
//
.setDropLast(false)
//transform
val encoded = encoder.transform(indexed)
encoded.select("category","categoryIndex", "categoryVec").show()
sc.stop()
}
}
//
// +--------+-------------+-------------+
// |category|categoryIndex| categoryVec|
// +--------+-------------+-------------+
// | a| 0.0|(3,[0],[1.0])|
// | b| 2.0|(3,[2],[1.0])|
// | c| 1.0|(3,[1],[1.0])|
// | a| 0.0|(3,[0],[1.0])|
// | a| 0.0|(3,[0],[1.0])|
// | c| 1.0|(3,[1],[1.0])|
// +--------+-------------+-------------+