Sparkカスタム関数(UDF、UDAF、UDTF)

40127 ワード

目次
  • 、カスタム標準関数(UDF)
  • 、カスタム集計関数(UDAF)
  • 、カスタムテーブル生成関数(UDTF)
  • Sparkは開発者のために大量の内蔵関数を提供しています。また、ユーザー定義の関数も使用できます。
    Sparkカスタム関数ステップ:1、定義関数2、登録関数SparkSession.udf.register():sql()のみ有効functions.udf():Data Frame APIに対して有効3、関数呼び出し
    1、カスタム標準関数(UDF)
    D:\test\t\目次の下にファイルhobries.txtがあります。ファイルの内容:
    alice	jogging,Coding,cooking
    lina	travel,dance
    
    需要:ユーザーの行為は個数統計要求出力フォーマットが好きです。
    alice	jogging,Coding,cooking	3
    lina	travel,dance			2
    
    import org.apache.spark.SparkContext
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{
         DataFrame, SparkSession}
    
    object SparkUDFDemo {
         
      //   
      case class Hobbies(name:String,hobbies: String)
    
      def main(args: Array[String]): Unit = {
         
        val spark :SparkSession= SparkSession.builder()
          .master("local[1]")
          .appName("SparkUDFDemo")
          .getOrCreate()
     	val sc:SparkContext = spark.sparkContext
    
        //            ,  RDD     DF
        import spark.implicits._
       
        val rdd:RDD[String] = sc.textFile("D:\\test\\t\\hobbies.txt")
        val df:DataFrame = rdd.map(x=>x.split("\t")).map(x=>Hobbies(x(0),x(1))).toDF()
    
        //df.printSchema()
        //df.show()
    
        df.registerTempTable("hobbies")
        //       ,       
        spark.udf.register("hoby_num",(s:String)=>s.split(",").size)
    
        val frame:DataFrame = spark.sql("select name,hobbies,hoby_num(hobbies) as hobnum from hobbies")
        frame.show()
      }
    }
    
    出力:
    +-----+--------------------+------+
    | name|             hobbies|hobnum|
    +-----+--------------------+------+
    |alice|jogging,Coding,co...|     3|
    | lina|        travel,dance|     2|
    +-----+--------------------+------+
    
    
    2、カスタム重合関数(UDAF)
    UDAF(User Defined Agregate Function)とは、ユーザーが定義する重合関数と、集合関数と一般関数の違いは何ですか?普通関数は1行の入力を受けて出力を生成しますが、集合関数は1組の入力を受けて出力を生成します。
    UDAF使用:
    UserDefinedAgregateFunctionを継承します。
    UserDefinedAggateFunctionを使用するステップ:
  • カスタムクラスはUserDefinedAggateFunctionを継承し、各段階の方法に対して
  • を実現する。
  • は、sparkにUDAFを登録し、名前
  • を結びつける。
  • その後、sql文で上に結合された名前を使って
  • を呼び出すことができます。
    D:\test\t\目次の下にファイルuser.jsonがあります。ファイルの内容:
    {"id": 1001, "name": "foo", "sex": "man", "age": 20}
    {"id": 1002, "name": "bar", "sex": "man", "age": 24}
    {"id": 1003, "name": "baz", "sex": "man", "age": 18}
    {"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
    {"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
    {"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
    
    需要:平均年齢を計算する。
    import org.apache.spark.sql.{
         Row, SparkSession, types}
    import org.apache.spark.sql.expressions.{
         MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    object SparkUDAFDemo {
         
      def main(args: Array[String]): Unit = {
         
        val spark = SparkSession.builder()
          .master("local[2]")
          .appName("SparkUDAFDemo")
          .getOrCreate()
        import spark.implicits._
        val sc = spark.sparkContext
        val df = spark.read.json("D:\\test\\t\\user.json")
        
        //        udaf  
        val myUdaf=new MyAgeAvgFunction
        spark.udf.register("myAvgAge",myUdaf)
    
        df.createTempView("userinfo")
        val resultDF = spark.sql("select myAvgAge(age) as avg_age from userinfo group by sex")
        resultDF.printSchema()
        resultDF.show()
      }
    }
    
    class MyAgeAvgFunction extends UserDefinedAggregateFunction{
         
      //            
      override def inputSchema: StructType = {
         
        new StructType().add("age",LongType)
        //     
        //StructType(StructField("age",LongType)::Nil)
      }
      //        
      override def bufferSchema: StructType = {
         
        new StructType().add("sum",LongType).add("count",LongType)
        //     
       // StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
      }
      //            
      override def dataType: DataType = DoubleType
    
      //           ,                
      override def deterministic: Boolean =true
      //       
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
         
        buffer(0)=0L
        buffer(1)=0L
      }
      //                 
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
         
        buffer(0)=buffer.getLong(0)+input.getLong(0)
        buffer(1)=buffer.getLong(1)+1
      }
    
      //          
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
         
       //    
        buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
        //  
        buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
      }
      //       
      override def evaluate(buffer: Row): Any = {
         
        buffer.getLong(0).toDouble/buffer.getLong(1)
      }
    }
    
    結果:
    root
     |-- avg_age: double (nullable = true)
    
    +------------------+
    |           avg_age|
    +------------------+
    |20.666666666666668|
    |18.666666666666668|
    +------------------+
    
    ブログを参照してください:https://www.cnblogs.com/cc11001100/p/9471859.html (このブログにはAgregatorを継承する方法も記載されています。)
    3、カスタムテーブル生成関数(UDTF)
    D:\test\t\ディレクトリの下にファイルudtf.txtがあります。ファイルの内容:
    01//zs//Hadoop scala spark hive hbase
    02//ls//Hadoop scala kafka hive hbase Oozie
    03//ww//Hadoop scala spark hive sqoop
    
    需要:lsのHadoop scala kafka hive hbase Oozeを次のように生成します。
      //      type           --(  )
      //      Hadoop
      //      scala
      //      kafaka
      //       hive
      //      hbase
      //      Oozie
    
    import java.util
    
    import org.apache.hadoop.hive.ql.exec.UDFArgumentException
    import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
    import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
    import org.apache.hadoop.hive.serde2.objectinspector.{
         ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
    import org.apache.spark.sql.SparkSession
    
    object SparkUDTFDemo {
         
      def main(args: Array[String]): Unit = {
         
        val spark = SparkSession.builder()
          .master("local[1]")
          .enableHiveSupport()		//  hive  
          .appName("SparkUDTFDemo")
          .getOrCreate()
        val sc = spark.sparkContext
    
        import spark.implicits._
    
        val lines = sc.textFile("D:\\test\\t\\udtf.txt")
        val stuDF = lines.map(_.split("//")).filter(x => x(1).equals("ls"))
          .map(x => (x(0), x(1), x(2))).toDF("id", "name", "class")
        //stuDF.printSchema()
        //stuDF.show()
    
        stuDF.createTempView("student")
        
        spark.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'kb09.sql.myUDTF'")
        //  AS               !!!
        val resultDF = spark.sql("select myUDTF(class) from student")
    
        resultDF.show()
      }
    }
    
    class myUDTF extends GenericUDTF{
         
    
      override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
         
        if (argOIs.length!=1){
         
          throw new UDFArgumentException("           ")
        }
        if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
         
          throw new UDFArgumentException("       ")
        }
        val fieldNames =new util.ArrayList[String]
        val fieldOIs =new util.ArrayList[ObjectInspector]
    
        fieldNames.add("type")
        //             
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
        ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
    
    
      }
    	//   Hadoop scala kafaka hive hbase Oozie
      override def process(objects: Array[AnyRef]): Unit ={
         
        //              
        val strings = objects(0).toString.split(" ")
        println(strings)
        for (str strings){
         
          val tmp = new Array[String](1)
          tmp(0)=str
          forward(tmp)
        }
      }
      override def close(): Unit = {
         }
    }
    
    出力:
    [Ljava.lang.String;@6d0e1408
    +------+
    |  type|
    +------+
    |Hadoop|
    | scala|
    | kafka|
    |  hive|
    | hbase|
    | Oozie|
    +------+