您的位置:首页 > 其它

改写Spark JdbcRDD,支持自己定义分区查询条件

2015-11-19 20:26 495 查看
文章来源:/article/2161037.html


改写Spark JdbcRDD,支持自己定义分区查询条件(转)

2015-04-22 15:51 597人阅读 评论(0) 收藏 举报


分类:

大数据(102)




改写Spark JdbcRDD,支持自己定义分区查询条件

分类: 大数据 Spark2015-02-06
11:58 139人阅读 评论(0)收藏 举报
Spark
JdbcRDD

Spark自带的JdbcRDD,只支持Long类型的分区参数,分区必须是一个Long区间。很多情况下,这种方式都不适用。

我对JdbcRDD进行了改写,可支持完全自定义分区条件。

主要实现思路:

把设置查询参数部分改写成可以自定义的函数。这样自己想怎么样设置分区参数都可以。

直接上代码吧:

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

package org.apache.spark.rdd

import org.apache.spark.TaskContext

import org.apache.spark.api.java.JavaSparkContext.fakeClassTag

import org.apache.spark.api.java.JavaSparkContext

import org.apache.spark.api.java.function.{Function => JFunction}

import org.apache.spark.api.java.JavaRDD

import org.apache.spark.SparkContext

import org.apache.spark.util.NextIterator

import scala.reflect.ClassTag

import java.sql.ResultSet

import java.sql.Connection

import org.apache.spark.Partition

import org.apache.spark.Logging

import java.sql.PreparedStatement

class CustomizedJdbcPartition(idx: Int, parameters: Map[String, Object]) extends Partition {

override def index = idx

val partitionParameters=parameters

}

// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private

/**

* An RDD that executes an SQL query on a JDBC connection and reads results.

* For usage example, see test case JdbcRDDSuite.

*

* @param getConnection a function that returns an open Connection.

* The RDD takes care of closing the connection.

* @param sql the text of the query.

* The query must contain two ? placeholders for parameters used to partition the results.

* E.g. "select title, author from books where ? <= id and id <= ?"

* @param lowerBound the minimum value of the first placeholder

* @param upperBound the maximum value of the second placeholder

* The lower and upper bounds are inclusive.

* @param numPartitions the number of partitions.

* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,

* the query would be executed twice, once with (1, 10) and once with (11, 20)

* @param mapRow a function from a ResultSet to a single row of the desired result type(s).

* This should only call getInt, getString, etc; the RDD takes care of calling next.

* The default maps a ResultSet to an array of Object.

*/

class CustomizedJdbcRDD[T: ClassTag](

sc: SparkContext,

getConnection: () => Connection,

sql: String,

getCustomizedPartitions: () => Array[Partition],

prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement,

mapRow: (ResultSet) => T = CustomizedJdbcRDD.resultSetToObjectArray _)

extends RDD[T](sc, Nil) with Logging {

override def getPartitions: Array[Partition] = {

getCustomizedPartitions();

}

override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {

context.addTaskCompletionListener{ context => closeIfNeeded() }

val part = thePart.asInstanceOf[CustomizedJdbcPartition]

val conn = getConnection()

val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,

// rather than pulling entire resultset into memory.

// see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
try {

if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {

stmt.setFetchSize(Integer.MIN_VALUE)

logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")

}

} catch {

case ex: Exception => {

//ex.printStackTrace();

}

}

prepareStatement(stmt, part)

val rs = stmt.executeQuery()

override def getNext: T = {

if (rs.next()) {

mapRow(rs)

} else {

finished = true

null.asInstanceOf[T]

}

}

override def close() {

try {

if (null != rs && ! rs.isClosed()) {

rs.close()

}

} catch {

case e: Exception => logWarning("Exception closing resultset", e)

}

try {

if (null != stmt && ! stmt.isClosed()) {

stmt.close()

}

} catch {

case e: Exception => logWarning("Exception closing statement", e)

}

try {

if (null != conn && ! conn.isClosed()) {

conn.close()

}

logInfo("closed connection")

} catch {

case e: Exception => logWarning("Exception closing connection", e)

}

}

}

}

object CustomizedJdbcRDD {

def resultSetToObjectArray(rs: ResultSet): Array[Object] = {

Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))

}

trait ConnectionFactory extends Serializable {

@throws[Exception]

def getConnection: Connection

}

/**

* Create an RDD that executes an SQL query on a JDBC connection and reads results.

* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.

*

* @param connectionFactory a factory that returns an open Connection.

* The RDD takes care of closing the connection.

* @param sql the text of the query.

* The query must contain two ? placeholders for parameters used to partition the results.

* E.g. "select title, author from books where ? <= id and id <= ?"

* @param lowerBound the minimum value of the first placeholder

* @param upperBound the maximum value of the second placeholder

* The lower and upper bounds are inclusive.

* @param numPartitions the number of partitions.

* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,

* the query would be executed twice, once with (1, 10) and once with (11, 20)

* @param mapRow a function from a ResultSet to a single row of the desired result type(s).

* This should only call getInt, getString, etc; the RDD takes care of calling next.

* The default maps a ResultSet to an array of Object.

*/

def create[T](

sc: JavaSparkContext,

connectionFactory: ConnectionFactory,

sql: String,

getCustomizedPartitions: () => Array[Partition],

prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement,

mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {

val jdbcRDD = new CustomizedJdbcRDD[T](

sc.sc,

() => connectionFactory.getConnection,

sql,

getCustomizedPartitions,

prepareStatement,

(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)

new JavaRDD[T](jdbcRDD)(fakeClassTag)

}

/**

* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is

* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.

*

* @param connectionFactory a factory that returns an open Connection.

* The RDD takes care of closing the connection.

* @param sql the text of the query.

* The query must contain two ? placeholders for parameters used to partition the results.

* E.g. "select title, author from books where ? <= id and id <= ?"

* @param lowerBound the minimum value of the first placeholder

* @param upperBound the maximum value of the second placeholder

* The lower and upper bounds are inclusive.

* @param numPartitions the number of partitions.

* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,

* the query would be executed twice, once with (1, 10) and once with (11, 20)

*/

def create(

sc: JavaSparkContext,

connectionFactory: ConnectionFactory,

sql: String,

getCustomizedPartitions: () => Array[Partition],

prepareStatement: (PreparedStatement, CustomizedJdbcPartition) => PreparedStatement): JavaRDD[Array[Object]] = {

val mapRow = new JFunction[ResultSet, Array[Object]] {

override def call(resultSet: ResultSet): Array[Object] = {

resultSetToObjectArray(resultSet)

}

}

create(sc, connectionFactory, sql, getCustomizedPartitions, prepareStatement, mapRow)

}

}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

下面是一段简单的测试代码:

package org.apache.spark

import java.sql.Connection

import java.sql.DriverManager

import org.apache.spark.rdd.CustomizedJdbcRDD

import org.apache.spark.rdd.CustomizedJdbcPartition

import java.sql.PreparedStatement

object HiveRDDTest {

private val driverName = "org.apache.hive.jdbc.HiveDriver";

private val tableName = "COLLECT_DATA";

private var connection: Connection = null;

def main(args: Array[String]): Unit = {

val conf = new SparkConf().setAppName("HiveRDDTest").setMaster("local[2]");

val sc = new SparkContext(conf);

Class.forName(driverName);

val data = new CustomizedJdbcRDD(sc,

//创建获取JDBC连接函数

() => {

DriverManager.getConnection("jdbc:hive2://192.168.31.135:10000/default", "spark", "");

},

//设置查询SQL

"select * from collect_data where host=?",

//创建分区函数

() => {

val partitions=new Array[Partition](1);

var parameters=Map[String, Object]();

parameters+=("host" -> "172.18.26.11");

val partition=new CustomizedJdbcPartition(0, parameters);

partitions(0)=partition;

partitions;

},

//为每个分区设置查询条件(基于上面设置的SQL语句)

(stmt:PreparedStatement, partition:CustomizedJdbcPartition) => {

stmt.setString(1, partition.asInstanceOf[CustomizedJdbcPartition]

.partitionParameters.get("host").get.asInstanceOf[String])

stmt;

}

);

println(data.count());

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: