Spark3.x UDF-UDAF函数

624 阅读1分钟

Spark SQL UDF函数计算元素个数

package com.ruozedata.saprk3.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
 * @theme 自定义UDF函数
 * @author 阿左
 * @create 2022-05-01
 * */
object UDFAPP {
    def main(args: Array[String]): Unit = {
        val sparkConf = new SparkConf().setMaster("local[2]").setAppName("UDFAPP")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()

        import spark.implicits._
        val df = spark.sparkContext.parallelize(Array(("一班", "a,b,c"), ("二班", "d,e")))
                .toDF("class", "student")

        df.createTempView("tmp")

        // 定义udf函数
        val stu_num_udf = (students :String) =>{
            val splits = students.split(",")
            splits.length
        }

        //注册udf函数,返回值可用于api使用
        val my_udf = spark.udf.register("stu_num", stu_num_udf)

        spark.sql(
            """
              |
              |select
              |class,student,stu_num(student)
              |from
              |tmp
              |""".stripMargin).show()

        //API使用
        import org.apache.spark.sql.functions._
        df.select(
            $"class",
            $"student",
            my_udf($"student").as("nums")
        ).show()
        
        spark.stop()
    }
}

执行结果: image.png

Spark SQL UDAF函数计算平均值

测试main

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
 * @theme 自定义UDAF函数计算平均值
 * @author 阿左
 * @create 2022-05-01
 * */
object UDAFApp {
    def main(args: Array[String]): Unit = {
        val sparkConf = new SparkConf().setMaster("local[2]").setAppName("UDFAPP")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()

        val array = Array(("语文", "a", "89.0"),
            ("语文", "b", "67.9"),
            ("数学", "a", "59"),
            ("数学", "b", "90.0"))

        import spark.implicits._
        val df = spark.sparkContext.parallelize(array)
                .toDF("object","student","score")
        df.createTempView("tmp")
        df.show()
        df.printSchema()
        val double2Int = spark.udf.register("double2Int", (score: String) => {
            score.toDouble
        })

        //导包
        import org.apache.spark.sql.functions._
        // 注册UDAF函数
        spark.udf.register("avg_score", udaf(AvgScore))

        spark.sql(
            """
              |
              |select
              |object, avg_score(double2Int(score)) as avg_score
              |from tmp
              |group by object
              |
              |""".stripMargin).show

        spark.stop()
    }
}

AvgScore <UDAF聚合函数,spark3.x>

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

/**
 * @theme 计算平均值
 * @author 阿左
 * @create 2022-05-01
 * */
object AvgScore extends Aggregator[Double,(Double, Long),Double]{
    //初始化
    override def zero: (Double, Long) = (0.0D, 0L)
    //分区内
    override def reduce(b: (Double, Long), a: Double): (Double, Long) = (b._1 + a, b._2+1)
    //分区间
    override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = (b1._1+b2._1, b2._2+b2._2)
    //输出结果
    override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
    //编码
    override def bufferEncoder: Encoder[(Double, Long)] = Encoders.product
    //编码
    override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

执行结果:

image.png