How to Collect() – Retrieve data from DataFrame in Databricks

Are you looking or ways to get the data or retrieve the information from the Azure Databricks dataframe or may be you are looking to find the different between the collect and select in the Azure Databricks then you and landed to the right page. In this post I will explain you step by step method to use the collect function in ADB pyspark.

Azure Databricks Spark Tutorial for beginner to advance level – Lesson 1

How to use Collect() function in Azure Databricks pyspark ?

DataFrame collect() is an operation that is used to retrieve all the elements of the dataset to the driver node. We should use the collect() on smaller dataset usually after filter(), group() e.t.c. We cannot able to retrieve larger datasets. Retrieving larger datasets results in OutOfMemory error.

In this article, I will explain the usage of collect() with DataFrame example, when to avoid it, and the difference between collect() and select().

In order to explain with example, first, let’s create a DataFrame.

import pyspark
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

dept = [("Marketing ",10), \
    ("Finance",20), \
    ("IT ",30), \
    ("Sales",40) \
  ]
deptColumns = ["dept_name","dept_id"]
deptDF = spark.createDataFrame(data=dept, schema = deptColumns)
deptDF.show(truncate=False)

show() function on DataFrame in Databricks prints the result of DataFrame in a table format. By default, it shows only 20 rows. The above snippet returns the data which is displayed below.

+---------+-------+
|dept_name|dept_id|
+---------+-------+
| Marketing  |10     |
| Finance |20     |
| IT     |30     |
| Sales  |40     |
+---------+-------+

Now, let’s use the collect() to retrieve the data.

dataCollect = deptDF.collect()
print(dataCollect)

deptDF.collect() retrieves all elements in a DataFrame in databricks as an Array of Row type to the driver node. printing a resultant array yields the below output.

[Row(dept_name='Finance', dept_id=10), 
Row(dept_name='Marketing', dept_id=20), 
Row(dept_name='Sales', dept_id=30), 
Row(dept_name='IT', dept_id=40)]

Note that collect() is an action hence it does not return a DataFrame instead, it returns data in an Array to the driver. Once the data is in an array, we can use python for the loop to process.

for row in dataCollect:
    print(row['dept_name'] + "," +str(row['dept_id']))

If you wanted to get first row and Second column from a DataFrame.

#Returns value of First Row, Second Column which is "Finance"
deptDF.collect()[0][0]

Let’s understand what’s happening on above statement.

  1. deptDF.collect() returns Array of Row type.
  2. deptDF.collect()[1] returns the Second element in an array (1st row).
  3. deptDF.collect[0][1] returns the value of the first row & Second column.

In case we want to just return certain elements of a DataFrame in Databricks, you should call select() transformation first.

When to avoid Collect() Function in Azure Databricks pyspark

Usually, collect() is used to retrieve the action output when we have very small result set and calling collect() on an RDD/DataFrame with a bigger result set causes out of memory as it returns the entire dataset (from all workers) to the driver hence we should avoid calling collect() function on a larger dataset.

Difference between collect () and select () in Databricks pyspark

select() is a transformation that returns a new DataFrame and holds the columns that are selected.

collect() is an action that returns the entire data set in an Array to the driver.

Example of collect() in Databricks Pyspark

Below is example of using collect() on DataFrame, similarly we can also create a program using collect() with RDD.

import pyspark
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

dept = [("Marketing",10), \
    ("Finance",20), \
    ("IT ",30), \
    ("Sales",40) \
  ]
deptColumns = ["dept_name","dept_id"]
deptDF = spark.createDataFrame(data=dept, schema = deptColumns)
deptDF.printSchema()
deptDF.show(truncate=False)

dataCollect = deptDF.collect()

print(dataCollect)

dataCollect2 = deptDF.select("dept_name").collect()
print(dataCollect2)

for row in dataCollect:
    print(row['dept_name'] + "," +str(row['dept_id']))

Databricks Official Documentation Link

Conclusion

In this article, you have learned the collect() function of the DataFrame is an action operation that returns all elements of the DataFrame to spark driver program and also learned it’s not a good to use for bigger dataset.

How to Select Columns From DataFrame in Databricks