efficiently implementing takeByKey for spark
Your current solution is a step in the right direction but it is still quite inefficient for at least three reasons:
mapValues(List(_))
creates a huge number of temporaryList
objectslength
for linearSeq
likeList
is O(N)x ++ y
once again creates a large number of temporary objects
The simplest you can include is to replace List
with mutable buffer with constant time length
. One possible example would be something like this:
import scala.collection.mutable.ArrayBufferrdd.aggregateByKey(ArrayBuffer[Int]())( (acc, x) => if (acc.length >= n) acc else acc += x, (acc1, acc2) => { val (xs, ys) = if (acc1.length > acc2.length) (acc1, acc2) else (acc2, acc1) val toTake = Math.min(n - xs.length, ys.length) for (i <- 0 until toTake) { xs += ys(i) } xs })
On a side note you can replace:
.flatMap(t => t._2.map(v => (t._1, v)))
with
.flatMapValues(x => x) // identity[Seq[V]]
It won't affect performance but it is slightly cleaner.
Here is the best solution I came up with so far
takeByKey(rdd: RDD[(K,V)], n: Int) : RDD[(K,V)] = { rdd.mapValues(List(_)) .reduceByKey((x,y) => if(x.length >= n) x else if(y.length >= n) y else (x ++ y).take(n)) .flatMap(t => t._2.map(v => (t._1, v)))}
It doesn't run out of memory and die like the groupByKey approach does, but it is still slow.