How to aggregate columns in PySpark Azure Databricks?

Are you looking to find out how to aggregate columns of PySpark DataFrame columns in Azure Databricks cloud or maybe you are looking for a solution, to do calculations on a group of records in PySpark Databricks using the PySpark in-built methods? If you are looking for any of these problem solutions, you have landed on the correct page. I will also show you how to use PySpark to do aggregation on both singe and multiple-column values in DataFrames in Azure Databricks. I will explain it by taking a practical example. So don’t waste time let’s start with a step-by-step guide to understanding how to aggregate columns in PySpark DataFrame.

In this blog, I will teach you the following with practical examples:

  • Single column aggregation
  • Multiple column aggregation
  • Group by column aggregation
  • Multiple aggregation of a group of records using agg() function
  • Rename aggregated column using alias() function

Create a simple DataFrame

Let’s understand how to use PySpark’s in-built aggregation function with a variety of examples. Let’s start by creating a DataFrame.

Gentle reminder:

In Databricks,

  • sparkSession made available as spark
  • sparkContext made available as sc

In case, you want to create it manually, use the below code.

from pyspark.sql.session import SparkSession

spark = SparkSession.builder 
    .master("local[*]") 
    .appName("azurelib.com") 
    .getOrCreate()

sc = spark.sparkContext

a) Create manual PySpark DataFrame

data = [
    ("chevrolet vega 2300","USA",15.5,90,28.0,"1970-01-01"),
    ("chevrolet vega 2300","USA",15.5,90,28.0,"1970-01-01"),
    ("toyota corona","Japan",14.0,95,25.0,"1970-01-01"),
    ("ford pinto","USA",19.0,75,25.0,"1971-01-01"),
    ("amc gremlin","USA",13.0,100,19.0,"1971-01-01"),
    ("plymouth satellite custom","USA",15.5,105,16.0,"1971-01-01"),
    ("datsun 510 (sw)","Japan",17.0,92,28.0,"1972-01-01"),
    ("toyouta corona mark ii (sw)","Japan",14.5,97,23.0,"1972-01-01"),
    ("dodge colt (sw)","USA",15.0,80,28.0,"1972-01-01"),
    ("toyota corolla 1600 (sw)","Japan",16.5,88,27.0,"1972-01-01")
]

columns = ["name","origin","acceleration","horse_power","miles_per_gallon","year"]
df = spark.createDataFrame(data, schema=columns)
df.printSchema()
df.show(truncate=False)

"""
root
 |-- name: string (nullable = true)
 |-- origin: string (nullable = true)
 |-- acceleration: double (nullable = true)
 |-- horse_power: long (nullable = true)
 |-- miles_per_gallon: double (nullable = true)
 |-- year: string (nullable = true)

+---------------------------+------+------------+-----------+----------------+----------+
|name                       |origin|acceleration|horse_power|miles_per_gallon|year      |
+---------------------------+------+------------+-----------+----------------+----------+
|chevrolet vega 2300        |USA   |15.5        |90         |28.0            |1970-01-01|
|chevrolet vega 2300        |USA   |15.5        |90         |28.0            |1970-01-01|
|toyota corona              |Japan |14.0        |95         |25.0            |1970-01-01|
|ford pinto                 |USA   |19.0        |75         |25.0            |1971-01-01|
|amc gremlin                |USA   |13.0        |100        |19.0            |1971-01-01|
|plymouth satellite custom  |USA   |15.5        |105        |16.0            |1971-01-01|
|datsun 510 (sw)            |Japan |17.0        |92         |28.0            |1972-01-01|
|toyouta corona mark ii (sw)|Japan |14.5        |97         |23.0            |1972-01-01|
|dodge colt (sw)            |USA   |15.0        |80         |28.0            |1972-01-01|
|toyota corolla 1600 (sw)   |Japan |16.5        |88         |27.0            |1972-01-01|
+---------------------------+------+------------+-----------+----------------+----------+
"""

b) Creating a DataFrame by reading files

Download and use the below source file.

# replace the file_path with the source file location which you have downloaded.

df_2 = spark.read.format("csv").option("inferSchema", True).option("header", True).load(file_path)
df_2.printSchema()

"""
root
 |-- name: string (nullable = true)
 |-- origin: string (nullable = true)
 |-- acceleration: double (nullable = true)
 |-- horse_power: integer (nullable = true)
 |-- miles_per_gallon: double (nullable = true)
 |-- year: timestamp (nullable = true)
"""

Note: Here, I will be using the manually created DataFrame.

How to do aggregation on a column in PySpark Azure Databricks?

In this section, you will learn how to perform aggregation on a single column of PySpark DataFrame. So, let’s try to find how many records are there in our DataFrame.

Example:

from pyspark.sql.functions import count

df.select(count("name")).show()

"""
Output:

+-----------+
|count(name)|
+-----------+
|         10|
+-----------+

"""

How to do aggregation on multiple columns in PySpark Azure Databricks?

In this section, you will learn how to perform aggregation on multiple columns of PySpark DataFrame. So, let’s try to find out the minimum and maximum horsepower in our DataFrame.

Example:

from pyspark.sql.functions import min, max

df.select(min("horse_power"), max("horse_power")).show()

"""
Output:

+----------------+----------------+
|min(horse_power)|max(horse_power)|
+----------------+----------------+
|              75|             105|
+----------------+----------------+

"""

How to do aggregation on a group of records in PySpark Azure Databricks?

In this section, you will learn how to perform aggregation on a group of records in PySpark DataFrame. So, let’s get deep in.

  • Find out the number of cars per region
  • Find out year wise highest horsepower cars of each region.

Solution 1:

from pyspark.sql.functions import count

df.groupBy("origin").count().show()

"""
Output:

+------+-----+
|origin|count|
+------+-----+
|   USA|    6|
| Japan|    4|
+------+-----+

"""

Solution 2:

from pyspark.sql.functions import max

df.groupBy("year", "origin").max("horse_power").show()

"""
Output:

+----------+------+----------------+
|      year|origin|max(horse_power)|
+----------+------+----------------+
|1970-01-01|   USA|              90|
|1970-01-01| Japan|              95|
|1971-01-01|   USA|             105|
|1972-01-01| Japan|              97|
|1972-01-01|   USA|              80|
+----------+------+----------------+

"""

How to do multiple aggregation on a group of records in PySpark Azure Databricks?

In this section, you will learn how to perform multiple aggregations on a group of records in PySpark DataFrame using agg() function. So, let’s try to find out the minimum, average and maximum horsepower cars in our DataFrame.

Example:

from pyspark.sql.functions import min, avg, max

df.groupBy("origin").agg(
    min("horse_power"),
    avg("horse_power"),
    max("horse_power")    
).show()

"""
Output:

+------+----------------+----------------+----------------+
|origin|min(horse_power)|avg(horse_power)|max(horse_power)|
+------+----------------+----------------+----------------+
|   USA|              75|            90.0|             105|
| Japan|              88|            93.0|              97|
+------+----------------+----------------+----------------+

"""

How to rename aggregated columns in PySpark Azure Databricks?

In this section, you will learn how to rename aggregated columns in PySpark DataFrame using the alias() function.

Example:

from pyspark.sql.functions import min, avg, max

df.groupBy("origin").agg(
    min("horse_power").alias("min_hp"),
    avg("horse_power").alias("avg_hp"),
    max("horse_power").alias("max_hp")
).show()

"""
Output:

+------+------+------+------+
|origin|min_hp|avg_hp|max_hp|
+------+------+------+------+
|   USA|    75|  90.0|   105|
| Japan|    88|  93.0|    97|
+------+------+------+------+

"""

What are the various inbuilt aggregation functions used commonly in PySpark Azure Databricks?

In this section, you will learn the various standard aggregation function available in PySpark.

1. count() and countDistinct()

  • count() returns the number of elements in a column.
  • countDistinct() returns the number of distinct elements in a columns, and always returns the max distinct value.
from pyspark.sql.functions import count, countDistinct

df.select(count("name"), countDistinct("year", "acceleration")).show()

"""
Output:

+-----------+----------------------------------+
|count(name)|count(DISTINCT year, acceleration)|
+-----------+----------------------------------+
|         10|                                 9|
+-----------+----------------------------------+

"""

2. sum() and sumDistinct()

  • sum() returns the sum of elements in a column.
  • sumDistinct() returns the sum of distinct elements in columns, and always returns the max distinct value.
from pyspark.sql.functions import sum, sumDistinct

df.select(sum("horse_power"), sumDistinct("horse_power")).show()

"""
Output:

+----------------+-------------------------+
|sum(horse_power)|sum(DISTINCT horse_power)|
+----------------+-------------------------+
|             912|                      822|
+----------------+-------------------------+

"""

3. stddev(), stddev_samp(), and stddev_pop()

  • stddev() alias for stddev_samp.
  • stddev_samp() returns the sample standard deviation of values in a column.
  • stddev_pop() returns the population standard deviation of the values in a column.
from pyspark.sql.functions import stddev, stddev_samp, stddev_pop

df.select(stddev("horse_power"), stddev_samp("horse_power"), stddev_pop("horse_power")).show()

"""
Output:

+------------------------+------------------------+-----------------------+
|stddev_samp(horse_power)|stddev_samp(horse_power)|stddev_pop(horse_power)|
+------------------------+------------------------+-----------------------+
|       8.929352346801718|       8.929352346801718|      8.471127433818948|
+------------------------+------------------------+-----------------------+

"""

4. avg() and mean()

  • avg() returns the average of values in the input column.
  • mean() is an alias of avg() function, works exactly like avg().
from pyspark.sql.functions import avg, mean

df.select(avg("acceleration"), mean("acceleration")).show()

"""
Output:

+-----------------+-----------------+
|avg(acceleration)|avg(acceleration)|
+-----------------+-----------------+
|            15.55|            15.55|
+-----------------+-----------------+

"""

5. first() and last()

  • first() returns the first value in a column.
  • last() returns the last value in a column.
from pyspark.sql.functions import first, last

df.select(first("horse_power"), last("horse_power")).show()

"""
Output:

+------------------+-----------------+
|first(horse_power)|last(horse_power)|
+------------------+-----------------+
|                90|               88|
+------------------+-----------------+

"""


6. min() and max()

  • min() returns the minimum value in a column.
  • max() returns the maximum value in a column.
from pyspark.sql.functions import min, max

df.select(min("horse_power"), max("horse_power")).show()

"""
Output:

+----------------+----------------+
|min(horse_power)|max(horse_power)|
+----------------+----------------+
|              75|             105|
+----------------+----------------+

"""


7. variance(), variance_samp(), and variance_pop
()

  • variance() alias for var_samp.
  • var_samp() returns the unbiased variance of the values in a column.
  • var_pop() returns the population variance of the values in a column.
from pyspark.sql.functions import variance, var_samp, var_pop

df.select(variance("horse_power"), var_samp("horse_power"), var_pop("horse_power")).show()

"""
Output:

+---------------------+---------------------+--------------------+
|var_samp(horse_power)|var_samp(horse_power)|var_pop(horse_power)|
+---------------------+---------------------+--------------------+
|    79.73333333333333|    79.73333333333333|               71.76|
+---------------------+---------------------+--------------------+

"""


8. collect_list() and collect_set()

  • collect_list() returns all values from an input column with duplicates.
  • collect_set() returns all values from an input column with duplicate values eliminated.
from pyspark.sql.functions import collect_list(), collect_set()

df.select(collect_list("horse_power"), collect_set("horse_power")).show(truncate=False)

"""
Output:

+------------------------------------------+--------------------------------------+
|collect_list(horse_power)                 |collect_set(horse_power)              |
+------------------------------------------+--------------------------------------+
|[90, 90, 95, 75, 100, 105, 92, 97, 80, 88]|[88, 100, 75, 90, 105, 97, 80, 95, 92]|
+------------------------------------------+--------------------------------------+

"""

I have attached the complete code used in this blog in notebook format to this GitHub link. You can download and import this notebook in databricks, jupyter notebook, etc.

When should you use aggregation functions in PySpark Azure Databricks?

These are some of the possible reasons:

  1. To count the number of records
  2. To count the distinct records
  3. To add numeric columns
  4. To add distinct numeric records
  5. To get an average of a numeric column
  6. To find the standard deviation
  7. To find variance
  8. To find the first and last value
  9. To find the minimum and maximum values
  10. To collect column values into a list
  11. To collect distinct column values into a list

Real World Use Case Scenarios for using aggregation functions in PySpark Azure Databricks?

Assume that you were given a dataset for analyzing purposes. For example:

  • To find out the number of records in the dataset by grouping records, use the count() function along with groupBy().
  • To find out the maximum and minimum values in the numeric column, use the min() and max() aggregation functions.
  • To find out the average value in a numeric column, use the avg() or mean() function.

The PySpark commonly used aggregation functions are explained in detail with a practical example in the above section, so have a look at it.

What are the alternatives to PySpark’s inbuilt aggregation function in PySpark Azure Databricks?

You can use the PySpark User Defined Functions (UDF) for aggregating things in a PySpark DataFrame. But the PySpark in-built functions are better performing than PySpark UDF, compile-time safe, and should be used instead of creating your own custom functions (UDF). Avoid utilizing custom UDF at all costs if the performance of your PySpark application is crucial because they cannot be guaranteed to perform.

Final Thoughts

In this article, we have learned about the PySpark fill() method to select the columns of DataFrame in Azure Databricks along with the examples explained clearly. I have also covered different scenarios with practical examples that could be possible. I hope the information that was provided helped in gaining knowledge.

Please share your comments and suggestions in the comment section below and I will try to answer all your queries as time permits.

Arud Seka Berne S

As a big data engineer, I design and build scalable data processing systems and integrate them with various data sources and databases. I have a strong background in Python and am proficient in big data technologies such as Hadoop, Hive, Spark, Databricks, and Azure. My interest lies in working with large datasets and deriving actionable insights to support informed business decisions.