CASE in pySpark

By Niraj Zade | 2024-01-17 | Tags: guide sql


Key takeaway

SQL CASE is done in pySpark dataframe API using when(), then() and otherwise() functions:

from pyspark.sql import functions as F

case_df = df.withColumn(
    "age_type",
    F.when(F.col("age") <= 1, "baby")
    .when((F.col("age") >= 1) & (F.col("age") < 18), "child")
    .when(F.col("age") >= 18, "adult")
    .otherwise("invalid age value"),
)

Scenario

SCENARIO:

In the dataframe, create column value_flag. Set this column's values conditionally depending on the value in amount:

  • amount < 200 -> low_value
  • 200<=amount<300 -> medium_value
  • 300<=amount -> high_value

There are 2 ways to get this done:

  1. Add a new column to the dataframe using withColumn()
  2. Add a new column to the dataframe using select()

Example data:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Initial Setup

I'm using Jupyter notebook for my setup.

It is a 2 step setup process:

  1. Import modules, create spark session
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
  1. Create dataframe:
data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]

df = spark.createDataFrame(data = data, schema = schema)
df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Add new case-based column using withColumn()

In this case, we will be creating a new column value_flag in the dataframe using the withColumn() method.

You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.

IMPLEMENT CASE USING SQL EXPRESSION

withColumn() only accepts column expressions. It does not accept SQL expression. So we convert the SQL expression into column expression using the expr() method.

from pyspark.sql import functions as F

case_df = df.withColumn(
    "value_flag",
    F.expr(
        "CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
    ),
)
case_df.show()

IMPLEMENT CASE USING when() and otherwise()

from pyspark.sql import functions as F

case_df = df.withColumn(
    "value_flag",
    F.when(F.col("amount") < 200, "low_value")
    .when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
    .when(F.col("amount") >= 300, "high_value")
    .otherwise(None),
)
case_df.show()

OUTPUT

+---+-------+--------+------+------------+
| id|   name|location|amount|  value_flag|
+---+-------+--------+------+------------+
|  1|  Alice|  Austin|   100|   low_value|
|  2|    Bob|  Austin|   200|medium_value|
|  3|  Chris|  Austin|   300|  high_value|
|  4|   Dave| Toronto|   400|  high_value|
|  5|  Elisa| Toronto|   300|  high_value|
|  6|Fabrice| Toronto|   200|medium_value|
|  7| Girard| Toronto|   100|   low_value|
|  8|    Hal|   Tokyo|    50|   low_value|
|  9|  Ignis|   Tokyo|   100|   low_value|
| 10|   John|   Tokyo|   100|   low_value|
+---+-------+--------+------+------------+

Add new case-based column using select()

You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.

IMPLEMENT CASE USING SQL EXPRESSION

from pyspark.sql import functions as F

case_df = df.select(
    "*",
    F.expr(
        "CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
    ),
)
case_df.show()

IMPLEMENT CASE USING when() and otherwise()

from pyspark.sql import functions as F

case_df = df.select(
    "*",
    F.when(F.col("amount") < 200, "low_value")
    .when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
    .when(F.col("amount") >= 300, "high_value")
    .otherwise(None)
    .alias("value_flag")
)

case_df.show()

OUTPUT

+---+-------+--------+------+------------+
| id|   name|location|amount|  value_flag|
+---+-------+--------+------+------------+
|  1|  Alice|  Austin|   100|   low_value|
|  2|    Bob|  Austin|   200|medium_value|
|  3|  Chris|  Austin|   300|  high_value|
|  4|   Dave| Toronto|   400|  high_value|
|  5|  Elisa| Toronto|   300|  high_value|
|  6|Fabrice| Toronto|   200|medium_value|
|  7| Girard| Toronto|   100|   low_value|
|  8|    Hal|   Tokyo|    50|   low_value|
|  9|  Ignis|   Tokyo|   100|   low_value|
| 10|   John|   Tokyo|   100|   low_value|
+---+-------+--------+------+------------+