Window functions with pySpark

By Niraj Zade | 2023-12-19 | Tags: guide sql


The flow while using window functions in pySpark is simple:

  1. Create a window
  2. Apply a function on the window

I'm using spark in jupyter. I used this code block to set things up:

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.master("local[*]").getOrCreate()

For the examples, I'll be using this data:

data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "sales_amount"]
df = spark.createDataFrame(data = data, schema = schema)

df.show()

Output:

+---+-------+--------+------------+
| id|   name|location|sales_amount|
+---+-------+--------+------------+
|  1|  Alice|  Austin|         100|
|  2|    Bob|  Austin|         200|
|  3|  Chris|  Austin|         300|
|  4|   Dave| Toronto|         400|
|  5|  Elisa| Toronto|         300|
|  6|Fabrice| Toronto|         200|
|  7| Girard| Toronto|         100|
|  8|    Hal|   Tokyo|          50|
|  9|  Ignis|   Tokyo|         100|
| 10|   John|   Tokyo|         100|
+---+-------+--------+------------+

Create window

There are 2 types of windows -

  1. Unordered windows - the windows whose rows within the partition are not ordered by any column
  2. Ordered windows - the windows whose rows within the partition are ordered by some column

In the spark engine, only aggregate functions accept unordered windows. All other window functions strictly demand ordered windows.

To create a window, there are 2 steps:

  1. Set the column(s) on which you'll partition the window
  2. If you want an ordered window, set the column(s) to use for ordering the rows within each window-partition

Creating an un-ordered window

This is a window function without orderBy()

Syntax:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = Window.partitionBy(F.col("col1"),F.col("col2"),F.col("col3") ... )

Example:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = Window.partitionBy(F.col("location"))

Creating an ordered window

Simply add orderBy() to the created window.

Syntax:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = Window.partitionBy(F.col("col1"),F.col("col2"),F.col("col3") ... ).orderBy(F.col("orderCol1"),F.col("orderCol2"),F.col("orderCol3") ... )

Example:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = Window.partitionBy(F.col("location")).orderBy(F.col("sales_amount"))

Controlling window bounds

I'm not going in depth on this topic. If you don't know window bounds, read up on "UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING in SQL window functions" and "ROWS BETWEEN vs RANGE BETWEEN in SQL".

The default window behaviour is - RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. This can cause problem when using ordered windows with aggregate functions and the last() method.

You control the window using rowsBetween() and rangeBetween():

  • rowsBetween() - Rows preceding and following current row.
  • rangeBetween() - depends on orderBy() - Rows preceding and following the current row, as well as all the rows with the same values in columns on which the window is created

You can set the bounds of a window using these options:

  • Window.unboundedPreceding
  • Window.unboundedFollowing
  • Window.currentRow

You can also set integers.

  • 0 - current row
  • -2 - 2 rows before current row (negative integers = rows before)
  • 2 - 2 rows after current row (positive integers = rows after)

Example: Create a window with UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = (
    Window
    .partitionBy(F.col("location"))
    .orderBy(F.col("sales_amount"))
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)

Example: Create a bounded window of rows between current row and 3 rows before current row.

This is also called as a rolling window or a running window.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

location_window  = (
    Window
    .partitionBy(F.col("location"))
    .orderBy(F.col("sales_amount"))
    .rowsBetween(-2, Window.currentRow) # Can also do .rowsBetween(-2, 0)
)

Applying the window function

Window functions can be roughly divided into 3 categories. They are:

  1. Aggregate functions (These function don't require ordered windows)
    1. avg()
    2. sum()
    3. min()
    4. max()
  2. Ranking functions - (These function require ordered windows)
    1. row_number()
    2. rank()
    3. dense_rank()
    4. percent_rank()
    5. ntile(int)
  3. Analytical functions - (These function require ordered windows)
    1. cume_dist()
    2. lag(col_name, int)
    3. lead(col_name, int)

Aggregate functions

Aggregate window functions don't require ordered windows. So you can specify window without .orderBy().

If you use a window with order by for calculating aggregate, and don't want to calculate rolling aggregate, make sure you define window bounds to be UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING

Example:

location_window  = (
    Window
    .partitionBy(F.col("location"))
    .orderBy(F.col("sales_amount"))
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)

avg()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location"))

# use window function
df_avg = df.withColumn("avg",F.avg(F.col("sales_amount")).over(location_window))
df_avg.show()

Output:

+---+-------+--------+------------+-----------------+
| id|   name|location|sales_amount|              avg|
+---+-------+--------+------------+-----------------+
|  1|  Alice|  Austin|         100|            200.0|
|  2|    Bob|  Austin|         200|            200.0|
|  3|  Chris|  Austin|         300|            200.0|
|  4|   Dave| Toronto|         400|            250.0|
|  5|  Elisa| Toronto|         300|            250.0|
|  6|Fabrice| Toronto|         200|            250.0|
|  7| Girard| Toronto|         100|            250.0|
|  8|    Hal|   Tokyo|          50|83.33333333333333|
|  9|  Ignis|   Tokyo|         100|83.33333333333333|
| 10|   John|   Tokyo|         100|83.33333333333333|
+---+-------+--------+------------+-----------------+

sum()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location"))

# use window function
df_sum = df.withColumn("sum",F.sum(F.col("sales_amount")).over(location_window))
df_sum.show()

Output:

+---+-------+--------+------------+----+
| id|   name|location|sales_amount| sum|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100| 600|
|  2|    Bob|  Austin|         200| 600|
|  3|  Chris|  Austin|         300| 600|
|  4|   Dave| Toronto|         400|1000|
|  5|  Elisa| Toronto|         300|1000|
|  6|Fabrice| Toronto|         200|1000|
|  7| Girard| Toronto|         100|1000|
|  8|    Hal|   Tokyo|          50| 250|
|  9|  Ignis|   Tokyo|         100| 250|
| 10|   John|   Tokyo|         100| 250|
+---+-------+--------+------------+----+

min()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location"))

# use window function
df_min = df.withColumn("min",F.min(F.col("sales_amount")).over(location_window))
df_min.show()

Output:

+---+-------+--------+------------+---+
| id|   name|location|sales_amount|min|
+---+-------+--------+------------+---+
|  1|  Alice|  Austin|         100|100|
|  2|    Bob|  Austin|         200|100|
|  3|  Chris|  Austin|         300|100|
|  4|   Dave| Toronto|         400|100|
|  5|  Elisa| Toronto|         300|100|
|  6|Fabrice| Toronto|         200|100|
|  7| Girard| Toronto|         100|100|
|  8|    Hal|   Tokyo|          50| 50|
|  9|  Ignis|   Tokyo|         100| 50|
| 10|   John|   Tokyo|         100| 50|
+---+-------+--------+------------+---+

max()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location"))

# use window function
df_max = df.withColumn("max",F.max(F.col("sales_amount")).over(location_window))
df_max.show()

Output:

+---+-------+--------+------------+---+
| id|   name|location|sales_amount|max|
+---+-------+--------+------------+---+
|  1|  Alice|  Austin|         100|300|
|  2|    Bob|  Austin|         200|300|
|  3|  Chris|  Austin|         300|300|
|  4|   Dave| Toronto|         400|400|
|  5|  Elisa| Toronto|         300|400|
|  6|Fabrice| Toronto|         200|400|
|  7| Girard| Toronto|         100|400|
|  8|    Hal|   Tokyo|          50|100|
|  9|  Ignis|   Tokyo|         100|100|
| 10|   John|   Tokyo|         100|100|
+---+-------+--------+------------+---+

Ranking functions

Ranking window functions need the window to be ordered. So, while creating window for ranking functions, you must specify orderBy(). If you don't, spark sql will throw an AnalysisException.

Example -

AnalysisException: Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table.

row_number()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_row_number = df.withColumn("row_number",F.row_number().over(location_window))
df_row_number.show()

Output:

+---+-------+--------+------------+----------+
| id|   name|location|sales_amount|row_number|
+---+-------+--------+------------+----------+
|  1|  Alice|  Austin|         100|         1|
|  2|    Bob|  Austin|         200|         2|
|  3|  Chris|  Austin|         300|         3|
|  7| Girard| Toronto|         100|         1|
|  6|Fabrice| Toronto|         200|         2|
|  5|  Elisa| Toronto|         300|         3|
|  4|   Dave| Toronto|         400|         4|
|  8|    Hal|   Tokyo|          50|         1|
|  9|  Ignis|   Tokyo|         100|         2|
| 10|   John|   Tokyo|         100|         3|
+---+-------+--------+------------+----------+

rank()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_rank = df.withColumn("rank",F.rank().over(location_window))
df_rank.show()

Output:

+---+-------+--------+------------+----+
| id|   name|location|sales_amount|rank|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100|   1|
|  2|    Bob|  Austin|         200|   2|
|  3|  Chris|  Austin|         300|   3|
|  7| Girard| Toronto|         100|   1|
|  6|Fabrice| Toronto|         200|   2|
|  5|  Elisa| Toronto|         300|   3|
|  4|   Dave| Toronto|         400|   4|
|  8|    Hal|   Tokyo|          50|   1|
|  9|  Ignis|   Tokyo|         100|   2|
| 10|   John|   Tokyo|         100|   2|
+---+-------+--------+------------+----+

dense_rank()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_dense_rank = df.withColumn("dense_rank",F.dense_rank().over(location_window))
df_dense_rank.show()

Output:

+---+-------+--------+------------+----------+
| id|   name|location|sales_amount|dense_rank|
+---+-------+--------+------------+----------+
|  1|  Alice|  Austin|         100|         1|
|  2|    Bob|  Austin|         200|         2|
|  3|  Chris|  Austin|         300|         3|
|  7| Girard| Toronto|         100|         1|
|  6|Fabrice| Toronto|         200|         2|
|  5|  Elisa| Toronto|         300|         3|
|  4|   Dave| Toronto|         400|         4|
|  8|    Hal|   Tokyo|          50|         1|
|  9|  Ignis|   Tokyo|         100|         2|
| 10|   John|   Tokyo|         100|         2|
+---+-------+--------+------------+----------+

percent_rank()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_percent_rank = df.withColumn("percent_rank",F.percent_rank().over(location_window))
df_percent_rank.show()

Output:

+---+-------+--------+------------+------------------+
| id|   name|location|sales_amount|      percent_rank|
+---+-------+--------+------------+------------------+
|  1|  Alice|  Austin|         100|               0.0|
|  2|    Bob|  Austin|         200|               0.5|
|  3|  Chris|  Austin|         300|               1.0|
|  7| Girard| Toronto|         100|               0.0|
|  6|Fabrice| Toronto|         200|0.3333333333333333|
|  5|  Elisa| Toronto|         300|0.6666666666666666|
|  4|   Dave| Toronto|         400|               1.0|
|  8|    Hal|   Tokyo|          50|               0.0|
|  9|  Ignis|   Tokyo|         100|               0.5|
| 10|   John|   Tokyo|         100|               0.5|
+---+-------+--------+------------+------------------+

ntile(int)

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_ntile = df.withColumn("ntile",F.ntile(2).over(location_window))
df_ntile.show()

Output:


+---+-------+--------+------------+-----+
| id|   name|location|sales_amount|ntile|
+---+-------+--------+------------+-----+
|  1|  Alice|  Austin|         100|    1|
|  2|    Bob|  Austin|         200|    1|
|  3|  Chris|  Austin|         300|    2|
|  7| Girard| Toronto|         100|    1|
|  6|Fabrice| Toronto|         200|    1|
|  5|  Elisa| Toronto|         300|    2|
|  4|   Dave| Toronto|         400|    2|
|  8|    Hal|   Tokyo|          50|    1|
|  9|  Ignis|   Tokyo|         100|    1|
| 10|   John|   Tokyo|         100|    2|
+---+-------+--------+------------+-----+

Analytical functions

Analytical window functions need the window to be ordered. So, while creating window for ranking functions, you must specify orderBy(). If you don't, spark sql will throw an AnalysisException.

Example -

AnalysisException: Window function cume_dist() requires window to be ordered, please add ORDER BY clause. For example SELECT cume_dist()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table.

cume_dist()

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_cume_dist = df.withColumn("cume_dist",F.cume_dist().over(location_window))
df_cume_dist.show()

Output:

+---+-------+--------+------------+------------------+
| id|   name|location|sales_amount|         cume_dist|
+---+-------+--------+------------+------------------+
|  1|  Alice|  Austin|         100|0.3333333333333333|
|  2|    Bob|  Austin|         200|0.6666666666666666|
|  3|  Chris|  Austin|         300|               1.0|
|  7| Girard| Toronto|         100|              0.25|
|  6|Fabrice| Toronto|         200|               0.5|
|  5|  Elisa| Toronto|         300|              0.75|
|  4|   Dave| Toronto|         400|               1.0|
|  8|    Hal|   Tokyo|          50|0.3333333333333333|
|  9|  Ignis|   Tokyo|         100|               1.0|
| 10|   John|   Tokyo|         100|               1.0|
+---+-------+--------+------------+------------------+

lag(col_name, int)

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_lag = df.withColumn("lag",F.lag(F.col("sales_amount"),1).over(location_window))
df_lag.show()

Output:

+---+-------+--------+------------+----+
| id|   name|location|sales_amount| lag|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100|NULL|
|  2|    Bob|  Austin|         200| 100|
|  3|  Chris|  Austin|         300| 200|
|  7| Girard| Toronto|         100|NULL|
|  6|Fabrice| Toronto|         200| 100|
|  5|  Elisa| Toronto|         300| 200|
|  4|   Dave| Toronto|         400| 300|
|  8|    Hal|   Tokyo|          50|NULL|
|  9|  Ignis|   Tokyo|         100|  50|
| 10|   John|   Tokyo|         100| 100|
+---+-------+--------+------------+----+

lead(col_name, int)

Code:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# create window
location_window  = Window.partitionBy(F.col("location")).orderBy("sales_amount")

# use window function
df_lead = df.withColumn("lead",F.lead(F.col("sales_amount"),1).over(location_window))
df_lead.show()

Output:

+---+-------+--------+------------+----+
| id|   name|location|sales_amount|lead|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100| 200|
|  2|    Bob|  Austin|         200| 300|
|  3|  Chris|  Austin|         300|NULL|
|  7| Girard| Toronto|         100| 200|
|  6|Fabrice| Toronto|         200| 300|
|  5|  Elisa| Toronto|         300| 400|
|  4|   Dave| Toronto|         400|NULL|
|  8|    Hal|   Tokyo|          50| 100|
|  9|  Ignis|   Tokyo|         100| 100|
| 10|   John|   Tokyo|         100|NULL|
+---+-------+--------+------------+----+