List Aggregation in Spark

Last updated on: 2025-05-30

In our previous discussions, we explored Group By operations in Spark and how they are commonly paired with aggregate functions. Now, we’ll take that a step further by looking at list aggregations — a useful technique when you want to collect grouped values into a single array for easier analysis.

Spark provides two key functions for this purpose:

  • collect_list() : Gathers all the values in a group into a list, including duplicates.
  • collect_set() : Gathers all the unique values in a group to a list (removes duplicates).

Example Dataframe

We'll be using the same sample DataFrame as in our Grouping Operations article

+--------+-----------+----------+---+------------------+-----------------+
|Item no.|  Item Name|  Category|MRP|  Discounted Price|  Price After Tax|
+--------+-----------+----------+---+------------------+-----------------+
|       1|Paper Clips|Stationery| 23|              20.7|            24.84|
|       2|     Butter|     Dairy| 57|     51.3000000004|            61.56|
|       3|      Jeans|   Clothes|799|             719.1|           862.92|
|       4|      Shirt|   Clothes|570|             513.0|            615.6|
|       5|Butter Milk|     Dairy| 50|              45.0|             54.0|
|       6|        Bag|   Apparel|455|             409.5|            491.4|
|       7|      Shoes|   Apparel|901|             810.9|    973.079999999|
|       8|    Stapler|Stationery| 50|              45.0|             54.0|
|       9|       Pens|Stationery|120|             108.0|            129.6|
+--------+-----------+----------+---+------------------+-----------------+

Creating lists of data

collect_list() function lists out all the prices from each category and collect_set() lists out the prices of all unique items from each category.

val result = priceAfterTax.groupBy("Category")
  .agg(collect_list("MRP").alias("All sales"),
    collect_set("Price After Tax").alias("Billable Amount"))

result.show()

Output

+----------+-------------+--------------------+
|  Category|    All sales|     Billable Amount|
+----------+-------------+--------------------+
|Stationery|[23, 50, 120]|[54.0, 129.6, 24.84]|
|   Apparel|   [455, 901]|[491.4, 973.07999...|
|     Dairy|     [57, 50]|       [54.0, 61.56]|
|   Clothes|   [799, 570]|     [615.6, 862.92]|
+----------+-------------+--------------------+

Combining List Operations with Aggregate Functions

You can mix list aggregation with standard aggregate functions like sum(). Let us find the total of each category of product purchased.

val subTotal = priceAfterTax.groupBy("Category")
  .agg(collect_list("Price After Tax").alias("Billable Amount"),
    sum("Price After Tax").alias("Sub Total"))

subTotal.show()

Output

+----------+--------------------+---------+
|  Category|     Billable Amount|Sub Total|
+----------+--------------------+---------+
|Stationery|[24.84, 54.0, 129.6]|   208.44|
|   Apparel|[491.4, 973.07999...|  1464.48|
|     Dairy|       [61.56, 54.0]|   115.56|
|   Clothes|     [862.92, 615.6]|  1478.52|
+----------+--------------------+---------+

You can also use count(), min(), max(), or avg() in a similar way to extract deeper insights.

If you want to explore more about aggregate functions, please read Aggregate Functions article.

Calculating Average Billing by Category

Let's calculate the average billing amount for each category:

val catAvg = priceAfterTax.groupBy("Category")
  .agg(collect_list("Price After Tax").alias("Billable Amount"),
    avg("Price After Tax").alias("Average Price "))

catAvg.show()

Output

+----------+--------------------+--------------+
|  Category|     Billable Amount|Average Price |
+----------+--------------------+--------------+
|Stationery|[24.84, 54.0, 129.6]|         69.48|
|   Apparel|[491.4, 973.07999...|        732.24|
|     Dairy|       [61.56, 54.0]|         57.78|
|   Clothes|     [862.92, 615.6]|        739.26|
+----------+--------------------+--------------+

Mapping Items to Their Prices

Want to create a map of each item's name and its corresponding price? You can use map_from_entries() with struct() and collect_list()

val mapElements = priceAfterTax.groupBy("Category")
  .agg(
    collect_list("Price After Tax").alias("Billable Amount"), 
    map_from_entries(
      collect_list(
        struct(col("Item Name"), col("Price After Tax"))
      )
    ).alias("Item-Price"))

mapElements.show(false)

Output

+----------+--------------------------+------------------------------------------------------+
|Category  |Billable Amount           |Item-Price                                            |
+----------+--------------------------+------------------------------------------------------+
|Stationery|[24.84, 54.0, 129.6]      |[Paper Clips -> 24.84, Stapler -> 54.0, Pens -> 129.6]|
|Apparel   |[491.4, 973.0799999999999]|[Bag -> 491.4, Shoes -> 973.0799999999999]            |
|Dairy     |[61.56, 54.0]             |[Butter -> 61.56, Butter Milk -> 54.0]                |
|Clothes   |[862.92, 615.6]           |[Jeans -> 862.92, Shirt -> 615.6]                     |
+----------+--------------------------+------------------------------------------------------+

Summary

In this article, we covered:

  • How to use list aggregations with collect_list() and collect_set()

  • How to combine them with traditional aggregate functions (sum, avg, count, etc.)

  • How to map individual elements to their values using map_from_entries()

List aggregations are especially useful when working with large grouped datasets and when you need a complete picture of the group’s contents for further analysis.

References