您的位置:首页 > 其它

spark streaming updateStateByKey 用法

2015-08-14 19:12 495 查看
updateStateByKey 解释:

以DStream中的数据进行按key做reduce操作,然后对各个批次的数据进行累加

在有新的数据信息进入或更新时,可以让用户保持想要的任何状。使用这个功能需要完成两步:

1) 定义状态:可以是任意数据类型

2) 定义状态更新函数:用一个函数指定如何使用先前的状态,从输入流中的新值更新状态。

对于有状态操作,要不断的把当前和历史的时间切片的RDD累加计算,随着时间的流失,计算的数据规模会变得越来越大。

updateStateByKey源码:

/**

Return a new “state” DStream where the state for each key is updated by applying

the given function on the previous state of the key and the new values of the key.

org.apache.spark.Partitioner is used to control the partitioning of each RDD.

@param updateFunc State update function. If
this
function returns None, then

corresponding state key-value pair will be eliminated.

@param partitioner Partitioner for controlling the partitioning of each RDD in the new

DStream.

@param initialRDD initial state value of each key.

@tparam S State type

*/

def updateStateByKey[S: ClassTag](

updateFunc: (Seq[V], Option[S]) => Option[S],

partitioner: Partitioner,

initialRDD: RDD[(K, S)]

): DStream[(K, S)] = {

val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {

iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))

}

updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)

}

代码实现

StatefulNetworkWordCount

object StatefulNetworkWordCount {
def main(args: Array[String]) {
if (args.length < 2) {
  System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>")
  System.exit(1)
}

Logger.getLogger("org.apache.spark").setLevel(Level.WARN)

val updateFunc = (values: Seq[Int], state: Option[Int]) => {
  val currentCount = values.sum

  val previousCount = state.getOrElse(0)

  Some(currentCount + previousCount)
}

val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
  iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}

val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount").setMaster("local")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")

// Initial RDD input to updateStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

// Create a ReceiverInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))

// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
  new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
}
}


NetworkWordCount

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._

object NetworkWordCount {
  def main(args: Array[String]) {
    if (args.length < 2) {
      System.err.println("Usage: NetworkWordCount <hostname> <port>")
      System.exit(1)
    }

    val sparkConf = new SparkConf().setAppName("NetworkWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(10))
    //使用updateStateByKey前需要设置checkpoint
    ssc.checkpoint("hdfs://master:8020/spark/checkpoint")

    val addFunc = (currValues: Seq[Int], prevValueState: Option[Int]) => {
      //通过Spark内部的reduceByKey按key规约,然后这里传入某key当前批次的Seq/List,再计算当前批次的总和
      val currentCount = currValues.sum
      // 已累加的值
      val previousCount = prevValueState.getOrElse(0)
      // 返回累加后的结果,是一个Option[Int]类型
      Some(currentCount + previousCount)
    }

    val lines = ssc.socketTextStream(args(0), args(1).toInt)
    val words = lines.flatMap(_.split(" "))
    val pairs = words.map(word => (word, 1))

    //val currWordCounts = pairs.reduceByKey(_ + _)
    //currWordCounts.print()

    val totalWordCounts = pairs.updateStateByKey[Int](addFunc)
    totalWordCounts.print()

    ssc.start()
    ssc.awaitTermination()
  }
}


WebPagePopularityValueCalculator

object WebPagePopularityValueCalculator {

  private val checkpointDir = "popularity-data-checkpoint"
  private val msgConsumerGroup = "user-behavior-topic-message-consumer-group"

  def main(args: Array[String]) {

    if (args.length < 2) {
      println("Usage:WebPagePopularityValueCalculator zkserver1:2181, zkserver2: 2181, zkserver3: 2181 consumeMsgDataTimeInterval (secs) ")
      System.exit(1)
    }

    val Array(zkServers, processingInterval) = args
    val conf = new SparkConf().setAppName("Web Page Popularity Value Calculator")

    val ssc = new StreamingContext(conf, Seconds(processingInterval.toInt))
    //using updateStateByKey asks for enabling checkpoint
    ssc.checkpoint(checkpointDir)

    val kafkaStream = KafkaUtils.createStream(
      //Spark streaming context
      ssc,
      //zookeeper quorum. e.g zkserver1:2181,zkserver2:2181,...
      zkServers,
      //kafka message consumer group ID
      msgConsumerGroup,
      //Map of (topic_name -> numPartitions) to consume. Each partition is consumed in its own thread
      Map("user-behavior-topic" -> 3))
    val msgDataRDD = kafkaStream.map(_._2)

    //for debug use only
    //println("Coming data in this interval...")
    //msgDataRDD.print()
    // e.g page37|5|1.5119122|-1
    val popularityData = msgDataRDD.map { msgLine => {
      val dataArr: Array[String] = msgLine.split("\\|")
      val pageID = dataArr(0)
      //calculate the popularity value
      val popValue: Double = dataArr(1).toFloat * 0.8 + dataArr(2).toFloat * 0.8 + dataArr(3).toFloat * 1
      (pageID, popValue)
     }
    }

    //sum the previous popularity value and current value
    //定义一个匿名函数去把网页热度上一次的计算结果值和新计算的值相加,得到最新的热度值。
    val updatePopularityValue = (iterator: Iterator[(String, Seq[Double], Option[Double])]) => {
      iterator.flatMap(t => {
        val newValue: Double = t._2.sum
        val stateValue: Double = t._3.getOrElse(0);
        Some(newValue + stateValue)
      }.map(sumedValue => (t._1, sumedValue)))
    }

    val initialRDD = ssc.sparkContext.parallelize(List(("page1", 0.00)))

    //调用 updateStateByKey 原语并传入上面定义的匿名函数更新网页热度值。
    val stateDStream = popularityData.updateStateByKey[Double](updatePopularityValue,
      new HashPartitioner(ssc.sparkContext.defaultParallelism), true, initialRDD)

    //set the checkpoint interval to avoid too frequently data checkpoint which may
    //may significantly reduce operation throughput
    stateDStream.checkpoint(Duration(8 * processingInterval.toInt * 1000))

    //after calculation, we need to sort the result and only show the top 10 hot pages
    //最后得到最新结果后,需要对结果进行排序,最后打印热度值最高的 10 个网页。
    stateDStream.foreachRDD { rdd => {
      val sortedData = rdd.map { case (k, v) => (v, k) }.sortByKey(false)
      val topKData = sortedData.take(10).map { case (v, k) => (k, v) }
      topKData.foreach(x => {
        println(x)
      })
     }
    }

    ssc.start()
    ssc.awaitTermination()
  }
}


参考文章:

http://blog.cloudera.com/blog/2014/11/how-to-do-near-real-time-sessionization-with-spark-streaming-and-apache-hadoop/

https://github.com/apache/spark/blob/branch-1.3/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala

https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala

http://stackoverflow.com/questions/28998408/spark-streaming-example-calls-updatestatebykey-with-additional-parameters

http://stackoverflow.com/questions/27535668/spark-streaming-groupbykey-and-updatestatebykey-implementation

尊重原创,未经允许不得转载:

/article/1810552.html
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: