Why collect() on a Large Spark Scala DataFrame Kills Your Driver
collect() is one of the most innocent-looking ways to crash a Spark job. It returns an Array[Row], which feels harmless — until you remember that the array has to fit in driver memory. This tutorial shows what collect() actually does, why it falls over on real datasets, and which APIs to reach for instead.
What collect() Actually Does
collect() is an action that pulls every row from every executor back across the network and into the JVM heap of the driver process. The result is a plain Scala Array[Row] — no laziness, no streaming, no partitioning. Whatever was distributed across the cluster a moment ago is now sitting in a single array on a single machine.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val rows: Array[Row] = df.collect()
rows.foreach(println)
// [Alice,Engineering,95000]
// [Bob,Marketing,72000]
// [Charlie,Engineering,105000]
// [Diana,Sales,68000]
// [Eve,Marketing,78000]
println(s"rows.length = ${rows.length}")
// rows.length = 5
On five rows this is fine. On five billion rows it isn't. The driver is typically the smallest JVM in your cluster — a handful of gigabytes at most — and an Array[Row] has no spill-to-disk behavior. The moment your data exceeds driver heap, you get an OutOfMemoryError and the application dies.
Why This Bites You in Production
The trap is that collect() looks like a normal Scala operation. You wrote it on a sample dataset in a notebook, it worked, you shipped it. Then the input grew, or someone joined in another table, or a partition skewed — and now the same line of code is trying to materialize 50 GB on a driver with 4 GB of heap.
A few common shapes:
- "I just want to print the data."
df.collect().foreach(println)pulls everything back just to print a handful of rows. - "I'll iterate over the rows in Scala." Anything that requires per-row logic in driver code usually wants
foreachPartitionor a UDF, notcollect(). - "I need the values to build the next query." If the values live in a DataFrame, the next query can usually reference them as a DataFrame too — via a join or a broadcast — without round-tripping through the driver.
collect() has legitimate uses, but they almost always involve a tiny, already-aggregated DataFrame. If you're collecting raw rows from a source table, you're probably about to crash.
Just Want to See the Data? Use show()
If your goal is to eyeball some rows during development, show() is built for exactly that. It pulls only the first N rows back (default 20) and prints them as a formatted table.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
df.show(false)
// +-------+-----------+------+
// |name |department |salary|
// +-------+-----------+------+
// |Alice |Engineering|95000 |
// |Bob |Marketing |72000 |
// |Charlie|Engineering|105000|
// |Diana |Sales |68000 |
// |Eve |Marketing |78000 |
// +-------+-----------+------+
show(false) disables column truncation. Pass an integer to control how many rows come back — df.show(100, false) to inspect a hundred. The driver only ever sees those rows, no matter how big the underlying DataFrame is.
Need Rows Programmatically? Use take() or limit()
When you actually want rows as Scala values — say, to assert in a test or pull a small sample — use take(n) instead of collect(). It's collect() with an upper bound: Spark only materializes n rows on the driver.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val sample: Array[Row] = df.take(2)
sample.foreach(println)
// [Alice,Engineering,95000]
// [Bob,Marketing,72000]
If you want to stay in DataFrame land — for example, to feed a sample into another transformation without ever pulling rows to the driver — use limit(n) instead. limit returns a DataFrame; nothing leaves the cluster until you call an action.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val limited: DataFrame = df.limit(2)
limited.show(false)
// +-----+-----------+------+
// |name |department |salary|
// +-----+-----------+------+
// |Alice|Engineering|95000 |
// |Bob |Marketing |72000 |
// +-----+-----------+------+
The distinction matters: take is an action that returns to the driver. limit is a transformation that produces a smaller DataFrame.
Aggregate First, Then Collect
The legitimate use case for collect() is when you've already reduced the data to something small. Counting rows per department turns billions of rows into a handful of summary rows — at that point, collect() is fine.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val deptCounts = df
.groupBy("department")
.agg(count("*").alias("headcount"))
deptCounts.show(false)
// +-----------+---------+
// |department |headcount|
// +-----------+---------+
// |Sales |1 |
// |Engineering|2 |
// |Marketing |2 |
// +-----------+---------+
val rows = deptCounts.collect()
rows.foreach(r => println(s"${r.getString(0)} -> ${r.getLong(1)}"))
// Sales -> 1
// Engineering -> 2
// Marketing -> 2
The aggregation runs on the cluster. Only the small result — one row per department — comes back. This is the shape that collect() is designed for: a final pull of a result that's already small by construction.
Streaming Rows with toLocalIterator
Sometimes you genuinely need to walk every row from driver code — for example, to write them to an external system that doesn't have a Spark connector. collect() would load them all at once. toLocalIterator() brings them back one partition at a time, so memory usage is bounded by the largest partition rather than by the size of the whole DataFrame.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val it = df.toLocalIterator()
while (it.hasNext) {
println(it.next())
}
// [Alice,Engineering,95000]
// [Bob,Marketing,72000]
// [Charlie,Engineering,105000]
// [Diana,Sales,68000]
// [Eve,Marketing,78000]
toLocalIterator isn't free — it runs a separate job for each partition, so it's slower than collect() for data that would fit in memory. Use it when memory is the constraint, not throughput. And if your real goal is to do per-row work in parallel across the cluster, use foreachPartition on the executors instead of pulling rows back at all.
Quick Reference
| You want to… | Use | Notes |
|---|---|---|
| Eyeball a few rows during development | df.show(n) |
Pretty-prints to stdout. Driver only sees those n rows. |
| Get a bounded sample as Scala values | df.take(n) |
Returns Array[Row] of at most n rows. |
| Sample inside a pipeline | df.limit(n) |
Returns a DataFrame. No driver materialization. |
| Pull a small aggregated result | df.collect() |
Safe only when the DataFrame is already small. |
| Iterate over every row from the driver | df.toLocalIterator() |
Streams one partition at a time. Slower but bounded memory. |
| Do per-row work in parallel | df.foreachPartition |
Stays on the executors. No driver round-trip. |
The rule of thumb: treat collect() as a "I am confident this fits in driver memory" assertion. If you can't say that out loud about your data, you don't want collect(). For exploration use show(), for sampling use take() or limit(), and for big results push the work onto the cluster instead of pulling the data home.
If you find yourself reaching for collect() to reuse a DataFrame across multiple downstream queries, you probably want caching instead — that keeps the data distributed but avoids recomputing it. And if you're building DataFrames for tests where collect() is genuinely safe because the data is tiny, see creating DataFrames for testing with toDF.