Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames python python

Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames


Since Spark 1.6 you can use pivot function on GroupedData and provide aggregate expression.

pivoted = (df    .groupBy("ID", "Age")    .pivot(        "Country",        ['US', 'UK', 'CA'])  # Optional list of levels    .sum("Score"))  # alternatively you can use .agg(expr))pivoted.show()## +---+---+---+---+---+## | ID|Age| US| UK| CA|## +---+---+---+---+---+## |X01| 41|  3|  1|  2|## |X02| 72|  4|  6|  7|## +---+---+---+---+---+

Levels can be omitted but if provided can both boost performance and serve as an internal filter.

This method is still relatively slow but certainly beats manual passing data manually between JVM and Python.


First up, this is probably not a good idea, because you are not getting any extra information, but you are binding yourself with a fixed schema (ie you must need to know how many countries you are expecting, and of course, additional country means change in code)

Having said that, this is a SQL problem, which is shown below. But in case you suppose it is not too "software like" (seriously, I have heard this!!), then you can refer the first solution.

Solution 1:

def reshape(t):    out = []    out.append(t[0])    out.append(t[1])    for v in brc.value:        if t[2] == v:            out.append(t[3])        else:            out.append(0)    return (out[0],out[1]),(out[2],out[3],out[4],out[5])def cntryFilter(t):    if t[2] in brc.value:        return t    else:        passdef addtup(t1,t2):    j=()    for k,v in enumerate(t1):        j=j+(t1[k]+t2[k],)    return jdef seq(tIntrm,tNext):    return addtup(tIntrm,tNext)def comb(tP,tF):    return addtup(tP,tF)countries = ['CA', 'UK', 'US', 'XX']brc = sc.broadcast(countries)reshaped = calls.filter(cntryFilter).map(reshape)pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)for i in pivot.collect():    print i

Now, Solution 2: Of course better as SQL is right tool for this

callRow = calls.map(lambda t:   Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))callsDF = ssc.createDataFrame(callRow)callsDF.printSchema()callsDF.registerTempTable("calls")res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\                    from (select userid,age,\                                  case when country='CA' then nbrCalls else 0 end ca,\                                  case when country='UK' then nbrCalls else 0 end uk,\                                  case when country='US' then nbrCalls else 0 end us,\                                  case when country='XX' then nbrCalls else 0 end xx \                             from calls) x \                     group by userid,age")res.show()

data set up:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)] calls = sc.parallelize(data,1)countries = ['CA', 'UK', 'US', 'XX']

Result:

From 1st solution

(('X02', 72), (7, 6, 4, 8)) (('X01', 41), (2, 1, 3, 0))

From 2nd solution:

root  |-- age: long (nullable = true)        |-- country: string (nullable = true)        |-- nbrCalls: long (nullable = true)        |-- userid: string (nullable = true)userid age ca uk us xx  X02    72  7  6  4  8   X01    41  2  1  3  0

Kindly let me know if this works, or not :)

BestAyan


Here's a native Spark approach that doesn't hardwire the column names. It's based on aggregateByKey, and uses a dictionary to collect the columns that appear for each key. Then we gather all the column names to create the final dataframe. [Prior version used jsonRDD after emitting a dictionary for each record, but this is more efficient.] Restricting to a specific list of columns, or excluding ones like XX would be an easy modification.

The performance seems good even on quite large tables. I'm using a variation which counts the number of times that each of a variable number of events occurs for each ID, generating one column per event type. The code is basically the same except it uses a collections.Counter instead of a dict in the seqFn to count the occurrences.

from pyspark.sql.types import *rdd = sc.parallelize([('X01',41,'US',3),                       ('X01',41,'UK',1),                       ('X01',41,'CA',2),                       ('X02',72,'US',4),                       ('X02',72,'UK',6),                       ('X02',72,'CA',7),                       ('X02',72,'XX',8)])schema = StructType([StructField('ID', StringType(), True),                     StructField('Age', IntegerType(), True),                     StructField('Country', StringType(), True),                     StructField('Score', IntegerType(), True)])df = sqlCtx.createDataFrame(rdd, schema)def seqPivot(u, v):    if not u:        u = {}    u[v.Country] = v.Score    return udef cmbPivot(u1, u2):    u1.update(u2)    return u1pivot = (    df    .rdd    .keyBy(lambda row: row.ID)    .aggregateByKey(None, seqPivot, cmbPivot))columns = (    pivot    .values()    .map(lambda u: set(u.keys()))    .reduce(lambda s,t: s.union(t)))result = sqlCtx.createDataFrame(    pivot    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),    schema=StructType(        [StructField('ID', StringType())] +         [StructField(c, IntegerType()) for c in columns]    ))result.show()

Produces:

ID  CA UK US XX  X02 7  6  4  8   X01 2  1  3  null