Spark Scala Cache Best Practices
Caching a DataFrame tells Spark to keep it in memory (or on disk) after the first time it's computed. This avoids recomputing the same transformations every time you trigger an action. Used well, it can dramatically speed up your pipelines. Used carelessly, it can eat all your memory and make things slower.
When to Cache
The short answer: cache when you're going to reuse a DataFrame more than once. If you filter a large dataset and then run two different aggregations on the result, without caching Spark will re-read and re-filter the data for each aggregation. Caching it means the filter runs once.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val engineers = df
.filter(col("department") === "Engineering")
.cache()
// First action materializes the cache
engineers.count()
// res: 2
// Second action reads from cache — no recomputation
engineers.show(false)
// +-------+-----------+------+
// |name |department |salary|
// +-------+-----------+------+
// |Alice |Engineering|95000 |
// |Charlie|Engineering|105000|
// +-------+-----------+------+
A key thing to understand: calling .cache() doesn't do anything immediately. It marks the DataFrame as cacheable. The data is actually cached when the first action (like count() or show()) forces Spark to compute it. Every subsequent action then reads from the cache instead of recomputing from scratch.
Cache Before Branching
The most common and valuable pattern is caching right before a DataFrame branches into multiple paths. If you're building a pipeline that produces several outputs from the same intermediate result, cache that intermediate result.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val enriched = df
.withColumn("salary_band",
when(col("salary") >= 100000, lit("Senior"))
.when(col("salary") >= 70000, lit("Mid"))
.otherwise(lit("Junior"))
)
.cache()
// Branch 1: department summary
val deptSummary = enriched
.groupBy("department")
.agg(
count("*").alias("headcount"),
avg("salary").alias("avg_salary"),
)
deptSummary.show(false)
// +-----------+---------+----------+
// |department |headcount|avg_salary|
// +-----------+---------+----------+
// |Sales |1 |68000.0 |
// |Engineering|2 |100000.0 |
// |Marketing |2 |75000.0 |
// +-----------+---------+----------+
// Branch 2: salary band breakdown
val bandBreakdown = enriched
.groupBy("salary_band")
.agg(count("*").alias("count"))
bandBreakdown.show(false)
// +-----------+-----+
// |salary_band|count|
// +-----------+-----+
// |Senior |1 |
// |Mid |3 |
// |Junior |1 |
// +-----------+-----+
Without the .cache() call, Spark would read the source data and recompute the salary_band column twice — once for each branch. With caching, the enriched DataFrame is computed once and both aggregations read from memory.
Persist with Storage Levels
The .cache() method is actually shorthand for .persist(StorageLevel.MEMORY_AND_DISK). If you need more control over where the data is stored, use .persist() directly with a storage level:
| Storage Level | Behavior |
|---|---|
MEMORY_ONLY |
Store in memory. If it doesn't fit, recompute on the fly. |
MEMORY_AND_DISK |
Store in memory, spill to disk if needed. This is what .cache() uses. |
DISK_ONLY |
Store only on disk. Useful for very large DataFrames you don't want in memory. |
MEMORY_ONLY_SER |
Store serialized in memory. Uses less space but more CPU to deserialize. |
MEMORY_AND_DISK_SER |
Serialized in memory, spill to disk. Good balance for large datasets. |
import org.apache.spark.storage.StorageLevel
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val persisted = df
.filter(col("salary") > 70000)
.persist(StorageLevel.MEMORY_AND_DISK)
persisted.show(false)
// +-------+-----------+------+
// |name |department |salary|
// +-------+-----------+------+
// |Alice |Engineering|95000 |
// |Bob |Marketing |72000 |
// |Charlie|Engineering|105000|
// |Eve |Marketing |78000 |
// +-------+-----------+------+
For most use cases, the default .cache() (which uses MEMORY_AND_DISK) is the right choice. Reach for explicit storage levels when you're tuning memory pressure on a cluster or dealing with very large datasets that you know won't fit in memory.
Always Unpersist When Done
Cached DataFrames hold onto cluster memory until the SparkSession ends or you explicitly release them. In long-running applications or notebooks, forgetting to unpersist can slowly starve your cluster of memory.
val df = Seq(
("Alice", "Engineering", 95000),
("Bob", "Marketing", 72000),
("Charlie", "Engineering", 105000),
("Diana", "Sales", 68000),
("Eve", "Marketing", 78000),
).toDF("name", "department", "salary")
val cached = df.filter(col("department") === "Marketing").cache()
// Use the cached data
cached.count()
// res: 2
cached.show(false)
// +----+----------+------+
// |name|department|salary|
// +----+----------+------+
// |Bob |Marketing |72000 |
// |Eve |Marketing |78000 |
// +----+----------+------+
// Release the memory when done
cached.unpersist()
A good habit: treat cache() and unpersist() as a pair. If you cache something, plan where you'll release it. This is especially important in notebooks where cells are run repeatedly — each run can accumulate cached data if you're not cleaning up.
When NOT to Cache
Caching isn't free. It consumes memory and there's overhead in storing and retrieving the data. Avoid caching when:
- You only use the DataFrame once. There's nothing to save by caching — you're just wasting memory storing data you'll never read again.
- The DataFrame is trivially cheap to compute. If you're just reading a small file with no transformations, the recomputation cost is negligible.
- You're memory constrained. Caching a large DataFrame when memory is tight can cause spilling, garbage collection pressure, or even OOM errors. In these cases, you might make things worse.
- The DataFrame is used in a single linear chain. If each transformation feeds directly into the next with no branching, Spark's lazy evaluation already handles this efficiently.
The rule of thumb: cache at branch points where the same expensive computation feeds multiple downstream paths. Everywhere else, let Spark's optimizer do its job.