Job Board
Consulting

Why .count() on a Filtered Spark Scala DataFrame Triggers a Full Scan

df.count() feels like it should be free — it's just counting rows. And on a plain table backed by Parquet or Iceberg, it more or less is: Spark can pull row counts straight out of the file footers. The moment you add a .filter() in front of it, that shortcut disappears. Spark has to read the filter column for every row to decide which ones to count, and every additional .count() you write re-runs the whole pipeline from scratch.

The Surprising Behavior

Here's the pattern that catches people. You have a transactions DataFrame and you want to know how many transactions were "high value":

val transactions = Seq(
  ("T001", "US", 49.99),
  ("T002", "US", 124.50),
  ("T003", "UK", 18.00),
  ("T004", "CA", 250.00),
  ("T005", "US", 12.75),
  ("T006", "UK", 305.25),
).toDF("txn_id", "region", "amount")

val highValueCount = transactions
  .filter(col("amount") > 100)
  .count()

println(s"High-value count: $highValueCount")
// High-value count: 3

This looks cheap. It's just a count. But under the hood Spark has to do real work: read the amount column for every row in the table, evaluate amount > 100, and then count the survivors. On a small toy DataFrame that's nothing. On a billion-row Parquet table on S3, that's a full column scan — gigabytes of data read just to get back a single integer.

Why This Happens

Spark has a clever optimization for unfiltered counts. When you call df.count() on a table backed by Parquet, ORC, or Iceberg, the engine doesn't actually read the data — it reads the row counts that the writers stored in the file footers. The whole operation is metadata-only.

That optimization is fragile. It depends on Spark being able to prove, statically, that every row in the file counts. The moment you put a .filter() between the source and the .count(), Spark can't prove that anymore. It has no way to know whether row #4,719,201 passes amount > 100 without reading its amount value. So it falls back to the general path: read the filter columns, evaluate the predicate, count the matches.

Column pruning helps — Spark only reads the columns referenced by the filter, not the whole row — and Parquet min/max statistics let it skip whole row groups that obviously can't match. But the rest of the column data really does have to come off disk and through the executors. It's not "look up a number," it's "scan every value in this column."

This is why df.filter(...).count() on a large table can run for minutes while df.count() on the same table returns instantly. They look almost identical in code. They are not even close in cost.

The Anti-Pattern: Multiple Counts, Multiple Scans

This gets worse the moment you reach for .count() more than once. Each .count() is its own action — Spark runs the whole filter pipeline again for every one of them, because nothing was cached in between.

val transactions = Seq(
  ("T001", "US", 49.99),
  ("T002", "US", 124.50),
  ("T003", "UK", 18.00),
  ("T004", "CA", 250.00),
  ("T005", "US", 12.75),
  ("T006", "UK", 305.25),
  ("T007", "CA", 78.00),
  ("T008", "US", 410.00),
).toDF("txn_id", "region", "amount")

val highCount = transactions.filter(col("amount") > 100).count()
val lowCount  = transactions.filter(col("amount") <= 100).count()
val usCount   = transactions.filter(col("region") === "US").count()

println(s"high: $highCount, low: $lowCount, us: $usCount")
// high: 4, low: 4, us: 4

Three counts. Three full scans of the source. If transactions is reading from S3 or a remote warehouse, that's three round-trips through the read path for three numbers that could have come from a single pass. The job will finish, but it'll take roughly three times as long as it needs to.

This pattern usually grows incrementally. Someone adds a filter-count to log how many rows passed validation. Someone else adds another to log how many failed. Six months later there are seven counts before the "real" computation even starts, and the job's wall-clock time is dominated by metric calculation, not by the work that produces output.

The Fix: One Pass with Conditional Aggregation

When you need several counts over the same data, ask Spark for all of them at once. sum(when(cond, 1).otherwise(0)) gives you a count of rows matching cond as part of a regular aggregation, and Spark folds the whole thing into a single scan.

val transactions = Seq(
  ("T001", "US", 49.99),
  ("T002", "US", 124.50),
  ("T003", "UK", 18.00),
  ("T004", "CA", 250.00),
  ("T005", "US", 12.75),
  ("T006", "UK", 305.25),
  ("T007", "CA", 78.00),
  ("T008", "US", 410.00),
).toDF("txn_id", "region", "amount")

val counts = transactions.agg(
  sum(when(col("amount") > 100, 1).otherwise(0)).as("high_count"),
  sum(when(col("amount") <= 100, 1).otherwise(0)).as("low_count"),
  sum(when(col("region") === "US", 1).otherwise(0)).as("us_count"),
)

counts.show(false)
// +----------+---------+--------+
// |high_count|low_count|us_count|
// +----------+---------+--------+
// |4         |4        |4       |
// +----------+---------+--------+

One scan, one shuffle, one row out. The same answers, at roughly one-third the cost. Spark SQL also has count_if(condition) (callable via expr("count_if(amount > 100)")) which expresses the same idea more compactly — see the count and count distinct example for that variant and for the difference between count("*") and count(col). The when/otherwise form is the most portable across Spark versions.

The same trick works for averages, sums, distinct counts, and anything else you'd reach for. If you find yourself writing three filtered actions over the same DataFrame, you almost always want a single .agg(...) call instead.

When You Really Do Need to Reuse the Filtered DataFrame

Sometimes the filter result isn't just a count — you want the filtered rows themselves, plus a few summary numbers about them, plus a downstream join. In that case the answer isn't "be cleverer with aggregation," it's "tell Spark to compute the filter once and hold onto the result."

val transactions = Seq(
  ("T001", "US", 49.99),
  ("T002", "US", 124.50),
  ("T003", "UK", 18.00),
  ("T004", "CA", 250.00),
  ("T005", "US", 12.75),
  ("T006", "UK", 305.25),
  ("T007", "CA", 78.00),
  ("T008", "US", 410.00),
).toDF("txn_id", "region", "amount")

val highValue = transactions.filter(col("amount") > 100).cache()

val totalHigh       = highValue.count()
val regionBreakdown = highValue.groupBy("region").count()

println(s"Total high-value: $totalHigh")
regionBreakdown.orderBy("region").show(false)
// Total high-value: 4
// +------+-----+
// |region|count|
// +------+-----+
// |CA    |1    |
// |UK    |1    |
// |US    |2    |
// +------+-----+

highValue.unpersist()

cache() tells Spark to keep the filtered DataFrame in memory after the first action materializes it. The .count() triggers the scan and stores the filtered rows; the groupBy("region").count() reads from the cache instead of re-running the scan. Two actions, one scan — and you still have highValue available for whatever joins or writes come next.

This isn't free. Cached rows eat executor memory, and if the filtered DataFrame doesn't fit, you'll spill to disk or recompute partitions. Cache when you'll touch the same intermediate result more than once. Don't cache reflexively. The cache best practices tutorial goes deeper on when caching helps and when it hurts.

Quick Reference

You wrote What Spark does Cost
df.count() on Parquet/Iceberg Reads row counts from file footers Metadata-only — near instant
df.filter(cond).count() Reads filter columns, evaluates predicate, counts Full column scan of filter columns
Multiple df.filter(...).count() calls Re-runs the pipeline for each call N scans for N counts
df.agg(sum(when(cond, 1)).as(...), ...) One scan, one aggregation One scan total
df.filter(cond).cache(); .count(); .groupBy(...).count() Scan once, reuse from memory One scan, plus cache cost

The rule of thumb: treat .count() as a real action, not a free check. If a filter sits in front of it, you're paying for the scan. If you're calling .count() more than once on the same source, you're paying for it more than once. Roll counts into a single .agg(...) when you can, and reach for .cache() when you genuinely need the filtered DataFrame for several downstream uses.

And if you're tempted to collect() the filtered DataFrame to "count it in Scala" instead, that's a faster way to crash the job, not a faster way to count.

Tutorial Details

Created: 2026-06-10 10:35:08 PM

Last Updated: 2026-06-10 10:35:08 PM