efficiently implementing takeByKey for spark efficiently implementing takeByKey for spark hadoop hadoop

efficiently implementing takeByKey for spark


Your current solution is a step in the right direction but it is still quite inefficient for at least three reasons:

  • mapValues(List(_)) creates a huge number of temporary List objects
  • length for linear Seq like List is O(N)
  • x ++ y once again creates a large number of temporary objects

The simplest you can include is to replace List with mutable buffer with constant time length. One possible example would be something like this:

import scala.collection.mutable.ArrayBufferrdd.aggregateByKey(ArrayBuffer[Int]())(  (acc, x) => if (acc.length >= n) acc else acc += x,  (acc1, acc2) => {    val (xs, ys) = if (acc1.length > acc2.length) (acc1, acc2) else (acc2, acc1)    val toTake = Math.min(n - xs.length, ys.length)    for (i <- 0 until toTake) {      xs += ys(i)    }    xs           })

On a side note you can replace:

.flatMap(t => t._2.map(v => (t._1, v)))

with

.flatMapValues(x => x)  // identity[Seq[V]]

It won't affect performance but it is slightly cleaner.


Here is the best solution I came up with so far

takeByKey(rdd: RDD[(K,V)], n: Int) : RDD[(K,V)] = {    rdd.mapValues(List(_))       .reduceByKey((x,y) => if(x.length >= n) x                              else if(y.length >= n) y                              else (x ++ y).take(n))       .flatMap(t => t._2.map(v => (t._1, v)))}

It doesn't run out of memory and die like the groupByKey approach does, but it is still slow.