scalaの応用--UDF:ユーザーカスタム関数

6830 ワード

Windows 10にhadoopをインストールし、idaでmavenプロジェクトを作成します.
    
        2.2.0
        2.11
        1.8
    

    
        
            org.apache.spark
            spark-core_${scala.version}
            ${spark.version}
        
        
            org.apache.spark
            spark-sql_${scala.version}
            ${spark.version}
        
        
            org.apache.spark
            spark-streaming_${scala.version}
            ${spark.version}
        
        
            org.apache.spark
            spark-yarn_${scala.version}
            ${spark.version}
        

        
            mysql
            mysql-connector-java
            8.0.16
        
    


    
        learnspark
        
            
                net.alchim31.maven
                scala-maven-plugin
                3.2.2
                
                    
                        
                            compile
                            testCompile
                        
                    
                
            
            
                org.apache.maven.plugins
                maven-assembly-plugin
                3.0.0
                
                    
                        
                            learn
                        
                    
                    
                        jar-with-dependencies
                    
                
                
                    
                        make-assembly
                        package
                        
                            single
                        
                    
                
            
        
    

  
データの準備:
{"name":" 3", "age":20}
{"name":" 4", "age":20}
{"name":" 5", "age":20}
{"name":" 6", "age":20}
data/input/user/user.json
package com.zouxxyy.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}

/**
 * UDF:       
 */

object UDF {

  def main(args: Array[String]): Unit = {
    System.setProperty("hadoop.home.dir","D:\\gitworkplace\\winutils\\hadoop-2.7.1" )
//         hadoop   ,    hadoop       ,    
    val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("UDF")

    //   SparkSession
    val spark: SparkSession = SparkSession.builder.config(sparkConf).getOrCreate()

    import spark.implicits._

    //  json read    DataFrame
    val frame: DataFrame = spark.read.json("data/input/user/user.json")

    frame.createOrReplaceTempView("user")

    //    :            
    spark.udf.register("addName", (x:String)=> "Name:"+x)

    spark.sql("select addName(name) from user").show()

    //    :              

    val udaf1 = new MyAgeAvgFunction

    spark.udf.register("avgAge", udaf1)

    spark.sql("select avgAge(age) from user").show()

    //    :              

    val udaf2 = new MyAgeAvgClassFunction

    //            
    val avgCol: TypedColumn[UserBean, Double] = udaf2.toColumn.name("aveAge")

    //      Dataset DSL       
    val userDS: Dataset[UserBean] = frame.as[UserBean]

    userDS.select(avgCol).show()

    spark.stop()
  }
}

/**
 *        (   )
 */

class MyAgeAvgFunction extends UserDefinedAggregateFunction{

  //        
  override def inputSchema: StructType = {
    new StructType().add("age", LongType)
  }

  //         
  override def bufferSchema: StructType = {
    new StructType().add("sum", LongType).add("count", LongType)
  }

  //          
  override def dataType: DataType = DoubleType

  //       
  override def deterministic: Boolean = true

  //           
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //     ,    
    buffer(0) = 0L
    buffer(1) = 0L
  }

  //       ,        
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  //            
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //          ,       
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}


/**
 *        (   )
 */

case class UserBean (name : String, age : BigInt) //          BigInt
case class AvgBuffer(var sum: BigInt, var count: Int)

class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double] {

  //       
  override def zero: AvgBuffer = {
    AvgBuffer(0, 0)
  }

  //           
  override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
    b.sum = b.sum + a.age
    b.count = b.count + 1
    //   b
    b
  }

  //       
  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count

    b1
  }

  //      
  override def finish(reduction: AvgBuffer): Double = {
    reduction.sum.toDouble / reduction.count
  }

  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}