UDF using python function in PySpark

By Niraj Zade | 2024-10-03 | Tags: workflow


The steps are:

  1. Define python function
  2. Register function as UDF in spark
  3. Use the function, as we would use any other PySpark function.

Example Scenario

Filter rows with invalid mobile numbers in the data below. A mobile number is valid when.

  • It has 10 digits
  • It has only numbers
+-----+-------------+
| name|mobile_number|
+-----+-------------+
|Alice|   9876543210|
|  Bob|         1234|
|Chris|         abcd|
+-----+-------------+

Setup data

(Not required on databricks)

# for jupyter notebook
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

# Create spark session (not required on databricks)
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

Define and use udf

Create dataframe

data = [
    ("Alice", "9876543210"),
    ("Bob", "1234"),
    ("Chris", "abcd"),
]
schema= ["name", "mobile_number"]

contact_df = spark.createDataFrame(data = data, schema = schema)
contact_df.show()

Define python function

# Create python functions
def validate_mobile_number(mobile_number):
    """check if mobile number string is numeric, and has 10 digits"""
    cleaned_mobile_number = str(mobile_number).strip()
    if len(cleaned_mobile_number) == 10 and cleaned_mobile_number.isdigit():
        return True
    return False

Register function as udf, and use it

from pyspark.sql import types as T
from pyspark.sql import functions as F

# Register python function as UDF
validate_mobile_number_udf = F.udf(lambda n: validate_mobile_number(n),T.BooleanType())

# Use UDF
df = contact_df.withColumn("is_number_valid", validate_mobile_number_udf(F.col("mobile_number")))
df.show()
+-----+-------------+---------------+
| name|mobile_number|is_number_valid|
+-----+-------------+---------------+
|Alice|   9876543210|           true|
|  Bob|         1234|          false|
|Chris|         abcd|          false|
+-----+-------------+---------------+

Apply udf

# Filter rows with bad mobile number values
valid_numbers_df = df.where(F.col("is_number_valid")==True)
valid_numbers_df.show()
+-----+-------------+---------------+
| name|mobile_number|is_number_valid|
+-----+-------------+---------------+
|Alice|   9876543210|           true|
+-----+-------------+---------------+

Example scenarios

UDFs are used to perform tasks that we cannot do with SQL or PySpark dataframe functions.

Some example scenarios are:

  • Validate email addresses
  • Validate IP addresses