在上一篇 文章 中我们介绍了一些 Flink SQL 的基础内容,以及与 Spark SQL 对比,有兴趣的小伙伴可以点连接进去看看。这篇文章,我们来说说UDF(User-Defined Functions)——用户自定义函数。
其实,关于UDF这部分官方文档就写的挺好的,简单明了,而且配有DEMO,有兴趣的同学,可以到 参考文档 里去找到连接。
首先,如果想使用自定义函数,那么必须在之前来注册这个函数,使用TableEnvironment的registerFunction()方法来注册。注册之后自定义函数会被插入到 TableEnvironment的函数目录中,以便API或SQL正确解析并执行它。在 Flink 中,UDF分为三类:标量函数(ScalarFunction)、表函数( TableFunction) 、聚合函数(
AggregateFunction)。
标量函数(ScalarFunction)
简单的说,标量函数,就是你输入几个数(0个或几个都行),经过一系列的处理,再返回给你几个数,这个案例咱们还使用上一篇文章中使用的意甲射手榜的案例,一般来说,总进球数=主场进球数+客场进球数,但是今年的规则有变,客场进球按两个球计算( 本文案例和前文有区别,使用scala,大家注意一下)。
import org.apache.flink.table.functions.ScalarFunctionclass TotalScores extends ScalarFunction{ private var wight:Int = 1 ; def this(wight:Int){ this() this.wight = wight } def eval(home:Int,visit:Int): Int = home+visit*this.wight}
首先,需要继承ScalarFunction该类,这里我们添加了一个构造器,传入的参数作为客场进球权重,然后实现eval方法,输入参数为主客场进球数,输出则为总进球数。
接下来,我们来写测试类:
import org.apache.flink.api.scala.ExecutionEnvironmentimport org.apache.flink.api.scala.typeutils.Typesimport org.apache.flink.table.api.TableEnvironmentimport org.apache.flink.table.sources.CsvTableSourceimport org.apache.flink.types.Rowimport org.apache.flink.api.scala._object TestScalarFunction { def main(args: Array[String]): Unit = { val filePath = "E:\\devlop\\workspace\\streaming1\\src\\main\\resources\\testdata.csv" val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env) val csvtable = CsvTableSource .builder .path(filePath) .ignoreFirstLine .fieldDelimiter(",") .field("rank", Types.INT) .field("player", Types.STRING) .field("club", Types.STRING) .field("matches", Types.INT) .field("red_card", Types.INT) .field("total_score", Types.INT) .field("total_score_home", Types.INT) .field("total_score_visit", Types.INT) .field("pass", Types.INT) .field("shot", Types.INT) .build tableEnv.registerTableSource("goals", csvtable) tableEnv.registerFunction("ts",new TotalScores(2)) val tableTest = tableEnv.sqlQuery("select player,total_score_home,total_score_visit,ts(total_score_home,total_score_visit) from goals where total_score > 10")//.scan("test").where("id='5'").select("id,sources,targets") tableEnv.toDataSet[Row](tableTest).print() }}
首先别忘记引用
import org.apache.flink.api.scala._
否则会有奇怪事情发生。
然后,注册函数,默认构造客场进球权重为2
tableEnv.registerFunction("ts",new TotalScores(2))
"select player,total_score_home,total_score_visit,ts(total_score_home,total_score_visit) from goals where total_score > 10"
在SQL中使用函数 ts(total_score_home,total_score_visit) 就这么简单
我们来看下输出:
C-罗纳尔多,5,7,19
夸利亚雷拉,5,5,15
萨帕塔,1,4,9
米利克,0,1,2
皮亚特克,2,0,2
因莫比莱,3,3,9
卡普托,2,4,10
表函数(TableFunction)
简单的说,表函数,就是你输入几个数(0个或几个都行),经过一系列的处理,再返回给你行数,返回的行可以包含一列或是多列值。这里我们使用一套新的数据案例来做一个说明。
假设这是某年四个直辖市四个季度GDP的一张透视表(说到透视表,想了解的同学可以异步到我之前的 文章 去看看)
provice,s1,s2,
s3,s4天津,10,
11,13,14北京,
13,16,17,
18重庆,14,12,
13,14上海,15,
11,15,17
我们来将这张透视表,还原成一张列表,接下来,我们来看代码
import org.apache.flink.table.functions.TableFunctionclass UnPivotFunction(separator: String) extends TableFunction[(String)] {@scala.annotation.varargs def eval(strs:String*): Unit = { strs.foreach(x=>collect(x)) }}
函数要继承TableFunction,后面泛型需要输入返回列的类型,这里为了方便,我们就使用了字符串。我们计划在查询里面把四个季度的值都输入进来,转换成列表。collect是TableFunction提供的函数,用于添加列,eval方法的参数,可以根据你的需要自行扩展,注意在使用不确定参数值的时候,加上注解@scala.annotation.varargs
接下来,我们来测试一下
import org.apache.flink.api.common.typeinfo.Typesimport org.apache.flink.api.scala.{ExecutionEnvironment, _}import org.apache.flink.table.api.TableEnvironmentimport org.apache.flink.table.sources.CsvTableSourceimport org.apache.flink.types.Rowimport wang.datahub.udf.UnPivotFunctionobject TestMyTableFunction2 { def main(args: Array[String]): Unit = { val filepath = "E:\\devlop\\workspace\\testsbtflink\\src\\main\\resources\\GDP.csv" val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env) tableEnv.registerFunction("mtf2", new UnPivotFunction("@")) val cts = CsvTableSource.builder().ignoreFirstLine() //provice,s1,s2,s3,s4 .field("provice",Types.STRING) .field("s1",Types.STRING) .field("s2",Types.STRING) .field("s3",Types.STRING) .field("s4",Types.STRING) .path(filepath) .build() tableEnv.registerTableSource("m",cts) val tableTest = tableEnv.sqlQuery("select provice,word from m , LATERAL TABLE(mtf2(s1,s2,s3,s4)) as T(word)") val stream = tableEnv.toDataSet[Row](tableTest) stream.print() }}
在SQL我使用了 JOIN LATERAL ,有兴趣了解的同学,可以看下云栖的文章,我放在参考文档里了。
我们来看下输出结果:
天津,10
天津,11
天津,13
天津,14
北京,13
北京,16
北京,17
北京,18
上海,15
上海,11
上海,15
上海,17
重庆,14
重庆,12
重庆,13
重庆,14
这个案例也许并不是那么恰当,其实,也可以利用到邮件切分等场景,这里算是抛砖引玉把。
聚合函数(AggregateFunction)
关于聚合函数,官方文档上的这张图,就充分的解释了其工作原理,主要计算通过
-
createAccumulator() -
accumulate() -
getValue()
这几个方法来完成,首先我们createAccumulator创建累加器,然后调用accumulate累加计算,最后getValue获取值。
当然这只是完成了初步工作,
-
retract() -
merge() -
resetAccumulator()
我们还需要回滚,合并,重置累加器等操作以适应不同的计算场景。
好了,我们的案例,再次来到了大家喜闻乐见的意甲联赛,这次我们统计俱乐部的进球数,还是使用了一个更靠谱的规则,就是给客场进球加了一个权重,然后来计算加权场均进球数。
先来创建累加器
class WeightedAvgAccum { var sum = 0 var count = 0}
然后创建计算函数
import java.lang.{Integer => JInteger,String => JString}import org.apache.flink.table.functions._class WeightedAvg(iWeight:Int) extends AggregateFunction[JInteger, WeightedAvgAccum] { override def createAccumulator(): WeightedAvgAccum = { new WeightedAvgAccum } override def getValue(acc: WeightedAvgAccum): JInteger = { if (acc.count == 0) { null } else { acc.sum / acc.count } } def accumulate(acc: WeightedAvgAccum,club:JString, home: JInteger, visit: JInteger): Unit = { acc.sum += home + visit * iWeight acc.count += 1 } def resetAccumulator(acc: WeightedAvgAccum): Unit = { acc.count = 0 acc.sum = 0 }}
接下来,我们来测试一下:
import org.apache.flink.api.scala.typeutils.Typesimport org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}import org.apache.flink.table.api.TableEnvironmentimport org.apache.flink.table.sources.CsvTableSourceimport org.apache.flink.types.Rowimport org.apache.flink.api.scala._object Testf { def main(args: Array[String]): Unit = { val filePath = "E:\\devlop\\workspace\\streaming1\\src\\main\\resources\\testdata.csv" val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env) val csvtable = CsvTableSource .builder .path(filePath) .ignoreFirstLine .fieldDelimiter(",") .field("rank", Types.INT) .field("player", Types.STRING) .field("club", Types.STRING) .field("matches", Types.INT) .field("red_card", Types.INT) .field("total_score", Types.INT) .field("total_score_home", Types.INT) .field("total_score_visit", Types.INT) .field("pass", Types.INT) .field("shot", Types.INT) .build tableEnv.registerTableSource("test", csvtable) tableEnv.registerFunction("myf",new MyFunction("111")) tableEnv.registerFunction("wag",new WeightedAvg(2)) val tableTest3 = tableEnv.sqlQuery("select club,wag(club,total_score_home,total_score_visit) as ag from test group by club") tableEnv.toDataSet[Row](tableTest3).print() }}
查看下结果:
切沃,2
拉齐奥,3
斯帕尔,1
博洛尼亚,1
国际米兰,3
帕尔马,2
恩波利,2
桑普多利亚,4
那不勒斯,4
都灵,2
AC米兰,3
亚特兰大,5
佛罗伦萨,2
卡利亚里,2
罗马,3
乌迪内斯,2
弗罗西诺内,2
尤文图斯,4
热那亚,3
萨索洛,2
最后(敲黑板),大家在聚合表的案例里,应该发现我使用了Java的基础类型,而不是Scala的数据类型,这是因为在UDF执行过程中,数据的创建,转换以及装箱拆箱都会带来额外的消耗,所以 Flink 官方,其实推荐UDF进来使用Java编写。
UDF其实是一个很神奇的东西,值得我们去探索与研究,下一期写点什么呢?如果您有建议或意见,欢迎与我联系,探讨。
参考文档:
https://ci.apache.org/projects/flink/flink-docs-stable/dev/table/udfs.html
https://yq.aliyun.com/articles/674345