A case statement is a type of statement that goes through conditions and returns a value when the first condition is met.
The easiest way to implement a case statement in a PySpark DataFrame is by using the following syntax:
from pyspark.sql.functions import when df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show()
This particular example adds a new column to a DataFrame called class that takes on the following values:
- Bad if the value in the points column is less than 9
- OK if the value in the points column is less than 12
- Good if the value in the points column is less than 15
- Great if none of the previous conditions are true
The following example shows how to use this function in practice.
Example: How to Use a Case Statement in PySpark
Suppose we have the following PySpark DataFrame:
from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() #define data data = [['A', 6], ['B', 8], ['C', 9], ['D', 9], ['E', 12], ['F', 14], ['G', 15], ['H', 17], ['I', 19], ['J', 22]] #define column names columns = ['player', 'points'] #create dataframe using data and column names df = spark.createDataFrame(data, columns) #view dataframe df.show() +------+------+ |player|points| +------+------+ | A| 6| | B| 8| | C| 9| | D| 9| | E| 12| | F| 14| | G| 15| | H| 17| | I| 19| | J| 22| +------+------+
We can use the following syntax to write a case statement that creates a new column called class whose values are determined by the values in the points column:
from pyspark.sql.functions import when df.withColumn('class',when(df.points<9, 'Bad').when(df.points<12, 'OK').when(df.points<15, 'Good').otherwise('Great')).show() +------+------+-----+ |player|points|class| +------+------+-----+ | A| 6| Bad| | B| 8| Bad| | C| 9| OK| | D| 9| OK| | E| 12| Good| | F| 14| Good| | G| 15|Great| | H| 17|Great| | I| 19|Great| | J| 22|Great| +------+------+-----+
The case statement looked at the value in the points column and returned:
- Bad if the value in the points column was less than 9
- OK if the value in the points column was less than 12
- Good if the value in the points column was less than 15
- Great if none of the previous conditions are true
Note: We chose to use three conditions in this particular example but you can chain together as many when() statements as you’d like to include even more conditions in your own case statement.
Additional Resources
The following tutorials explain how to perform other common tasks in PySpark:
PySpark: How to Check Data Type of Columns in DataFrame
PySpark: How to Drop Multiple Columns from DataFrame
PySpark: How to Drop Duplicate Rows from DataFrame