GROUP BY in pySpark

By Niraj Zade | 2024-01-22   | Tags: guide sql


Introduction

Scenario - Group the given data by location, and then calculate statistics on it.

Example data -

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

There are 2 ways to do Group By in pySpark:

  1. Using groupBy() followed by agg() to calculate aggregate - recommended
  2. Using groupBy() followed by aggregation function - not recommended

I'm solving this scenario in both ways.

Initial Setup

I'm using Jupyter notebook for my setup.

It is a 2 step setup process:

  1. Import modules, create spark session
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
  1. Create dataframe:
data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]

df = spark.createDataFrame(data = data, schema = schema)
df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

groupBy() followed by agg() to calculate aggregate - recommended

This is the recommended method.

Using agg(), you can apply as many aggregations as needed on the output of a groupBy() Use agg() when you apply multiple aggregations on grouped data. You can also set alias to the grouped data.

Example - Find group by location, and find the total and average amount for each location.

import pyspark.sql.functions as F

statistics_df = (
    df.groupBy(F.col("location"))
    .agg(
        F.sum(F.col("amount")).alias("total_amount"),
        F.avg(F.col("amount")).alias("avg_amount"),
    )
)
statistics_df.show()

Output:

+--------+------------+-----------------+------------+
|location|total_amount|       avg_amount|avg_discount|
+--------+------------+-----------------+------------+
|  Austin|         600|            200.0|         100|
| Toronto|        1000|            250.0|         200|
|   Tokyo|         250|83.33333333333333|           0|
+--------+------------+-----------------+------------+

groupBy() directly followed by aggregation function - not recommended

The groupBy() function returns object of type GroupedData, which cannot be used by most functions. You have follow up the group by with some aggregate function, so that you can again get back a dataframe object, that can be fed to other functions.

There is one major limitation to this method - you cannot set the alias of the generated column. So, this method is not recommended.

Example - Group by location and the calculate average amount per location

import pyspark.sql.functions as F

locationwise_amounts_df = df.groupBy(F.col("location")).avg("amount")
locationwise_amounts_df.show()

Output

+--------+-----------------+
|location|      avg(amount)|
+--------+-----------------+
|  Austin|            200.0|
| Toronto|            250.0|
|   Tokyo|83.33333333333333|
+--------+-----------------+

Aggregate functions you can use:

These are the functions you can use for calculations:

  • count() - Note - for count, you can also use all columns. Eg - count("*")
  • avg()
  • sum()
  • mean()
  • max()
  • pivot()

All these functions take column name as argument.

Best practice

Use groupBy() to group columns, followed by agg() to calculate statistics.

Don't use the groupBy() directly followed by aggregation functions like sum(), min() etc.


© pySparkGuide.com 2024 | Website was autogenerated on 2024-04-24

Brought to you by Niraj Zade - Website, Linkedin

~ whoever owns storage, owns computing ~