SCENARIO:
Fetch rows of the Toronto location where amount is greater than or equal to 200
The WHERE clause can be applied using both where()
and filter()
method. Both methods take the exact same inputs and give the exact same outputs.
The where()
and filter()
methods accepts both SQL expressions and Column expressions. I'm solving this scenario in both ways.
Data used in this scenario:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
Initial Setup
I'm using Jupyter notebook for my setup.
It is a 2 step setup process:
- Import modules, create spark session
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'
# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F
# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
- Create dataframe:
data = [
(1, "Alice", "Austin", 100),
(2, "Bob", "Austin", 200),
(3, "Chris", "Austin", 300),
(4, "Dave", "Toronto", 400),
(5, "Elisa", "Toronto", 300),
(6, "Fabrice", "Toronto", 200),
(7, "Girard", "Toronto", 100),
(8, "Hal", "Tokyo", 50),
(9, "Ignis", "Tokyo", 100),
(10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]
df = spark.createDataFrame(data = data, schema = schema)
df.show()
Output:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 1| Alice| Austin| 100|
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
| 7| Girard| Toronto| 100|
| 8| Hal| Tokyo| 50|
| 9| Ignis| Tokyo| 100|
| 10| John| Tokyo| 100|
+---+-------+--------+------+
WHERE clause using SQL expression
In this case, we write the condition in SQL string format. Make sure that the expression evaluates to a boolean result (True or False).
from pyspark.sql.functions import col
filtered_df = df.where((col("location")=="Toronto") & (col("amount")>=200))
filtered_df.show()
Output:
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
+---+-------+--------+------+
WHERE clause using Dataframe API expression
In this case, we create the condition using functions of the Dataframe API. Make sure that the expression evaluates to a boolean result (True or False).
As a best practice, I'll be using col()
function to refer to column names.
from pyspark.sql.functions import col
filtered_df = df.where(col("amount")>=200)
filtered_df.show()
OUTPUT
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 4| Dave| Toronto| 400|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
+---+-------+--------+------+
More Examples
Amounts between 200 and 300
Get rows from all locations with amounts between 200 and 300 (200 <=amount<=300)
SQL EXPRESSION
between_where_df = df.where("AMOUNT BETWEEN 200 AND 300")
between_where_df.show()
DATAFRAME API EXPRESSION
from pyspark.sql.functions import col
between_where_df = df.where(col("amount").between(200, 300))
between_where_df.show()
OUTPUT
+---+-------+--------+------+
| id| name|location|amount|
+---+-------+--------+------+
| 2| Bob| Austin| 200|
| 3| Chris| Austin| 300|
| 5| Elisa| Toronto| 300|
| 6|Fabrice| Toronto| 200|
+---+-------+--------+------+