您的位置:首页 > 数据库

SparkSQL 自定义算子UDF、UDAF、UDTF

2019-02-13 18:10 561 查看
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/laksdbaksjfgba/article/details/87162906

背景

我根据算子输入输出之间的关系来理解算子分类:

UDF——输入一行,输出一行
UDAF——输入多行,输出一行
UDTF——输入一行,输出多行

本文主要是整理这三种自定义算子的具体实现方式

使用的数据集——用户行为日志user_log.csv,csv中自带首行列头信息,字段定义如下:
1. user_id | 买家id
2. item_id | 商品id
3. cat_id | 商品类别id
4. merchant_id | 卖家id
5. brand_id | 品牌id
6. month | 交易时间:月
7. day | 交易事件:日
8. action | 行为
9. age_range | 买家年龄分段
10. gender | 性别
11. province| 收获地址省份

新手上路,有任何搞错的地方,或者走了弯路,还请大家不吝指出,帮我进步



SparkSQL算子分类


1. UDF

通过匿名函数的方式注册自定义算子

object UserAnalysis {
def main(args:Array[String]): Unit ={

//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport()      //启用hive
.getOrCreate()

//sparksession直接读取csv,可设置分隔符delimitor.
val userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)

//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")

//通过匿名函数的方式注册自定义算子:将0和1分别转换成female和male
sparkSession.udf.register("getGender",(gender:Integer)=>{
var result="unknown"
if (gender==0){
result="female"
}else if(gender==1){
result="male"
}
result
})

val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")

//显示DataFrame内容
genderDF.show(10)
}
}

通过实名函数的方式注册自定义算子

object UserAnalysis {
def main(args:Array[String]): Unit ={

//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport()      //启用hive
.getOrCreate()

//sparksession直接读取csv,可设置分隔符delimitor.
val userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)

//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")

/*
通过实名函数的方式注册自定义算子
Scala中方法和函数是两个不同的概念,方法无法作为参数进行传递,
也无法赋值给变量,但是函数是可以的。在Scala中,利用下划线可以将方法转换成函数:
*/
sparkSession.udf.register("getGender",getGender _)

val genderDF = sparkSession.sql("select getGender(gender) as A from userDF")

//显示DataFrame内容
genderDF.show(10)
}

//将0和1分别转换成female和male
def getGender(gender:Integer): String ={
var result="unknown"
if (gender==0){
result="female"
}else if(gender==1){
result="male"
}
result
}
}

通过以上两种方式实现相同算子,得到相同的结果:


2. UDAF

通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来自定义UDAF算子

class UserDefinedMax extends UserDefinedAggregateFunction{

//定义输入数据的类型,两种写法都可以
//override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)

//定义聚合过程中所处理的数据类型
override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))

//定义输入数据的类型
override def dataType: DataType = IntegerType

//规定一致性
override def deterministic: Boolean = true

//在聚合之前,每组数据的初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}

//每组数据中,当新的值进来的时候,如何进行聚合值的计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(input.getInt(0)> buffer.getInt(0))
buffer(0)=input.getInt(0)
}

//合并各个分组的结果
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if(buffer2.getInt(0)> buffer1.getInt(0)){
buffer1(0)=buffer2.getInt(0)
}
}

//返回最终结果
override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
}

测试代码

object UserAnalysis {
def main(args:Array[String]): Unit ={

//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/small_user_log.csv"

//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport()      //启用hive
.getOrCreate()

//sparksession直接读取csv,可设置分隔符delimitor.
var userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)

//转换dataframe字段类型或字段名
import org.apache.spark.sql.functions._
userDF = userDF .withColumn("item_id", col("item_id").cast(IntegerType))

//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")

//注册算子,如果UserDefinedMax是object,不用new
sparkSession.udf.register("UserDefinedMax", new UserDefinedMax)

//测试sparksql内嵌max算子结果
val MaxDF = sparkSession.sql("select max(item_id) from userDF")

MaxDF.show

//测试用户自定义max算子结果
val UserDefinedMaxDF = sparkSession.sql("select UserDefinedMax(item_id) from userDF")

UserDefinedMaxDF.show
}
}

可以看到两个max算子的输出相同:



3. UDTF

通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来自定义UDTF算子

class UserDefinedUDTF extends GenericUDTF{

//这个方法的作用:1.输入参数校验  2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
if (args.length != 1) {
throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
}
if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
}

val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]

//这里定义的是输出列默认字段名称
fieldNames.add("col1")
//这里定义的是输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}

//这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
override def process(args: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = args(0).toString.split("")
for(i <- strLst){
var tmp:Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}

override def close(): Unit = {}
}

测试代码

object UserAnalysis {
def main(args:Array[String]): Unit ={

//测试数据所在的本地路径
val userDataPath = "file:///home/hadoop/data_format/zxc/small1.csv"

//创建sparksession
val sparkSession = SparkSession
.builder
.master("local")
.appName("UserAnalysis")
.enableHiveSupport()      //启用hive
.getOrCreate()

//sparksession直接读取csv,可设置分隔符delimitor.
var userDF = sparkSession.read
.option("header","true")
.csv(userDataPath)

//将DataFrame注册成视图,然后即可使用hql访问
userDF.createOrReplaceTempView("userDF")

//注册utdf算子,这里无法使用sparkSession.udf.register()
sparkSession.sql("CREATE TEMPORARY FUNCTION UserDefinedUDTF as 'com.zxc.sparkAppTest.udtf.UserDefinedUDTF'")

//使用UDTF算子处理原表userDF
val UserDefinedUDTFDF = sparkSession.sql(
"select " +
"user_id," +
"item_id," +
"cat_id," +
"merchant_id," +
"brand_id," +
"month," +
"day," +
"action," +
"age_range," +
"gender," +
"UserDefinedUDTF(province) " +
"from " +
"userDF"
)

UserDefinedUDTFDF.show
}
}

对比原表和经UDTF算子处理之后的结果表:



► 小结

  • 关于UDF
    简单粗暴的理解,它就是输入一行输出一行的自定义算子
    我们可以通过实名函数或匿名函数的方式来实现,并使用sparkSession.udf.register()注册
    需要注意,截至目前(spark2.4)最多只支持22个输入参数的UDF

    另外还有一种实现方案(基于spark1.5,spark2.4待测试):
    继承org.apache.hadoop.hive.ql.exec.UDF


  • 关于UDAF
    简单粗暴的理解,它就是输入多行输出一行的自定义算子,比UDF的功能强大一些
    通过实现抽象类org.apache.spark.sql.expressions.UserDefinedAggregateFunction来实现UDAF算子,并使用sparkSession.udf.register()注册

    另外还有一种实现方案(基于spark1.5,spark2.4待测试):
    先继承org.apache.hadoop.hive.ql.exec.UDAF
    内部静态类实现org.apache.hadoop.hive.ql.exec.UDAFEvaluator


  • 关于UDTF
    简单粗暴的理解,它就是输入一行输出多行的自定义算子,可输出多行多列,又被称为 “表生成函数”
    通过实现抽象类org.apache.hadoop.hive.ql.udf.generic.GenericUDTF来实现UDTF算子,但是似乎无法使用sparkSession.udf.register()注册。注册方法如下:

sparkSession.sql("CREATE TEMPORARY FUNCTION 自定义算子名称 as '算子实现类全限定名称'")
实现UDTFf还需要注意(基于spark1.5,可能已过时):
udtf,process方法中对参数需要使用toString,String强转没用
sparksql子查询必须要有别名
算子内部使用竖线切分字符串时,需要转义
udtf调用forward方法,必须传字符串数组,即使只有一个元素

以上就是全部内容,持续更新,请多提宝贵意见

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: