Job Board
Consulting

Writing a Custom Accumulator to Count Filtered Rows in Spark Scala

When a Spark job filters bad rows out of a dataset, you usually want to know how many — and ideally why. Accumulators are Spark's built-in tool for collecting metrics from worker tasks back to the driver. The built-in LongAccumulator handles a single counter; for per-reason breakdowns you need to write a custom one by extending AccumulatorV2.

The Problem

You're loading orders from upstream and dropping rows that fail validation — missing email, non-positive amount, wrong status. The pipeline runs fine, but operators want a per-run report: how many rows did we drop, and for which reason?

You could compute these stats by running extra aggregations on the input DataFrame, but that means another pass over the data. Accumulators let you collect the counts as a side effect of the work you're already doing.

Starting Simple: the Built-in LongAccumulator

Before writing anything custom, see if a built-in accumulator is enough. LongAccumulator is a single mutable Long counter that workers can increment and the driver can read.

val df = Seq(
  (1, "alice@example.com", 89.99),
  (2, null, 145.00),
  (3, "charlie@example.com", -15.00),
  (4, "diana@example.com", 250.75),
  (5, null, -5.00),
  (6, "frank@example.com", 50.00),
).toDF("order_id", "customer_email", "amount")

val rejected = spark.sparkContext.longAccumulator("rejected_rows")

val cached = df.cache()

cached.foreach { row =>
  val email = row.getAs[String]("customer_email")
  val amount = row.getAs[Double]("amount")
  if (email == null || amount <= 0) rejected.add(1)
}

val valid = cached.filter(col("customer_email").isNotNull && col("amount") > 0)

valid.show(false)
// +--------+-----------------+------+
// |order_id|customer_email   |amount|
// +--------+-----------------+------+
// |1       |alice@example.com|89.99 |
// |4       |diana@example.com|250.75|
// |6       |frank@example.com|50.0  |
// +--------+-----------------+------+

println(s"Rejected rows: ${rejected.value}")
// Rejected rows: 3

Two things to notice. First, the foreach is an action — it forces Spark to materialize the DataFrame and run the closure on every row. Second, we cache() first because filter later will iterate the same DataFrame again, and without caching we'd recompute it from scratch.

This works, but LongAccumulator is a single counter. It can tell you three rows were dropped, but not two for missing email, one for negative amount. For that you need a richer type.

Defining a Custom Accumulator

A custom accumulator extends AccumulatorV2[IN, OUT], where IN is the type added by workers and OUT is the type the driver reads back. For per-reason counts we want to add a String (the reason) and read back a Map[String, Long].

The abstract class requires six methods:

Method Purpose
isZero Returns true if no values have been added yet. Used by Spark to decide whether to ship the accumulator to workers.
copy() Returns a deep copy. Spark copies accumulators when shipping them to tasks.
reset() Resets to the zero state. Called between task attempts.
add(in: IN) Updates the accumulator with a single input value. Called on workers.
merge(other) Merges another accumulator's state into this one. Called when task results return to the driver.
value Returns the current value. Called on the driver.

Here's the implementation:

import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable

class RejectReasonAccumulator extends AccumulatorV2[String, Map[String, Long]] {
  private val counts = mutable.Map.empty[String, Long].withDefaultValue(0L)

  override def isZero: Boolean = counts.isEmpty

  override def copy(): AccumulatorV2[String, Map[String, Long]] = {
    val c = new RejectReasonAccumulator
    counts.foreach { case (k, v) => c.counts(k) = v }
    c
  }

  override def reset(): Unit = counts.clear()

  override def add(reason: String): Unit = {
    counts(reason) += 1
  }

  override def merge(other: AccumulatorV2[String, Map[String, Long]]): Unit = other match {
    case o: RejectReasonAccumulator =>
      o.counts.foreach { case (k, v) => counts(k) = counts(k) + v }
    case _ =>
      throw new UnsupportedOperationException(
        s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}"
      )
  }

  override def value: Map[String, Long] = counts.toMap
}

A few things worth pointing out:

  • The internal state is a mutable.Map[String, Long].withDefaultValue(0L). The default makes counts(reason) += 1 work even for reasons we've never seen before.
  • copy() does a deep copy. If you skip the copy and share the inner map, multiple task attempts will clobber each other.
  • merge pattern-matches on the other accumulator's runtime type. Throw on the wrong type — Spark won't pass you the wrong type in practice, but the typed contract demands it.
  • value returns an immutable Map, not the mutable one. This prevents external code from mutating the accumulator's internals.

Using the Custom Accumulator

To use a custom accumulator, instantiate it, register it with the SparkContext (so it shows up in the Spark UI), then increment it from worker code.

val df = Seq(
  (1, "alice@example.com", 89.99),
  (2, null, 145.00),
  (3, "charlie@example.com", -15.00),
  (4, "diana@example.com", 250.75),
  (5, null, -5.00),
  (6, "frank@example.com", 50.00),
).toDF("order_id", "customer_email", "amount")

val rejectStats = new RejectReasonAccumulator
spark.sparkContext.register(rejectStats, "reject_stats")

val cached = df.cache()

cached.foreach { row =>
  val email = row.getAs[String]("customer_email")
  val amount = row.getAs[Double]("amount")
  if (email == null) rejectStats.add("missing_email")
  else if (amount <= 0) rejectStats.add("non_positive_amount")
}

val valid = cached.filter(col("customer_email").isNotNull && col("amount") > 0)

valid.show(false)
// +--------+-----------------+------+
// |order_id|customer_email   |amount|
// +--------+-----------------+------+
// |1       |alice@example.com|89.99 |
// |4       |diana@example.com|250.75|
// |6       |frank@example.com|50.0  |
// +--------+-----------------+------+

println(s"Reject stats: ${rejectStats.value}")
// Reject stats: Map(missing_email -> 2, non_positive_amount -> 1)

The driver gets back exactly the breakdown it needs: two rows with missing emails, one with a non-positive amount. Use isNull and isNotNull for the filter itself — those operators don't have the surprising null-propagation behavior that === and =!= do.

The string name passed to register is what shows up in the Spark UI under "Accumulators". Pick something descriptive; you'll thank yourself when you're debugging a job a month from now.

Tracking Multiple Reasons Per Row

What if a single row can fail for multiple reasons? In the previous example we used else if, so each row contributed to at most one bucket — that's fine if you only care about the primary reason. But sometimes you want to know about all of them: a row might both have a missing email AND a wrong status, and you want both counted.

val df = Seq(
  (1, "alice@example.com", 89.99, "completed"),
  (2, null, 145.00, "completed"),
  (3, "charlie@example.com", -15.00, "completed"),
  (4, "diana@example.com", 250.75, "completed"),
  (5, null, -5.00, "cancelled"),
  (6, "frank@example.com", 50.00, "pending"),
  (7, "grace@example.com", 75.00, "completed"),
).toDF("order_id", "customer_email", "amount", "status")

val rejectStats = new RejectReasonAccumulator
spark.sparkContext.register(rejectStats, "reject_stats_v2")

val cached = df.cache()

cached.foreach { row =>
  val email = row.getAs[String]("customer_email")
  val amount = row.getAs[Double]("amount")
  val status = row.getAs[String]("status")
  if (email == null) rejectStats.add("missing_email")
  if (amount <= 0) rejectStats.add("non_positive_amount")
  if (status != "completed") rejectStats.add("not_completed")
}

val valid = cached.filter(
  col("customer_email").isNotNull &&
    col("amount") > 0 &&
    col("status") === "completed"
)

valid.show(false)
// +--------+-----------------+------+---------+
// |order_id|customer_email   |amount|status   |
// +--------+-----------------+------+---------+
// |1       |alice@example.com|89.99 |completed|
// |4       |diana@example.com|250.75|completed|
// |7       |grace@example.com|75.0  |completed|
// +--------+-----------------+------+---------+

println(s"Reject stats: ${rejectStats.value}")
// Reject stats: Map(missing_email -> 2, not_completed -> 2, non_positive_amount -> 2)

Notice the totals don't add up to "number of rejected rows" anymore. Row 5 has both a missing email and a non-positive amount and a cancelled status, so it contributes to three buckets at once. Four rows were dropped, but the counters sum to six. That's intentional — the breakdown answers "how often does each problem occur", not "which problem caused the drop".

Why the foreach + cache + filter Pattern?

This is the most important part of the article, because it's where people get burned.

Accumulator updates inside a transformation are unreliable. If you put the if (email == null) rejectStats.add(...) logic inside a filter predicate or a UDF used by withColumn, Spark may re-execute the closure multiple times for the same row — during speculative execution, stage retries, or simply because the DataFrame is evaluated more than once downstream. Each re-execution adds to the accumulator again, and your counts come out too high.

Accumulator updates inside an action are reliable — each row is processed exactly once per successful task. foreach and foreachPartition are actions, so updates inside them are guaranteed once-per-row by Spark's task semantics.

That's why the pattern is:

  1. cache() the source DataFrame so it's only computed once.
  2. Run a foreach action to populate the accumulator.
  3. Run a separate filter to produce the cleaned DataFrame, which reads from the cache instead of recomputing.

If you skip the cache, the source DataFrame gets computed once for the foreach and again for the filter, doubling your input cost. See Spark Scala Cache Best Practices for more on when this matters.

When to Reach for a Custom Accumulator

Custom accumulators are not the answer to every "count something during a job" question. Before writing one, ask:

  • Can a built-in handle it? LongAccumulator, DoubleAccumulator, and CollectionAccumulator cover most simple cases. Reach for a custom accumulator only when you need a richer aggregate — a histogram, a set, a min/max, a per-key counter.
  • Could an extra aggregation do the same thing? A groupBy("reason").count() over a derived column is often simpler than an accumulator and runs as part of the same logical plan. Accumulators win when you want metrics from a job whose primary output is the filtered data, not when the metrics are the output.
  • Do you need the value during the job, or just at the end? Accumulator values are only reliable on the driver after an action completes. They're not a real-time progress meter.

For the right use case — collecting per-category metrics as a by-product of a filtering pipeline — a custom accumulator is exactly the right tool. Just remember: update from inside an action, register with a descriptive name, and don't trust intermediate values.

Tutorial Details

Created: 2026-05-20 10:17:49 PM

Last Updated: 2026-05-20 10:17:49 PM