GROUP BY in pySpark

By Niraj Zade | 2024-01-22 | Tags: guide sql


23# Key takeaway

Use groupBy() to create grouped rows. Then use agg() for doing calculations on those grouped rows

import pyspark.sql.functions as F

statistics_df = (
    df.groupBy(F.col("location"))
    .agg(
        F.sum(F.col("amount")).alias("total_amount"),
        F.avg(F.col("amount")).alias("avg_amount"),
    )
)
statistics_df.show()

Output:

+--------+------------+-----------------+
|location|total_amount|       avg_amount|
+--------+------------+-----------------+
|  Austin|         600|            200.0|
| Toronto|        1000|            250.0|
|   Tokyo|         250|83.33333333333333|
+--------+------------+-----------------+

Introduction

Consider a scenario: Group the given data by location, and then calculate statistics on it.

Example data -

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

There are 2 ways to do Group By in pySpark:

  1. Using groupBy() followed by agg() to calculate aggregate - recommended
  2. Using groupBy() followed by aggregation function - not recommended

I'm solving this scenario in both ways.

Create dataframe for examples

I'm using Spark Jupyter notebook for this example. You can use databricks or whatever you have got.

  1. Import modules
  2. Create spark session (not needed for databricks)
  3. Create dataframe
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]

df = spark.createDataFrame(data = data, schema = schema)
df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

groupBy() followed by agg() - recommended

This is the recommended method.

First create grouped dataframe using groupBy(), then calculate as many aggregations as needed using agg().

Example - Find group by location, and find the total and average amount for each location.

import pyspark.sql.functions as F

statistics_df = (
    df.groupBy(F.col("location"))
    .agg(
        F.sum(F.col("amount")).alias("total_amount"),
        F.avg(F.col("amount")).alias("avg_amount"),
    )
)
statistics_df.show()

Output:

+--------+------------+-----------------+------------+
|location|total_amount|       avg_amount|avg_discount|
+--------+------------+-----------------+------------+
|  Austin|         600|            200.0|         100|
| Toronto|        1000|            250.0|         200|
|   Tokyo|         250|83.33333333333333|           0|
+--------+------------+-----------------+------------+

groupBy() directly followed by an aggregation function - not recommended

The groupBy() function returns object of type GroupedData, which cannot be used by most functions. It has to be followed up with an aggregate function, so that we can again get back a dataframe object (that can be fed to other functions).

Example - Group by location and the calculate average amount per location

import pyspark.sql.functions as F

locationwise_amounts_df = df.groupBy(F.col("location")).avg("amount")
locationwise_amounts_df.show()

Output

+--------+-----------------+
|location|      avg(amount)|
+--------+-----------------+
|  Austin|            200.0|
| Toronto|            250.0|
|   Tokyo|83.33333333333333|
+--------+-----------------+

Aggregate functions that can be used

These are the functions you can use for calculations:

  • count() - Note - for count, you can also use all columns. Eg - count("*")
  • avg()
  • sum()
  • mean()
  • max()

All these functions take column name as argument.

Best practice

Use groupBy() to group columns, followed by agg() to calculate statistics.

Don't use the groupBy() directly followed by aggregation functions like sum(), min() etc.

Performance considerations

groupBy is a wide transformation. It will cause shuffle. You can minimize the data being shuffle between executors by using bucketing or partitioning. When you apply groupBy on the bucketed or partitioned column, the columns of the respective groups will already be together in their respective executors. So, it will show up as a shuffle-exchange step, but no data will be actually shuffled.