How to select the first row of each group? How to select the first row of each group? sql sql

How to select the first row of each group?


Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}import org.apache.spark.sql.expressions.Windowval df = sc.parallelize(Seq(  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")dfTop.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// |   0|   cat26|      30.9|// |   1|   cat67|      28.5|// |   2|   cat56|      39.6|// |   3|    cat8|      35.6|// +----+--------+----------+

This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future.

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))val dfTopByJoin = df.join(broadcast(dfMax),    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))  .drop("max_hour")  .drop("max_value")dfTopByJoin.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// |   0|   cat26|      30.9|// |   1|   cat67|      28.5|// |   2|   cat56|      39.6|// |   3|    cat8|      35.6|// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

dfTopByJoin  .groupBy($"hour")  .agg(    first("category").alias("category"),    first("TotalValue").alias("TotalValue"))

Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))  .groupBy($"hour")  .agg(max("vs").alias("vs"))  .select($"Hour", $"vs.Category", $"vs.TotalValue")dfTop.show// +----+--------+----------+// |Hour|Category|TotalValue|// +----+--------+----------+// |   0|   cat26|      30.9|// |   1|   cat67|      28.5|// |   2|   cat56|      39.6|// |   3|    cat8|      35.6|// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)df.as[Record]  .groupBy($"hour")  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)  .show// +---+--------------+// | _1|            _2|// +---+--------------+// |[0]|[0,cat26,30.9]|// |[1]|[1,cat67,28.5]|// |[2]|[2,cat56,39.6]|// |[3]| [3,cat8,35.6]|// +---+--------------+

Spark 2.0 or later:

df.as[Record]  .groupByKey(_.Hour)  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to

df.orderBy(...).dropDuplicates(...)

which internally uses equivalent execution plan.


For Spark 2.0.2 with grouping by multiple columns:

import org.apache.spark.sql.functions.row_numberimport org.apache.spark.sql.expressions.Windowval w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")


This is a exact same of zero323's answer but in SQL query way.

Assuming that dataframe is created and registered as

df.createOrReplaceTempView("table")//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|0   |cat26   |30.9      |//|0   |cat13   |22.1      |//|0   |cat95   |19.6      |//|0   |cat105  |1.3       |//|1   |cat67   |28.5      |//|1   |cat4    |26.8      |//|1   |cat13   |12.6      |//|1   |cat23   |5.3       |//|2   |cat56   |39.6      |//|2   |cat40   |29.7      |//|2   |cat187  |27.9      |//|2   |cat68   |9.8       |//|3   |cat8    |35.6      |//+----+--------+----------+

Window function :

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1   |cat67   |28.5      |//|3   |cat8    |35.6      |//|2   |cat56   |39.6      |//|0   |cat26   |30.9      |//+----+--------+----------+

Plain SQL aggregation followed by join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +  "(select Hour, Category, TotalValue from table tmp1 " +  "join " +  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +  "on " +  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +  "group by tmp3.Hour")  .show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1   |cat67   |28.5      |//|3   |cat8    |35.6      |//|2   |cat56   |39.6      |//|0   |cat26   |30.9      |//+----+--------+----------+

Using ordering over structs:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)//+----+--------+----------+//|Hour|Category|TotalValue|//+----+--------+----------+//|1   |cat67   |28.5      |//|3   |cat8    |35.6      |//|2   |cat56   |39.6      |//|0   |cat26   |30.9      |//+----+--------+----------+

DataSets way and don't dos are same as in original answer