Frequently used Dataframe methods

By Niraj Zade | 2023-12-05 | Tags: reference pyspark api


TRANSFORMATIONS

Joins

join

Join 2 dataframes by specifying join keys and join type

# join on single column
joined_df = left_df.join(right_df, "col1")
# join on multiple column
joined_df = left_df.join(right_df, ["col1", "col2"])

Example - Join on column expressions

joined_df = left_df.join(right_df,
        [
            left_df.name == right_df.name,
            left_df.age == right_df.age
        ],
        'outer'
)

Parameters

  • mandatory
    • other - the other dataframe
  • optional
    • on
      • Values can be:
        • str - a string for the join column name
        • list - a list of string column names
        • Column() expression- a join expression (Column), or a list of Columns.
      • If on is a string or a list of strings indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an equi-join.
    • how : str
      • default is inner
      • Values can be:
        • inner
        • cross
        • outer
        • full
        • fullouter or full_outer
        • left
        • leftouter or left_outer
        • right
        • rightouter or right_outer
        • semi
        • leftsemi or left_semi
        • anti
        • leftanti or left_anti

crossJoin

Perform cartesian join of 2 dataframes

joined_df = left_df.crossJoin(right_df)

select

Projects a set of expressions and returns a new DataFrame.

Parameters

  • cols - column names (string) or expressions (Column). Column name can also be '*'

Examples

Select using strings (column names):

select_df = df.select('*')
select_df = df.select("name", "age")

Select using Column expressions:

select_df = df.select(df.name, (df.age + 10).alias('age')).show()

Select using SQL expressions:

select() doesn't support SQL string expressions. So, use expr() to convert them into column expressions.

select_df = df.select("name", "age", expr("to_date(concat(year, month, day), 'yyyyMMdd') as birthdate"))

selectExpr

Perform select using SQL expression. This is a variant of select()

Exactly same as using select(), but without having to use expr() to convert SQL expression into Column expression.

select_df = df.select("name", "age", "to_date(concat(year, month, day), 'yyyyMMdd') as birthdate")

colRegex

Selects column based on the column name specified as a regex and returns it as Column.

Example - select all columns with prefix Col1

df.select(df.colRegex("`(Col1)?+.+`")).show()

sorting

orderBy

Returns a new dataframe sorted by specified columns.

Guarantees total order of the output (across all partitions).

Parameters

  • optional
    • cols - str, Column(), list (column name strings or column expressions)
    • ascending - bool or list. Default True

There are 2 ways to use the orderBy function:

1 Ordering using orderBy() function's arguments. Preferred when sorting on a single column.

# sort on single column
sorted_df = df.orderBy("age", ascending=False)

# sort on multiple columns. Not very readable code.
# length of both lists must be equal
sorted_df = df.orderBy(["age", "name"], ascending=[False, True])

2 Ordering using asc(), desc() functions. Preferred when sorting on multiple columns (makes code more readable).

from pyspark.sql.functions import desc, asc

# sort on single column
sorted_df = df.orderBy(desc("age"))

# sort on multiple columns. Very readable code.
sorted_df = df.orderBy(desc("age"), desc("name"))

sort

Returns a new dataframe sorted by specified columns

Parameters

  • optional
    • cols - str, Column(), list (column name strings or column expressions)
    • ascending - bool or list. Default True

There are 2 ways to use the sort function:

1 Sorting using orderBy() function's arguments. Preferred when sorting on a single column.

# sort on single column
sorted_df = df.sort("age", ascending=False)

# sort on multiple columns. Not very readable code.
# length of both lists must be equal
sorted_df = df.sort(["age", "name"], ascending=[False, True])

2 Sorting using asc(), desc() functions. Preferred when sorting on multiple columns (makes code more readable).

from pyspark.sql.functions import desc, asc

# sort on single column
sorted_df = df.sort(desc("age"))

# sort on multiple columns. Very readable code.
sorted_df = df.sort(desc("age"), desc("name"))

sortWithinPartitions

Same as sort(), but sorts within partitions (doesn't cause shuffle)

limit

Limit the number of rows in the result dataframe

limit_df = df.limit(10)

Aggregation

agg

Aggregate without groups.

Shorthand for df.groupBy().agg(). Notice that the groupBy() hasn't specified any columns. So it will do global aggregation (cause a large shuffle).

Parameters:

  • exprs - Column expression or dict of key and value strings

agg() using dictionary to specify aggregation

max_age_df = df.agg({"age": "max"})

agg() using spark sql functions

from pyspark.sql import functions as f
min_age_df = df.agg(f.min(df.age))

groupBy

Aggregate with groups

Pair with an aggregate function

Parameters:

  • cols - string, Column expression, list (list of col-name strings, or list of column expressions)

Aggregate using dictionary to specify aggregation

# string
grouped_df = df.groupBy("name").agg({"age": "max"})

# column expression
grouped_df = df.groupBy(df.name).agg({"age": "max"})

Group by multiple columns

# In the array, can mix strings and column expressions 
grouped_df = df.groupBy(["name", df.age]).count()

NOTE

There is no having() function to filter aggregates.

So, filter after the grouping + aggregation is done using .where().

The template looks like - df.groupBy(<group expression>).agg(<aggregation function>).where(predicate)

Explanation for why we don't need a having() method :

Just think about it.

In SQL, HAVING is just a WHERE that is applied after GROUP BY. SQL has separate WHERE and HAVING so that we can clearly tell the SQL compiler which filtering to do before the grouping (WHERE), and what filtering to do after the grouping (HAVING).

In dataframe api, we are in control. We don't have a having() method, but we are in complete control. We can simply define which filtering to do before and which filtering to do after the grouping, by using the filter() before and after the grouping operation.

cube

Create a multi-dimensional cube for the current DataFrame using the specified columns, so we can run aggregations on them.

df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
result_df = df.cube("name", df.age).count().orderBy("name", "age").show()

+-----+----+-----+
| name| age|count|
+-----+----+-----+
| NULL|NULL|    2|
| NULL|   2|    1|
| NULL|   5|    1|
|Alice|NULL|    1|
|Alice|   2|    1|
|  Bob|NULL|    1|
|  Bob|   5|    1|
+-----+----+-----+

df

age name
1 Alice
2 Bob

result_df

name age count
NULL NULL 2
NULL 1 1
NULL 2 1
Alice NULL 1
Alice 2 1
Bob NULL 1
Bob 2 1

I've never used cube(), and have no idea when it should be used.

rollup

Create a multi-dimensional rollup for the current DataFrame using the specified columns, so we can run aggregation on them.

https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.rollup.html

Set operations

intersect

Returns a new dataframe with rows that are present in both dataframe.

Eliminates duplicates rows.

intersection_without_duplicates_df = df1.intersect(df2)

intersectAll

intersect(), but preserves duplicates()

intersection_with_duplicates_df = df1.intersectAll(df2)

union

Combines rows in 2 dataframes.

WARNING - resolves columns by position, not by name (like standard SQL behavior)

Dataframes with same schema

union_df = df1.union(df2)

Dataframes with different schema

df1

name id
alice 1
bob 2

df2

id name
3 charlie
4 dan
union_df = df1.union(df2)

union_df

name id
alice 1
bob 2
3 charlie
4 dan

unionAll

union(), but preserves duplicates

union_df = df1.unionAll(df2)

unionByName

Preferred method.

Like union(), but resolves columns by name instead of position.

df1

name id
alice 1
bob 2

df2

id name
3 charlie
4 dan
union_df = df1.unionByName(df2)

union_df

name id
alice 1
bob 2
charlie 3
dan 4

subtract

Returns rows in left dataframe, but not in right dataframe.

Removes duplicates (works like SQL's EXCEPT DISTINCT)

subtracted_df = left_df.subtract(right_df)

exceptAll

Like subtract(), but preserves duplicates.

Works like SQL's EXCEPT ALL

WARNING - Resolves columns by position (not by name)

result_df = left_df.subtract(right_df)

Data cleanup

dropDuplicates

Returns a dataframe, without duplicate rows

Parameters

  • subset - list of column names. By default uses all columns.
unique_rows_df = df.dropDuplicates()

unique_rows_df = df.dropDuplicates(['name','age'])

dropNa

Returns a new dataframe, omitting rows with null values

Parameters

  • optional
    • how - string. Takes values:
      • any - drop row if it contains any null
      • all - drop row if all values are null
    • thresh - int. Default - None. Overrides how parameter. Drop rows that have more than or equal to thresh null values.
    • subset - str, tuple, list. List of columns to check for nulls
result_df = df.na.drop()
result_df = df.na.drop("any")
result_df = df.na.drop("all")

# drop if name or age is null
result_df = df.na.drop(how="any",subset=["name", "age"])

distinct

Returns a dataframe with distinct rows of current dataframe

df_distinct = df.distinct()

where or filter

where() is an alias for filter(). Use whatever you prefer.

Filter rows of the dataframe using SQL or column expressions.

I prefer using SQL expressions.

Filter using column expressions

filtered_df = df.filter(df.age>3)
filtered_df = df.where(df.age>3)

Filter using SQL expressions

filtered_df = df.filter("age>3")
filtered_df = df.where("age>3")

Partition handling

coalesce

Doesn't cause partitions.

coalesce can only reduce number of partitions. It cannot increase the number of partitions.

Example - If you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.

If the requested number of partitions is larger than the current number of partitions, then it will stay at current number of partitions.

However, if you’re doing a drastic coalesce, e.g. to numPartitions = 1, this may result in your computation taking place on fewer nodes than you like (e.g. one node in the case of numPartitions = 1). To avoid this, you can call repartition(). This will add a shuffle step, but means the current upstream partitions will be executed in parallel (per whatever the current partitioning is).

coalesced_df = df.coalesce(1)

repartition

Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned.

Parameters

  • numPartitions - int. If not specified, default number of partitions will be used.
  • cols - str or column expression. Partitioning columns.
repartitioned_df = df.repartition(10)

# repartition on a column
repartitioned_df = df.repartition(10, "age")

# repartition on multiple columns
repartitioned_df = df.repartition(10, "age", "name")

repartitionByRange

Parameters

  • numPartitions - int - target number of partitions. If not specified, default number of partitions will be used.
  • cols - columns to partition on

Notes:

  • At least one partition-by expression must be specified. When no explicit sort order is specified, “ascending nulls first” is assumed.

  • Due to performance reasons this method uses sampling to estimate the ranges. Hence, the output may not be consistent, since sampling can return different values. The sample size can be controlled by the config spark.sql.execution.rangeExchange.sampleSizePerPartition.

repartitioned_df = df.repartitionByRange(10)

# repartition on a column
repartitioned_df = df.repartitionByRange(10, "age")

# repartition on multiple columns
repartitioned_df = df.repartitionByRange(10, "age", "name")

Restructure dataframe columns

drop

Drop columns from the dataframe

Note

When an input is a column name, it is treated literally without further interpretation. Otherwise, will try to match the equivalent expression. So that dropping column by its name drop(colName) has different semantic with directly dropping the column drop(col(colName)).

ageless_df = df.drop('age')
ageless_df = df.drop(df.age)

WARNING

If a dataframe has 2 columns with the same name, drop() will fail and throw an error Can not drop col(‘yourCol’) due to ambiguous reference

We end up with these situations when we join 2 dataframes. Can use drop() to prevent duplicate columns during join

joined_df_without_duplicate_cols = left_df.join(right_df, left_df.name == left_df.name, 'inner').drop('name')

withColumn

Add/replace a column with new values.

Parameters

  • colName - str - string name of new column
  • col - column expression for new column

Example - add a new column age2

result_df = df.withColumn('age2', df.age + 2).show()

withColumns

Do multiple withColumn() operations at the same time.

Parameters

  • colsMap - dict - A dict of column name and column expression.
result_df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3})

withColumnRenamed

Rename a single existing column

result_df = df.withColumnRenamed('old_col_name', 'new_col_name')

withColumnsRenamed

Rename multiple columns at the same time

Parameters

  • colsMap - dict - A dict of existing column names and corresponding desired column names.
result_df = df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show()

Sample rows

randomSplit

Randomly splits this DataFrame with the provided weights, and returns a list of dataframes

Useful for generating training and test datasets

Parameters -

  • mandatory
    • weights : list
      • list of doubles as weights with which to split the DataFrame. Weights will be normalized if they don’t sum up to 1.0.
  • optional
    • seed : int - the seed to use for the rng algorithm used by the sampler
split_df_list = df.randomSplit([0.8, 0.2], 24)

train_df = split_df_list[0]
test_df = split_df_list[1]

Sample

Returns a sampled subset of this DataFrame.

Use to generate smaller dataset out of larger ones. Usually for quick experiments.

Parameters

  • optional
    • withReplacement : bool, optional - Sample with replacement or not (default False).
    • fraction : float, optional - Fraction of rows to generate, range [0.0, 1.0].
    • seedint, optional - Seed for sampling (default a random seed).
sample_df = df.sample(fraction=0.5, seed=3)

sample_df = df.sample(withReplacement=True, fraction=0.5, seed=3)

sampleBy

Returns a stratified sample without replacement based on the fraction given on each stratum.

Use it to generate multiple sample datasets out of a single dataframe. Better than using multiple sample() commands. Also, you cannot generate multiple dataframes without replacement with the simpler sample().

Parameters

  • mandatory
    • col: Column or str - column that defines strata
    • fractions: dict - sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as zero.
  • optional
    • seed : int, optional - Seed for sampling
sampled_df = dataset.sampleBy("key_column", fractions={0: 0.1, 1: 0.2}, seed=0)

Misc functions

alias

Returns a new DataFrame with an alias set.

Useful when using sql expressions while referencing to a dataframe. But, this function will mostly clutter up the mental flow. Just refactor your code instead.

# set alias
df_alias = df.alias("df_alias")

# join using alias
joined_df = df_alias.join(df_alias, col("df_alias.name") == col("df2.name"), 'inner')
# display
joined_df.select("df_alias.name", "df2.name", "df2.age").sort(desc("df_alias.name")).show()

ACTIONS

collect

Returns all the records as a list of Row.

df.collect()

count

Returns the number of rows in this DataFrame.

df.count()
#> 3

describe

Computes basic statistics for numeric and string columns.

df.describe(['age']).show()
#> +-------+----+
#> |summary| age|
#> +-------+----+
#> |  count|   3|
#> |   mean|12.0|
#> | stddev| 1.0|
#> |    min|  11|
#> |    max|  13|
#> +-------+----+

first

Returns the first row as a Row.

df.first()
#> Row(age=2, name='Alice')

foreach

Applies the f function to all Row of this DataFrame. This is a shorthand for df.rdd.foreach().

def func(person):
    print(person.name)

df.foreach(func)

foreachpartition

Applies the f function to each partition of this DataFrame. This a shorthand for df.rdd.foreachPartition().

def func(itr):
    for person in itr:
        print(person.name)

df.foreachPartition(func)

Returns the first n rows.

df.head()
#> Row(age=2, name='Alice')

df.head(1)
#> [Row(age=2, name='Alice')]

show

Prints the first n rows to the console.

Parameters

  • n : int, optional
    • Number of rows to show.
  • truncate : bool or int, optional
    • If set to True, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length truncate and align cells right.
  • vertical:bool, optional
    • If set to True, print output rows vertically (one line per column value).

Show only top 2 rows.

df.show(2)

+---+-----+
|age| name|
+---+-----+
| 14|  Tom|
| 23|Alice|
+---+-----+
only showing top 2 rows

Show DataFrame where the maximum number of characters is 3.

df.show(truncate=3)

+---+----+
|age|name|
+---+----+
| 14| Tom|
| 23| Ali|
| 16| Bob|
+---+----+

Show DataFrame vertically.

df.show(vertical=True)

-RECORD 0-----
age  | 14
name | Tom
-RECORD 1-----
age  | 23
name | Alice
-RECORD 2-----
age  | 16
name | Bob

summary

Computes specified statistics for numeric and string columns.

Available statistics are: - count - mean - stddev - min - max - arbitrary approximate percentiles specified as a percentage (e.g., 75%)

If no statistics are given, this function computes count, mean, stddev, min, approximate quartiles (percentiles at 25%, 50%, and 75%), and max.

df.select("age", "weight", "height").summary("count", "min", "25%", "75%", "max").show()
+-------+---+------+------+
|summary|age|weight|height|
+-------+---+------+------+
|  count|  3|     3|     3|
|    min| 11|  37.8| 142.2|
|    25%| 11|  37.8| 142.2|
|    75%| 13|  44.1| 150.5|
|    max| 13|  44.1| 150.5|
+-------+---+------+------+

tail

Returns the last num rows as a list of Row.

Running tail requires moving data into the application’s driver process, and doing so with a very large num can crash the driver process with OutOfMemoryError.

df.tail(2)

#> [Row(age=23, name='Alice'), Row(age=16, name='Bob')]

take

Returns the first num rows as a list of Row. Return the first 2 rows of the DataFrame.

df.take(2)

#> [Row(age=14, name='Tom'), Row(age=23, name='Alice')]

toLocalIterator

Returns an iterator that contains all of the rows in this DataFrame. The iterator will consume as much memory as the largest partition in this DataFrame. With prefetch it may consume up to the memory of the 2 largest partitions.

list(df.toLocalIterator())

#> [Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')]

FUNCTIONS

Caching

cache

Persists the DataFrame with the default storage level (MEMORYANDDISK).

df.cache()

persist

Sets the storage level to persist the contents of the DataFrame across operations after the first time it is computed. This can only be used to assign a new storage level if the DataFrame does not have a storage level set yet. If no storage level is specified defaults to (MEMORYANDDISK_DESER)

Can take values:

  • NONE - StorageLevel.NONE
  • DISK_ONLY - StorageLevel.DISK_ONLY
  • DISK_ONLY_2 - StorageLevel.DISK_ONLY_2
  • DISK_ONLY_3 - StorageLevel.DISK_ONLY_3
  • MEMORY_AND_DISK - StorageLevel.MEMORY_AND_DISK
  • MEMORY_AND_DISK_2 - StorageLevel.MEMORY_AND_DISK_2
  • MEMORY_AND_DISK_DESER - StorageLevel.MEMORY_AND_DISK_DESER
  • MEMORY_ONLY - StorageLevel.MEMORY_ONLY
  • MEMORY_ONLY_2 - StorageLevel.MEMORY_ONLY_2
  • OFF_HEAP - StorageLevel.OFF_HEAP
from pyspark.storagelevel import StorageLevel

# default
df.persist()

df.persist(StorageLevel.DISK_ONLY)

unpersist

Marks the DataFrame as non-persistent, and remove all blocks for it from memory and disk.

df.unpersist()

checkpoint

This API is experimental.

Returns a checkpointed version of this DataFrame. Checkpointing can be used to truncate the logical plan of this DataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially. It will be saved to files inside the checkpoint directory set with SparkContext.setCheckpointDir().

Parameters

  • eager : bool, optional, default True. Whether to checkpoint this Dataframe immediately (or lazily).
df.checkpoint()

checkpointed_df = df.checkpoint(False)

localCheckpoint

Returns a locally checkpointed version of this DataFrame. Checkpointing can be used to truncate the logical plan of this DataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially. Local checkpoints are stored in the executors using the caching subsystem and therefore they are not reliable.

Parameters

  • eager : bool, optional, default True. Whether to checkpoint this Dataframe immediately (or lazily).
locally_checkpointed_df = df.localCheckpoint(False)

Temp Tables and Views

registerTempTable

Registers this DataFrame as a temporary table using the given name.

The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame.

Parameters

  • name : str - Name of the temporary table to register.
df.registerTempTable("people")

Can drop the table with

spark.catalog.dropTempView("people")

createTempView

Creates a local temporary view with this DataFrame.

The lifetime of this temporary table is tied to the SparkSession that was used to create this DataFrame. throws TempTableAlreadyExistsException, if the view name already exists in the catalog.

Parameters

  • name : str - Name of the view.
people_df.createTempView("people")

createOrReplaceTempView

Creates or replaces a local temporary view with this DataFrame. (scoped to current sparksession)

Parameters

  • name : str - Name of the view.
people_df.createOrReplaceTempView("people")

createOrReplaceGlobalTempView

Creates or replaces a global temporary view using the given name.

The lifetime of this temporary view is tied to this Spark application.

people_df.createOrReplaceGlobalTempView("people")

# now, can use it
df2 = spark.sql("SELECT * FROM global_temp.people")

Can drop the view with

spark.catalog.dropGlobalTempView("people")

createGlobalTempView

Creates a global temporary view with this DataFrame.

The lifetime of this temporary view is tied to this Spark application. Throws TempTableAlreadyExistsException, if the view name already exists in the catalog.

people_df.createGlobalTempView("people")

# now, can use it
df2 = spark.sql("SELECT * FROM global_temp.people")

Can drop the view with

spark.catalog.dropGlobalTempView("people")

Perf Analysis

explain()

Parameters

  • optional
    • extended : bool, optional. Default False. If False, prints only the physical plan. When this is a string without specifying the mode, it works as the mode is specified.
    • mode : str, optional - specifies the expected output format of plans.
      • simple: Print only a physical plan.
      • extended: Print both logical and physical plans.
      • codegen: Print a physical plan and generated codes if they are available.
      • cost: Print a logical plan and statistics if they are available.
      • formatted: Split explain output into two sections: a physical plan outline and node details.

Print out the physical plan only (default).

df.explain()

Print out all of the parsed, analyzed, optimized and physical plans.

df.explain(True)

== Parsed Logical Plan ==
...
== Analyzed Logical Plan ==
...
== Optimized Logical Plan ==
...
== Physical Plan ==
...

Print out the plans with two sections: a physical plan outline and node details

df.explain(mode="formatted")  

== Physical Plan ==
* Scan ExistingRDD (...)
(1) Scan ExistingRDD [codegen id : ...]
Output [2]: [age..., name...]
...

Print a logical plan and statistics if they are available

df.explain("cost")

== Optimized Logical Plan ==
...Statistics...
...

hint

Used to hint join type

Example - Explicitly trigger the broadcast hashjoin by providing the hint in df2.

df.join(df2.hint("broadcast"), "name")

Available hints:

  1. BROADCAST - Broadcast hash join. Join side (one with the hint) will be broadcasted, irrespective of the autoBroadcastJoinThreshold.
  2. MERGE or SHUFFLE_MERGE or MERGEJOIN - Shuffle sort merge join.
  3. SHUFFLE_HASH - Shuffle hash join
  4. SHUFFLE_REPLICATE_NL - Cartesian product join if inner join

Data exploration

printSchema()

Prints out the schema in the tree format. Optionally allows to specify how many levels to print if schema is nested.

Parameters

  • level : int, optional, default None. How many levels to print for nested schemas.
df.printSchema()
root
 |-- age: long (nullable = true)
 |-- name: string (nullable = true)

misc

toDF()

Returns a new DataFrame that with new specified column names

Parameters

  • cols : tuple - a tuple of string new column name. The length of the list needs to be the same as the number of columns in the initial DataFrame
df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
cols_renamed_df = df.toDF('col1', 'col2').show()

toJSON

Converts a DataFrame into a RDD of string.

Returns a RDD

Parameters

  • optional
    • use_unicode : bool, optional, default True. Whether to convert to unicode or not.
json_rdd = df.toJSON().first()

withWatermark()

Used for streaming pipelines

Documentation link - pyspark.sql.DataFrame.withWatermark