Window Functions: Access Previous, Next & Cumulative Distribution Values
Last updated on: 2025-05-30
In the previous article, we explored different ways to rank rows within a window partition, which helped in organizing and analyzing grouped data efficiently. In this article, we’ll take it a step further and delve into three essential window functions that allow us to retrieve data from other rows within the same partition:
- lag()
- lead()
- cume_dist()
To use these functions, make sure to import the following class:
We'll be working with the same DataFrame used in the previous article:
+---+--------+-----------+----------+-------------------+-----+
| ID| Name|Room Number| DOB| Submit Time|Marks|
+---+--------+-----------+----------+-------------------+-----+
| 1| Ajay| 10|2010-01-01|2025-02-17 12:30:45|92.75|
| 2|Bharghav| 20|2009-06-04|2025-02-17 08:15:30| 88.5|
| 3| Chaitra| 30|2010-12-12|2025-02-17 14:45:10| 75.8|
| 4| Kamal| 20|2010-08-25|2025-02-17 17:10:05| 82.3|
| 5| Sohaib| 30|2009-04-14|2025-02-17 09:55:20| 90.6|
| 6| Tanish| 20|2009-05-11|2025-02-17 09:45:30| 88.5|
| 7| Uday| 20|2009-09-06|2025-02-17 09:45:30| 92.3|
+---+--------+-----------+----------+-------------------+-----+
We’ll define a common window specification to use across functions:
val rankRow = Window.partitionBy(col("Room Number"))
.orderBy(col("Marks"))
1. lag() function
The lag() function retrieves a value from a previous row in the same window partition. It takes the following parameters:
-
The column to retrieve.
-
Offset (number of rows back).
-
Default value if no such row exists.
Example: val lagRows = df.withColumn("Lag Rows", lag(col("Marks"),1,0) .over(rankRow))
lagRows.show()
**Output**
```text
+---+--------+-----------+----------+-------------------+-----+--------+
| ID| Name|Room Number| DOB| Submit Time|Marks|Lag Rows|
+---+--------+-----------+----------+-------------------+-----+--------+
| 4| Kamal| 20|2010-08-25|2025-02-17 17:10:05| 82.3| 0.0|
| 2|Bharghav| 20|2009-06-04|2025-02-17 08:15:30| 88.5| 82.3|
| 6| Tanish| 20|2009-05-11|2025-02-17 09:45:30| 88.5| 88.5|
| 7| Uday| 20|2009-09-06|2025-02-17 09:45:30| 92.3| 88.5|
| 1| Ajay| 10|2010-01-01|2025-02-17 12:30:45|92.75| 0.0|
| 3| Chaitra| 30|2010-12-12|2025-02-17 14:45:10| 75.8| 0.0|
| 5| Sohaib| 30|2009-04-14|2025-02-17 09:55:20| 90.6| 75.8|
+---+--------+-----------+----------+-------------------+-----+--------+
Explanation:
Let’s take Room Number 20 as an example:
-
Kamal has 82.3, and since he’s the first in his partition (lowest marks), there’s no row before him → Lag = 0.0
-
Bharghav has 88.5, so we look at the row before him → Lag = 82.3
-
Tanish also has 88.5 → since Spark preserves row order in ties, the previous row is Bharghav → Lag = 88.5
-
Uday has 92.3 → the previous row is Tanish → Lag = 88.5
2. lead() function
The lead() function retrieves a value from a future row in the same window partition.
val leadRows = df.withColumn("Lag Rows",
lead(col("Marks"),1,0)
.over(rankRow))
leadRows.show()
Output
+---+--------+-----------+----------+-------------------+-----+---------+
| ID| Name|Room Number| DOB| Submit Time|Marks|Lead Rows|
+---+--------+-----------+----------+-------------------+-----+---------+
| 4| Kamal| 20|2010-08-25|2025-02-17 17:10:05| 82.3| 88.5|
| 2|Bharghav| 20|2009-06-04|2025-02-17 08:15:30| 88.5| 88.5|
| 6| Tanish| 20|2009-05-11|2025-02-17 09:45:30| 88.5| 92.3|
| 7| Uday| 20|2009-09-06|2025-02-17 09:45:30| 92.3| 0.0|
| 1| Ajay| 10|2010-01-01|2025-02-17 12:30:45|92.75| 0.0|
| 3| Chaitra| 30|2010-12-12|2025-02-17 14:45:10| 75.8| 90.6|
| 5| Sohaib| 30|2009-04-14|2025-02-17 09:55:20| 90.6| 0.0|
+---+--------+-----------+----------+-------------------+-----+---------+
Explanation:
-
Continuing with Room Number 20:
-
Kamal (82.3) → next row is Bharghav (88.5) → Lead = 88.5
-
Bharghav (88.5) → next is Tanish (88.5) → Lead = 88.5
-
Tanish (88.5) → next is Uday (92.3) → Lead = 92.3
-
Uday (92.3) → no next row → Lead = 0.0
-
We use
lag()
andlead()
functions to check what values are taken by the previous and next rows respectively. These functions help in simplifying the task.
3. cume_dist() function
The cume_dist()
function calculates the cumulative distribution of values within a partition. It returns a fraction between 0 and 1 that indicates the proportion of rows with values less than or equal to the current row.
Example:
val cumDistRows = df.withColumn("Cumulative Distribution",
cume_dist().over(rankRow))
cumDistRows.show()
Output
+---+--------+-----------+----------+-------------------+-----+-----------------------+
| ID| Name|Room Number| DOB| Submit Time|Marks|Cumulative Distribution|
+---+--------+-----------+----------+-------------------+-----+-----------------------+
| 4| Kamal| 20|2010-08-25|2025-02-17 17:10:05| 82.3| 0.25|
| 2|Bharghav| 20|2009-06-04|2025-02-17 08:15:30| 88.5| 0.75|
| 6| Tanish| 20|2009-05-11|2025-02-17 09:45:30| 88.5| 0.75|
| 7| Uday| 20|2009-09-06|2025-02-17 09:45:30| 92.3| 1.0|
| 1| Ajay| 10|2010-01-01|2025-02-17 12:30:45|92.75| 1.0|
| 3| Chaitra| 30|2010-12-12|2025-02-17 14:45:10| 75.8| 0.5|
| 5| Sohaib| 30|2009-04-14|2025-02-17 09:55:20| 90.6| 1.0|
+---+--------+-----------+----------+-------------------+-----+-----------------------+
cume_dist()
provides a percentile-like ranking. It’s particularly useful when analyzing distributions, percentiles, or plotting cumulative charts.
Explanation:
For Room Number 20:
-
There are 4 students.
-
Kamal (lowest marks) → 1 out of 4 ≤ him → 0.25
-
Bharghav and Tanish (88.5) → 3 out of 4 ≤ them → 0.75
-
Uday (highest) → 4 out of 4 ≤ him → 1.0
Summary
In this article, we covered:
- How to use lag() and lead() to fetch values from previous and next rows.
- How cume_dist() helps analyze the distribution of values in a partition.
- Why these window functions are important for tasks like trend analysis, comparison, ranking, and scoring.
These powerful tools eliminate the need for complex logic and help keep your Spark transformations concise and efficient.