Key takeaway
Use groupBy()
to create grouped rows. Then use agg()
for doing calculations on those grouped rows
import pyspark.sql.functions as F
statistics_df = (
df.groupBy(F.col("location"))
.agg(
F.sum(F.col("amount")).alias("total_amount"),
F.avg(F.col("amount")).alias("avg_amount"),
)
)
statistics_df.show()
Output:
+--------+------------+-----------------+
|location|total_amount| avg_amount|
+--------+------------+-----------------+
| Austin| 600| 200.0|
| Toronto| 1000| 250.0|
| Tokyo| 250|83.33333333333333|
+--------+------------+-----------------+
Introduction
Consider a scenario: Group the given data by location
, and then calculate statistics on it.
Example data -
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
There are 2 ways to do Group By in pySpark:
- Using
groupBy()
followed byagg()
to calculate aggregate - recommended - Using
groupBy()
followed by aggregation function - not recommended
I'm solving this scenario in both ways.
Create dataframe for examples
I'm using Spark Jupyter notebook for this example. You can use databricks or whatever you have got.
- Import modules
- Create spark session (not needed for databricks)
- Create dataframe
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'
# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F
# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
data = [
(1, "Alice", "Austin", 100),
(2, "Bob", "Austin", 200),
(3, "Chris", "Austin", 300),
(4, "Dave", "Toronto", 400),
(5, "Elisa", "Toronto", 300),
(6, "Fabrice", "Toronto", 200),
(7, "Girard", "Toronto", 100),
(8, "Hal", "Tokyo", 50),
(9, "Ignis", "Tokyo", 100),
(10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]
df = spark.createDataFrame(data = data, schema = schema)
df.show()
Output:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
groupBy() followed by agg() - recommended
This is the recommended method.
First create grouped dataframe using groupBy()
, then calculate as many aggregations as needed using agg()
.
Example - Find group by location, and find the total and average amount for each location.
import pyspark.sql.functions as F
statistics_df = (
df.groupBy(F.col("location"))
.agg(
F.sum(F.col("amount")).alias("total_amount"),
F.avg(F.col("amount")).alias("avg_amount"),
)
)
statistics_df.show()
Output:
+--------+------------+-----------------+------------+
|location|total_amount| avg_amount|avg_discount|
+--------+------------+-----------------+------------+
| Austin| 600| 200.0| 100|
| Toronto| 1000| 250.0| 200|
| Tokyo| 250|83.33333333333333| 0|
+--------+------------+-----------------+------------+
groupBy() directly followed by an aggregation function - not recommended
The groupBy()
function returns object of type GroupedData
, which cannot be used by most functions. It has to be followed up with an aggregate function, so that we can again get back a dataframe object (that can be fed to other functions).
Example - Group by location and the calculate average amount per location
import pyspark.sql.functions as F
locationwise_amounts_df = df.groupBy(F.col("location")).avg("amount")
locationwise_amounts_df.show()
Output
+--------+-----------------+
|location| avg(amount)|
+--------+-----------------+
| Austin| 200.0|
| Toronto| 250.0|
| Tokyo|83.33333333333333|
+--------+-----------------+
Aggregate functions that can be used
These are the functions you can use for calculations:
count()
- Note - for count, you can also use all columns. Eg -count("*")
avg()
sum()
mean()
max()
All these functions take column name as argument.
Best practice
Use groupBy()
to group columns, followed by agg()
to calculate statistics.
Don't use the groupBy()
directly followed by aggregation functions like sum()
, min()
etc.
Performance considerations
groupBy
is a wide transformation. It will cause shuffle. You can minimize the data being shuffle between executors by using bucketing or partitioning. When you apply groupBy
on the bucketed or partitioned column, the columns of the respective groups will already be together in their respective executors. So, it will show up as a shuffle-exchange step, but no data will be actually shuffled.