Pyspark: Split multiple array columns into rows
Spark >= 2.4
You can replace zip_
udf
with arrays_zip
function
from pyspark.sql.functions import arrays_zip, col, explode(df .withColumn("tmp", arrays_zip("b", "c")) .withColumn("tmp", explode("tmp")) .select("a", col("tmp.b"), col("tmp.c"), "d"))
Spark < 2.4
With DataFrames
and UDF:
from pyspark.sql.types import ArrayType, StructType, StructField, IntegerTypefrom pyspark.sql.functions import col, udf, explodezip_ = udf( lambda x, y: list(zip(x, y)), ArrayType(StructType([ # Adjust types to reflect data types StructField("first", IntegerType()), StructField("second", IntegerType()) ])))(df .withColumn("tmp", zip_("b", "c")) # UDF output cannot be directly passed to explode .withColumn("tmp", explode("tmp")) .select("a", col("tmp.first").alias("b"), col("tmp.second").alias("c"), "d"))
With RDDs
:
(df .rdd .flatMap(lambda row: [(row.a, b, c, row.d) for b, c in zip(row.b, row.c)]) .toDF(["a", "b", "c", "d"]))
Both solutions are inefficient due to Python communication overhead. If data size is fixed you can do something like this:
from functools import reducefrom pyspark.sql import DataFrame# Length of arrayn = 3# For legacy Python you'll need a separate function# in place of method accessor reduce( DataFrame.unionAll, (df.select("a", col("b").getItem(i), col("c").getItem(i), "d") for i in range(n))).toDF("a", "b", "c", "d")
or even:
from pyspark.sql.functions import array, struct# SQL level zip of arrays of known size# followed by explodetmp = explode(array(*[ struct(col("b").getItem(i).alias("b"), col("c").getItem(i).alias("c")) for i in range(n)]))(df .withColumn("tmp", tmp) .select("a", col("tmp").getItem("b"), col("tmp").getItem("c"), "d"))
This should be significantly faster compared to UDF or RDD. Generalized to support an arbitrary number of columns:
# This uses keyword only arguments# If you use legacy Python you'll have to change signature# Body of the function can stay the samedef zip_and_explode(*colnames, n): return explode(array(*[ struct(*[col(c).getItem(i).alias(c) for c in colnames]) for i in range(n) ]))df.withColumn("tmp", zip_and_explode("b", "c", n=3))
You'd need to use flatMap
, not map
as you want to make multiple output rows out of each input row.
from pyspark.sql import Rowdef dualExplode(r): rowDict = r.asDict() bList = rowDict.pop('b') cList = rowDict.pop('c') for b,c in zip(bList, cList): newDict = dict(rowDict) newDict['b'] = b newDict['c'] = c yield Row(**newDict)df_split = sqlContext.createDataFrame(df.rdd.flatMap(dualExplode))
One liner (for Spark>=2.4.0):
df.withColumn("bc", arrays_zip("b","c")) .select("a", explode("bc").alias("tbc")) .select("a", col"tbc.b", "tbc.c").show()
Import required:
from pyspark.sql.functions import arrays_zip
Steps -
- Create a column bc which is an
array_zip
of columnsb
andc
- Explode
bc
to get a structtbc
- Select the required columns
a
,b
andc
(all exploded as required).
Output:
> df.withColumn("bc", arrays_zip("b","c")).select("a", explode("bc").alias("tbc")).select("a", "tbc.b", col("tbc.c")).show()+---+---+---+| a| b| c|+---+---+---+| 1| 1| 7|| 1| 2| 8|| 1| 3| 9|+---+---+---+