Filter rows in pySpark

By Niraj Zade | 2024-01-17 | Tags: guide sql


Key takeaway

# filter using dataframe api
filtered_df = df.where(col("amount")>=200)
# or
filtered_df = df.filter(col("amount")>=200)

# filter using sql expression
filtered_df = df.where((col("location")=="Toronto") & (col("amount")>=200))
# or
filtered_df = df.filter((col("location")=="Toronto") & (col("amount")>=200))

Scenario

Fetch rows of the Toronto location where amount is greater than or equal to 200

The WHERE clause can be applied using both where() and filter() method. Both methods take the exact same inputs and give the exact same outputs.

The where() and filter() methods accepts both SQL expressions and Column expressions. I'm solving this scenario in both ways.

Data used in this scenario:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Create dataframe for examples

I'm using Spark Jupyter notebook for this example. You can use databricks or whatever you have got.

  1. Import modules
  2. Create spark session (not needed for databricks)
  3. Create dataframe
# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# import relavant python modules
from pyspark.sql import types as T
from pyspark.sql import functions as F

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "amount"]

df = spark.createDataFrame(data = data, schema = schema)
df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  1|  Alice|  Austin|   100|
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
|  7| Girard| Toronto|   100|
|  8|    Hal|   Tokyo|    50|
|  9|  Ignis|   Tokyo|   100|
| 10|   John|   Tokyo|   100|
+---+-------+--------+------+

Filter using Dataframe API

In this case, we create the condition using functions of the Dataframe API. Make sure that the expression evaluates to a boolean result (True or False).

As a best practice, I'll be using col() function to refer to column names.

from pyspark.sql.functions import col

filtered_df = df.where(col("amount")>=200)
filtered_df.show()

OUTPUT

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
+---+-------+--------+------+

Filter using SQL expression

In this case, we write the condition in SQL string format. Make sure that the expression evaluates to a boolean result (True or False).

from pyspark.sql.functions import col

filtered_df = df.where((col("location")=="Toronto") & (col("amount")>=200))
filtered_df.show()

Output:

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  4|   Dave| Toronto|   400|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
+---+-------+--------+------+

More Examples

Amounts between 200 and 300

Get rows from all locations with amounts between 200 and 300 (200 <=amount<=300)

DATAFRAME API EXPRESSION

from pyspark.sql.functions import col

between_where_df = df.where(col("amount").between(200, 300))
between_where_df.show()

SQL EXPRESSION

between_where_df = df.where("AMOUNT BETWEEN 200 AND 300")
between_where_df.show()

OUTPUT

+---+-------+--------+------+
| id|   name|location|amount|
+---+-------+--------+------+
|  2|    Bob|  Austin|   200|
|  3|  Chris|  Austin|   300|
|  5|  Elisa| Toronto|   300|
|  6|Fabrice| Toronto|   200|
+---+-------+--------+------+