Key takeaway
SQL CASE is done in pySpark dataframe API using when()
, then()
and otherwise()
functions:
from pyspark.sql import functions as F
case_df = df.withColumn(
"age_type",
F.when(F.col("age") <= 1, "baby")
.when((F.col("age") >= 1) & (F.col("age") < 18), "child")
.when(F.col("age") >= 18, "adult")
.otherwise("invalid age value"),
)
Scenario
SCENARIO:
In the dataframe, create column value_flag
. Set this column's values conditionally depending on the value in amount
:
amount < 200
->low_value
200<=amount<300
->medium_value
300<=amount
->high_value
There are 2 ways to get this done:
- Add a new column to the dataframe using
withColumn()
- Add a new column to the dataframe using
select()
Example data:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
Initial Setup
I'm using Jupyter notebook for my setup.
It is a 2 step setup process:
- Import modules, create spark session
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'
# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F
# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
- Create dataframe:
data = [
(1, "Alice", "Austin", 100),
(2, "Bob", "Austin", 200),
(3, "Chris", "Austin", 300),
(4, "Dave", "Toronto", 400),
(5, "Elisa", "Toronto", 300),
(6, "Fabrice", "Toronto", 200),
(7, "Girard", "Toronto", 100),
(8, "Hal", "Tokyo", 50),
(9, "Ignis", "Tokyo", 100),
(10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]
df = spark.createDataFrame(data = data, schema = schema)
df.show()
Output:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
Add new case-based column using withColumn()
In this case, we will be creating a new column value_flag
in the dataframe using the withColumn()
method.
You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.
IMPLEMENT CASE USING SQL EXPRESSION
withColumn()
only accepts column expressions. It does not accept SQL expression. So we convert the SQL expression into column expression using the expr()
method.
from pyspark.sql import functions as F
case_df = df.withColumn(
"value_flag",
F.expr(
"CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
),
)
case_df.show()
IMPLEMENT CASE USING when()
and otherwise()
from pyspark.sql import functions as F
case_df = df.withColumn(
"value_flag",
F.when(F.col("amount") < 200, "low_value")
.when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
.when(F.col("amount") >= 300, "high_value")
.otherwise(None),
)
case_df.show()
OUTPUT
+---+-------+--------+------+------------+
| id| name|location|amount| value_flag|
+---+-------+--------+------+------------+
| 1| Alice| Austin| 100| low_value|
| 2| Bob| Austin| 200|medium_value|
| 3| Chris| Austin| 300| high_value|
| 4| Dave| Toronto| 400| high_value|
| 5| Elisa| Toronto| 300| high_value|
| 6|Fabrice| Toronto| 200|medium_value|
| 7| Girard| Toronto| 100| low_value|
| 8| Hal| Tokyo| 50| low_value|
| 9| Ignis| Tokyo| 100| low_value|
| 10| John| Tokyo| 100| low_value|
+---+-------+--------+------+------------+
Add new case-based column using select()
You can get this done using both SQL expression and pySpark sql functions. Use whatever you prefer. I'm solving this scenario in both ways.
IMPLEMENT CASE USING SQL EXPRESSION
from pyspark.sql import functions as F
case_df = df.select(
"*",
F.expr(
"CASE WHEN AMOUNT < 200 THEN 'low_value' WHEN AMOUNT>=200 AND AMOUNT<300 THEN 'medium_value' WHEN AMOUNT>=300 THEN 'high_value' ELSE Null END AS value_flag"
),
)
case_df.show()
IMPLEMENT CASE USING when()
and otherwise()
from pyspark.sql import functions as F
case_df = df.select(
"*",
F.when(F.col("amount") < 200, "low_value")
.when((F.col("amount") >= 200) & (F.col("amount") < 300), "medium_value")
.when(F.col("amount") >= 300, "high_value")
.otherwise(None)
.alias("value_flag")
)
case_df.show()
OUTPUT
+---+-------+--------+------+------------+
| id| name|location|amount| value_flag|
+---+-------+--------+------+------------+
| 1| Alice| Austin| 100| low_value|
| 2| Bob| Austin| 200|medium_value|
| 3| Chris| Austin| 300| high_value|
| 4| Dave| Toronto| 400| high_value|
| 5| Elisa| Toronto| 300| high_value|
| 6|Fabrice| Toronto| 200|medium_value|
| 7| Girard| Toronto| 100| low_value|
| 8| Hal| Tokyo| 50| low_value|
| 9| Ignis| Tokyo| 100| low_value|
| 10| John| Tokyo| 100| low_value|
+---+-------+--------+------+------------+