Spark dataframe transform multiple rows to column Spark dataframe transform multiple rows to column python python

Spark dataframe transform multiple rows to column


Using zero323's dataframe,

df = sqlContext.createDataFrame([("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),("e", 1, "m4"), ("e", 1, "m5")], ("a", "cnt", "major"))

you could also use

reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)


Lets start with example data:

df = sqlContext.createDataFrame([    ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),    ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),    ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),    ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),    ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),    ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),    ("e", 1, "m4"), ("e", 1, "m5")],     ("a", "cnt", "major"))

Please note that I've changed count to cnt. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.

There are at least two ways to reshape this data:

  • aggregating over DataFrame

    from pyspark.sql.functions import col, when, maxmajors = sorted(df.select("major")    .distinct()    .map(lambda row: row[0])    .collect())cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m)     for m in  majors]maxs = [max(col(m)).alias(m) for m in majors]reshaped1 = (df    .select(col("a"), *cols)    .groupBy("a")    .agg(*maxs)    .na.fill(0))reshaped1.show()## +---+---+---+---+---+---+## |  a| m1| m2| m3| m4| m5|## +---+---+---+---+---+---+## |  a|  1|  1|  2|  3|  0|## |  b|  4|  1|  2|  0|  0|## |  c|  3|  0|  4|  5|  0|## |  d|  6|  1|  2|  3|  4|## |  e|  4|  5|  1|  1|  1|## +---+---+---+---+---+---+
  • groupBy over RDD

    from pyspark.sql import Rowgrouped = (df    .map(lambda row: (row.a, (row.major, row.cnt)))    .groupByKey())def make_row(kv):    k, vs = kv    tmp = dict(list(vs) + [("a", k)])    return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors})reshaped2 = sqlContext.createDataFrame(grouped.map(make_row))reshaped2.show()## +---+---+---+---+---+---+## |  a| m1| m2| m3| m4| m5|## +---+---+---+---+---+---+## |  a|  1|  1|  2|  3|  0|## |  e|  4|  5|  1|  1|  1|## |  c|  3|  0|  4|  5|  0|## |  b|  4|  1|  2|  0|  0|## |  d|  6|  1|  2|  3|  4|## +---+---+---+---+---+---+