您的位置:首页 > 其它

Spark UDAF

2016-02-11 20:11 246 查看
package cn.spark.study.udf;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.expressions.MutableAggregationBuffer;

import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;

import org.apache.spark.sql.types.DataType;

import org.apache.spark.sql.types.DataTypes;

import org.apache.spark.sql.types.StructField;

import org.apache.spark.sql.types.StructType;

public class StringCount extends UserDefinedAggregateFunction {

/**
*
*/
private static final long serialVersionUID = 1L;

/**
* inputSchema 指的是,输入数据的类型
*/
@Override
public StructType inputSchema() {
StructField[] fields= {DataTypes.createStructField("str", DataTypes.StringType, true)};
StructType schema = DataTypes.createStructType(fields);
return schema;
}
/**
* bufferSchema 指的是,中间聚合时,所处理的数据的类型
*/
@Override
public StructType bufferSchema() {
StructField[] fields= {DataTypes.createStructField("count", DataTypes.IntegerType, true)};
StructType schema = DataTypes.createStructType(fields);
return schema;
}


/**

* dataType ,指的是,函数返回值的类型

*/

@Override

public DataType dataType() {

return DataTypes.IntegerType;

}

@Override
public boolean deterministic() {
return true;
}
/**
*  为每个分组的数据执行初始化操作
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0);
}

/**
*  指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
*/
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
Integer bf = buffer.<Integer>getAs(0);
buffer.update(0, bf+1);
}

@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
Integer bf1 = buffer1.<Integer>getAs(0);
Integer bf2 = buffer2.<Integer>getAs(0);
buffer1.update(0, bf1+bf2);
}

/**
* 最后,指的是,一个分组的聚合值,,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
*/
@Override
public Object evaluate(Row buffer) {
return buffer.<Integer>getAs(0);
}


}

使用:

package cn.spark.study.udf;

import java.util.ArrayList;

import java.util.Arrays;

import java.util.Iterator;

import java.util.List;

import org.apache.spark.SparkConf;

import org.apache.spark.api.java.JavaRDD;

import org.apache.spark.api.java.JavaSparkContext;

import org.apache.spark.api.java.function.FlatMapFunction;

import org.apache.spark.sql.DataFrame;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.RowFactory;

import org.apache.spark.sql.SQLContext;

import org.apache.spark.sql.types.DataTypes;

import org.apache.spark.sql.types.StructField;

import org.apache.spark.sql.types.StructType;

@SuppressWarnings(value={“unused”})

public class UdafSql {

/**
* he following example registers a Scala closure as UDF:


sqlContext.udf.register(“myUDF”, (arg1: Int, arg2: String) => arg2 + arg1)

The following example registers a UDF in Java:

sqlContext.udf().register("myUDF",
new UDF2<Integer, String, String>() {
@Override
public String call(Integer arg1, String arg2) {
return arg2 + arg1;
}


}, DataTypes.StringType);

Or, to use Java 8 lambda syntax:

sqlContext.udf().register(“myUDF”,

(Integer arg1, String arg2) -> arg2 + arg1,

DataTypes.StringType);

* @param args

*/

public static void main(String[] args) {

firstUdf();

}

private static void firstUdf(){
SparkConf conf = new SparkConf().setAppName("UdfSql").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlct = new SQLContext(jsc);
String[] str= {"Hpf99","Leo","Marray","Jack","Tom","Tom","Tom","Leo","Leo","Marray","Marray","Jack"};
List<String> lis = Arrays.asList(str);
JavaRDD<String> strRdd = jsc.parallelize(lis);

JavaRDD<Row> rowRdd = strRdd.mapPartitions(new FlatMapFunction<Iterator<String>, Row>() {

/**
*
*/
private static final long serialVersionUID = 1L;

@Override
public Iterable<Row> call(Iterator<String> t) throws Exception {

7d4c
List<Row> lis = new ArrayList<Row>();
while(t.hasNext()){
String next = t.next();
Row create = RowFactory.create(next);
lis.add(create);
}
return lis;
}
});

StructField[] fields= {DataTypes.createStructField("name", DataTypes.StringType, true)};
StructType schema = DataTypes.createStructType(fields);
DataFrame rowDF = sqlct.createDataFrame(rowRdd, schema);

rowDF.registerTempTable("names");

sqlct.udf().register("strCount", new StringCount());

DataFrame sql = sqlct.sql("SELECT name, strCount(name) as mycount,count(name) FROM names group by name");
sql.show();

jsc.close();
}


}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: