How to select the first row of each group?
Window functions:
Something like this should do the trick:
import org.apache.spark.sql.functions.{row_number, max, broadcast}import org.apache.spark.sql.expressions.Windowval df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")dfTop.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// | 0| cat26| 30.9|// | 1| cat67| 28.5|// | 2| cat56| 39.6|// | 3| cat8| 35.6|// +----+--------+----------+
This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future.
Plain SQL aggregation followed by join
:
Alternatively you can join with aggregated data frame:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))val dfTopByJoin = df.join(broadcast(dfMax), ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value")dfTopByJoin.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// | 0| cat26| 30.9|// | 1| cat67| 28.5|// | 2| cat56| 39.6|// | 3| cat8| 35.6|// +----+--------+----------+
It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:
dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"), first("TotalValue").alias("TotalValue"))
Using ordering over structs
:
Neat, although not very well tested, trick which doesn't require joins or window functions:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour", $"vs.Category", $"vs.TotalValue")dfTop.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// | 0| cat26| 30.9|// | 1| cat67| 28.5|// | 2| cat56| 39.6|// | 3| cat8| 35.6|// +----+--------+----------+
With DataSet API (Spark 1.6+, 2.0+):
Spark 1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)df.as[Record] .groupBy($"hour") .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y) .show// +---+--------------+// | _1| _2|// +---+--------------+// |[0]|[0,cat26,30.9]|// |[1]|[1,cat67,28.5]|// |[2]|[2,cat56,39.6]|// |[3]| [3,cat8,35.6]|// +---+--------------+
Spark 2.0 or later:
df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed
output mode.
Don't use:
df.orderBy(...).groupBy(...).agg(first(...), ...)
It may seem to work (especially in the local
mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).
The same note applies to
df.orderBy(...).dropDuplicates(...)
which internally uses equivalent execution plan.
For Spark 2.0.2 with grouping by multiple columns:
import org.apache.spark.sql.functions.row_numberimport org.apache.spark.sql.expressions.Windowval w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
This is a exact same of zero323's answer but in SQL query way.
Assuming that dataframe is created and registered as
df.createOrReplaceTempView("table")//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|0 |cat26 |30.9 |//|0 |cat13 |22.1 |//|0 |cat95 |19.6 |//|0 |cat105 |1.3 |//|1 |cat67 |28.5 |//|1 |cat4 |26.8 |//|1 |cat13 |12.6 |//|1 |cat23 |5.3 |//|2 |cat56 |39.6 |//|2 |cat40 |29.7 |//|2 |cat187 |27.9 |//|2 |cat68 |9.8 |//|3 |cat8 |35.6 |//+----+--------+----------+
Window function :
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1 |cat67 |28.5 |//|3 |cat8 |35.6 |//|2 |cat56 |39.6 |//|0 |cat26 |30.9 |//+----+--------+----------+
Plain SQL aggregation followed by join:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " + "(select Hour, Category, TotalValue from table tmp1 " + "join " + "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " + "on " + "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " + "group by tmp3.Hour") .show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1 |cat67 |28.5 |//|3 |cat8 |35.6 |//|2 |cat56 |39.6 |//|0 |cat26 |30.9 |//+----+--------+----------+
Using ordering over structs:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1 |cat67 |28.5 |//|3 |cat8 |35.6 |//|2 |cat56 |39.6 |//|0 |cat26 |30.9 |//+----+--------+----------+
DataSets way and don't dos are same as in original answer