spark naive bayes 実験メモ


import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.feature.HashingTF

val htf = new HashingTF(10000);

val pos_data = sc.textFile("test_pos.txt").map { text => new LabeledPoint(0, htf.transform(text.split("\\s+")))};
val neg_data = sc.textFile("test_neg.txt").map { text => new LabeledPoint(1, htf.transform(text.split("\\s+")))};

var data = pos_data.union(neg_data);
var splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)

val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial")

var result = test.map { t =>
        val predicted = model.predict(t.features)

        (predicted, t.label) match {
          case (0.0, 0.0) => "TN"
          case (0.0, 1.0) => "FN"
          case (1.0, 0.0) => "FP"
          case (1.0, 1.0) => "TP"
        }
}.countByValue()

val totalCount = test.count()

val truePositiveCount = if(result.contains("TP")) result("TP").toDouble else 0;
val trueNegativeCount = if(result.contains("TN")) result("TN").toDouble else 0;
val falsePositiveCount = if(result.contains("FP")) result("FP").toDouble else 0;
val falseNegativeCount = if(result.contains("FN")) result("FN").toDouble else 0;

val accuracy = (truePositiveCount + trueNegativeCount) / totalCount
var threatscore = truePositiveCount / (truePositiveCount + falsePositiveCount + falseNegativeCount);
var percision = truePositiveCount/(truePositiveCount + falsePositiveCount);
var recall = truePositiveCount / (truePositiveCount + falseNegativeCount);
var f = truePositiveCount /( truePositiveCount + (falsePositiveCount + falseNegativeCount) /2)

println("accuracy: " + accuracy)
println("threatscore: " + threatscore)
println("percision: " + percision)
println("recall: " + recall )
println("f: " + f)

test.collect.foreach { t =>
        val predicted = model.predict(t.features);
        println(t.label +": "+ predicted)
}