PySpark Window Functions

Window functions in PySpark are a powerful tool for performing calculations across a set of rows related to the current row within a DataFrame. These functions allow you to create advanced analytics and aggregations that involve ordering, ranking, and statistical computations in a flexible manner.

Window functions operate on a “window” of rows that are defined by an ordered range of data. The window specification determines which rows are included in the window for each row in the DataFrame. This specification typically consists of:

  • Partition By: This clause divides the rows into partitions or groups based on one or more columns. The window function operates independently within each partition.

  • Order By: This clause defines the order in which rows are processed within each partition. It establishes the logical order of rows in the window.

Commonly used PySpark window functions include:

  • Ranking Functions: These functions assign a rank to each row within a partition based on the specified order. Common ranking functions include rank(), dense_rank(), and ntile().

  • Aggregate Functions: You can calculate aggregate values, such as sums, averages, and maximum/minimum values, over a window of rows. Functions like sum(), avg(), max(), and min() can be used.

  • Lead and Lag Functions: These functions allow you to access values from the “next” or “previous” rows within the window. They are useful for creating time-series calculations.

  • Window Spec: This function is used to define the window specification within the over() clause, specifying the partitioning and ordering criteria.

Here’s an example of using a PySpark window function to calculate the average salary within each department, taking advantage of the partition by clause:

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, avg

# Initialize Spark session
spark = SparkSession.builder.appName("WindowFunctionExample").getOrCreate()

# Create a DataFrame with employee data
df = spark.createDataFrame([...])

# Define a window specification partitioned by department
window_spec = Window.partitionBy("department")

# Calculate the average salary within each department
df.withColumn("avg_salary", avg(col("salary")).over(window_spec)).show()