Job Board
Consulting

Spark Scala Cache Best Practices

Caching a DataFrame tells Spark to keep it in memory (or on disk) after the first time it's computed. This avoids recomputing the same transformations every time you trigger an action. Used well, it can dramatically speed up your pipelines. Used carelessly, it can eat all your memory and make things slower.

When to Cache

The short answer: cache when you're going to reuse a DataFrame more than once. If you filter a large dataset and then run two different aggregations on the result, without caching Spark will re-read and re-filter the data for each aggregation. Caching it means the filter runs once.

val df = Seq(
  ("Alice", "Engineering", 95000),
  ("Bob", "Marketing", 72000),
  ("Charlie", "Engineering", 105000),
  ("Diana", "Sales", 68000),
  ("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")

val engineers = df
  .filter(col("department") === "Engineering")
  .cache()

// First action materializes the cache
engineers.count()
// res: 2

// Second action reads from cache — no recomputation
engineers.show(false)
// +-------+-----------+------+
// |name   |department |salary|
// +-------+-----------+------+
// |Alice  |Engineering|95000 |
// |Charlie|Engineering|105000|
// +-------+-----------+------+

A key thing to understand: calling .cache() doesn't do anything immediately. It marks the DataFrame as cacheable. The data is actually cached when the first action (like count() or show()) forces Spark to compute it. Every subsequent action then reads from the cache instead of recomputing from scratch.

Cache Before Branching

The most common and valuable pattern is caching right before a DataFrame branches into multiple paths. If you're building a pipeline that produces several outputs from the same intermediate result, cache that intermediate result.

val df = Seq(
  ("Alice", "Engineering", 95000),
  ("Bob", "Marketing", 72000),
  ("Charlie", "Engineering", 105000),
  ("Diana", "Sales", 68000),
  ("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")

val enriched = df
  .withColumn("salary_band",
    when(col("salary") >= 100000, lit("Senior"))
    .when(col("salary") >= 70000, lit("Mid"))
    .otherwise(lit("Junior"))
  )
  .cache()

// Branch 1: department summary
val deptSummary = enriched
  .groupBy("department")
  .agg(
    count("*").alias("headcount"),
    avg("salary").alias("avg_salary"),
  )

deptSummary.show(false)
// +-----------+---------+----------+
// |department |headcount|avg_salary|
// +-----------+---------+----------+
// |Sales      |1        |68000.0   |
// |Engineering|2        |100000.0  |
// |Marketing  |2        |75000.0   |
// +-----------+---------+----------+

// Branch 2: salary band breakdown
val bandBreakdown = enriched
  .groupBy("salary_band")
  .agg(count("*").alias("count"))

bandBreakdown.show(false)
// +-----------+-----+
// |salary_band|count|
// +-----------+-----+
// |Senior     |1    |
// |Mid        |3    |
// |Junior     |1    |
// +-----------+-----+

Without the .cache() call, Spark would read the source data and recompute the salary_band column twice — once for each branch. With caching, the enriched DataFrame is computed once and both aggregations read from memory.

Persist with Storage Levels

The .cache() method is actually shorthand for .persist(StorageLevel.MEMORY_AND_DISK). If you need more control over where the data is stored, use .persist() directly with a storage level:

Storage Level Behavior
MEMORY_ONLY Store in memory. If it doesn't fit, recompute on the fly.
MEMORY_AND_DISK Store in memory, spill to disk if needed. This is what .cache() uses.
DISK_ONLY Store only on disk. Useful for very large DataFrames you don't want in memory.
MEMORY_ONLY_SER Store serialized in memory. Uses less space but more CPU to deserialize.
MEMORY_AND_DISK_SER Serialized in memory, spill to disk. Good balance for large datasets.
import org.apache.spark.storage.StorageLevel

val df = Seq(
  ("Alice", "Engineering", 95000),
  ("Bob", "Marketing", 72000),
  ("Charlie", "Engineering", 105000),
  ("Diana", "Sales", 68000),
  ("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")

val persisted = df
  .filter(col("salary") > 70000)
  .persist(StorageLevel.MEMORY_AND_DISK)

persisted.show(false)
// +-------+-----------+------+
// |name   |department |salary|
// +-------+-----------+------+
// |Alice  |Engineering|95000 |
// |Bob    |Marketing  |72000 |
// |Charlie|Engineering|105000|
// |Eve    |Marketing  |78000 |
// +-------+-----------+------+

For most use cases, the default .cache() (which uses MEMORY_AND_DISK) is the right choice. Reach for explicit storage levels when you're tuning memory pressure on a cluster or dealing with very large datasets that you know won't fit in memory.

Always Unpersist When Done

Cached DataFrames hold onto cluster memory until the SparkSession ends or you explicitly release them. In long-running applications or notebooks, forgetting to unpersist can slowly starve your cluster of memory.

val df = Seq(
  ("Alice", "Engineering", 95000),
  ("Bob", "Marketing", 72000),
  ("Charlie", "Engineering", 105000),
  ("Diana", "Sales", 68000),
  ("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")

val cached = df.filter(col("department") === "Marketing").cache()

// Use the cached data
cached.count()
// res: 2

cached.show(false)
// +----+----------+------+
// |name|department|salary|
// +----+----------+------+
// |Bob |Marketing |72000 |
// |Eve |Marketing |78000 |
// +----+----------+------+

// Release the memory when done
cached.unpersist()

A good habit: treat cache() and unpersist() as a pair. If you cache something, plan where you'll release it. This is especially important in notebooks where cells are run repeatedly — each run can accumulate cached data if you're not cleaning up.

When NOT to Cache

Caching isn't free. It consumes memory and there's overhead in storing and retrieving the data. Avoid caching when:

  • You only use the DataFrame once. There's nothing to save by caching — you're just wasting memory storing data you'll never read again.
  • The DataFrame is trivially cheap to compute. If you're just reading a small file with no transformations, the recomputation cost is negligible.
  • You're memory constrained. Caching a large DataFrame when memory is tight can cause spilling, garbage collection pressure, or even OOM errors. In these cases, you might make things worse.
  • The DataFrame is used in a single linear chain. If each transformation feeds directly into the next with no branching, Spark's lazy evaluation already handles this efficiently.

The rule of thumb: cache at branch points where the same expensive computation feeds multiple downstream paths. Everywhere else, let Spark's optimizer do its job.

Tutorial Details

Created: 2017-12-01 11:25:00 PM

Last Updated: 2026-03-15 07:12:00 PM