CASE in pySpark

By Niraj Zade | 2024-01-17   | Tags: guide sql


In the pySpark Dataframe API, we implement the sql CASE clause using when() and otherwise() functions.

Scenario

SCENARIO:

In the dataframe, create column value_flag. Set this column's values conditionally depending on the value in amount:

  • amount < 200 -> low_value
  • 200<=amount<300 -> medium_value
  • 300<=amount -> high_value

There are 2 ways to get this done:

  1. Add a new column to the dataframe using withColumn()
  2. Add a new column to the dataframe using select()

Example data:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Initial Setup

I'm using Jupyter notebook for my setup.

It is a 2 step setup process:

  1. Import modules, create spark session
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
  1. Create dataframe:
data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]

df = spark.createDataFrame(data = data, schema = schema)
df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Add new column to dataframe using withColumn()

In this case, we will be creating a new column value_flag in the dataframe using the withColumn() method.

You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.

IMPLEMENT CASE USING SQL EXPRESSION

withColumn() only accepts column expressions. It does not accept SQL expression. So we convert the SQL expression into column expression using the expr() method.

from pyspark.sql import functions as F

case_df = df.withColumn(
    "value_flag",
    F.expr(
        "CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
    ),
)
case_df.show()

IMPLEMENT CASE USING when() and otherwise()

from pyspark.sql import functions as F

case_df = df.withColumn(
    "value_flag",
    F.when(F.col("amount") < 200, "low_value")
    .when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
    .when(F.col("amount") >= 300, "high_value")
    .otherwise(None),
)
case_df.show()

OUTPUT

+---+-------+--------+------+------------+
| id|   name|location|amount|  value_flag|
+---+-------+--------+------+------------+
|  1|  Alice|  Austin|   100|   low_value|
|  2|    Bob|  Austin|   200|medium_value|
|  3|  Chris|  Austin|   300|  high_value|
|  4|   Dave| Toronto|   400|  high_value|
|  5|  Elisa| Toronto|   300|  high_value|
|  6|Fabrice| Toronto|   200|medium_value|
|  7| Girard| Toronto|   100|   low_value|
|  8|    Hal|   Tokyo|    50|   low_value|
|  9|  Ignis|   Tokyo|   100|   low_value|
| 10|   John|   Tokyo|   100|   low_value|
+---+-------+--------+------+------------+

Add new column to dataframe using select()

You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.

IMPLEMENT CASE USING SQL EXPRESSION

from pyspark.sql import functions as F

case_df = df.select(
    "*",
    F.expr(
        "CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
    ),
)
case_df.show()

IMPLEMENT CASE USING when() and otherwise()

from pyspark.sql import functions as F

case_df = df.select(
    "*",
    F.when(F.col("amount") < 200, "low_value")
    .when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
    .when(F.col("amount") >= 300, "high_value")
    .otherwise(None)
    .alias("value_flag")
)

case_df.show()

OUTPUT

+---+-------+--------+------+------------+
| id|   name|location|amount|  value_flag|
+---+-------+--------+------+------------+
|  1|  Alice|  Austin|   100|   low_value|
|  2|    Bob|  Austin|   200|medium_value|
|  3|  Chris|  Austin|   300|  high_value|
|  4|   Dave| Toronto|   400|  high_value|
|  5|  Elisa| Toronto|   300|  high_value|
|  6|Fabrice| Toronto|   200|medium_value|
|  7| Girard| Toronto|   100|   low_value|
|  8|    Hal|   Tokyo|    50|   low_value|
|  9|  Ignis|   Tokyo|   100|   low_value|
| 10|   John|   Tokyo|   100|   low_value|
+---+-------+--------+------+------------+

© pySparkGuide.com 2024 | Website was autogenerated on 2024-04-24

Brought to you by Niraj Zade - Website, Linkedin

~ whoever owns storage, owns computing ~