When and Otherwise in Spark Scala
The when
function in Spark implements conditionals within your DataFrame based etl pipelines. It allows you to perform fallthrough logic and create new columns with values based upone the conditions logic.
The when
function first appeared in version 1.4.0
and as of Spark 3.4.1 it has a straight forward definition:
def when(condition: Column, value: Any): Column
Let's See an example
Consider a situation where you have a store selling a wide variety of items and you want to categorize the items that are the most popular as "Movers and Shakers". Let's use when
to help us categorize that data based upon the quanity of items sold.
val df = Seq(
("Fidget Spinner", 2000),
("Glam Rock Makeup", 3788),
("Screwdriver", 257),
("Plunger", 199),
).toDF("product", "quantity_sold")
val df2 = df
.withColumn("category", when(col("quantity_sold") >= 1000, lit("Movers and Shakers")))
df2.show()
// +----------------+-------------+------------------+
// | product|quantity_sold| category|
// +----------------+-------------+------------------+
// | Fidget Spinner| 2000|Movers and Shakers|
// |Glam Rock Makeup| 3788|Movers and Shakers|
// | Screwdriver| 257| null|
// | Plunger| 199| null|
// +----------------+-------------+------------------+
From the results we can see how when
with a greater than equal to comparison was used to perform a conditional check on the quanity_sold
columns. If more than 1000 items were sold it becomes a "Mover and Shaker".
You can chain when
function calls together. They just return column expressions. Let's expand our categorization to include a few more cateogries:
val df = Seq(
("Fidget Spinner", 2000),
("Glam Rock Makeup", 3788),
("Screwdriver", 257),
("Plunger", 199),
("Screwdriver", 257),
).toDF("product", "quantity_sold")
val df2 = df
.withColumn("category",
when(col("quantity_sold") >= 1000, lit("Movers and Shakers"))
.when(col("quantity_sold") >= 200, lit("Likable"))
)
df2.show()
// +----------------+-------------+------------------+
// | product|quantity_sold| category|
// +----------------+-------------+------------------+
// | Fidget Spinner| 2000|Movers and Shakers|
// |Glam Rock Makeup| 3788|Movers and Shakers|
// | Screwdriver| 257| Likable|
// | Plunger| 199| null|
// | Screwdriver| 257| Likable|
// +----------------+-------------+------------------+
Here we've now captured two different categories. Now what if we want a 'default' option if none of the other conditions exist? A fallthrough value? We can use otherwise
to provide such a value.
val df = Seq(
("Fidget Spinner", 2000),
("Glam Rock Makeup", 3788),
("Screwdriver", 257),
("Plunger", 199),
("Screwdriver", 257),
("Cabbage Patch Doll", 5),
("Garbage Pail Kid", 2),
).toDF("product", "quantity_sold")
val df2 = df
.withColumn("category",
when(col("quantity_sold") >= 1000, lit("Movers and Shakers"))
.when(col("quantity_sold") >= 200, lit("Likable"))
.otherwise(lit("Dud"))
)
df2.show()
// +------------------+-------------+------------------+
// | product|quantity_sold| category|
// +------------------+-------------+------------------+
// | Fidget Spinner| 2000|Movers and Shakers|
// | Glam Rock Makeup| 3788|Movers and Shakers|
// | Screwdriver| 257| Likable|
// | Plunger| 199| Dud|
// | Screwdriver| 257| Likable|
// |Cabbage Patch Doll| 5| Dud|
// | Garbage Pail Kid| 2| Dud|
// +------------------+-------------+------------------+
You can think of the when/otherwise
combination similar to how you would a switch or case statement in other programming languages. The when's are the conditional's evaluated in the order listed and otherwise is the default value if all the conditionals have fallwn through (not been met).
The when
function is really powerful and can be used to great effect. I often use it for simple parsing and value extraction of unstructured data using the regexp-replace function.
Combining it with many of the other spark scala functions can provide a lot of value and it quick;y becomes a common tool used in machine learning and other data pipelines.