PySpark: How to Select Row with Max Value in Each Group


You can use the following syntax to select the row with the max value by group in a PySpark DataFrame:

from pyspark.sql import Window
import pyspark.sql.functions as F

#specify column to group by
w = Window.partitionBy('team')

#find row with max value in points column by team
df.withColumn('maxPoints', F.max('points').over(w))\
    .where(F.col('points') == F.col('maxPoints'))\
    .drop('maxPoints')\
    .show()  

This particular example returns a DataFrame that contains the rows with the max value in the points column for each unique value in the team column.

The following example shows how to use this syntax in practice.

Example: Select Row with Max Value in Each Group in PySpark

Suppose we have the following PySpark DataFrame that contains information about various basketball players:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [['A', 18, 3], 
        ['A', 33, 5], 
        ['A', 12, 8], 
        ['A', 15, 10], 
        ['B', 19, 4],
        ['B', 24, 4],
        ['B', 28, 2],
        ['C', 40, 7],
        ['C', 24, 3],
        ['C', 13, 4]]
  
#define column names
columns = ['team', 'points', 'assists'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view dataframe
df.show()

+----+------+-------+
|team|points|assists|
+----+------+-------+
|   A|    18|      3|
|   A|    33|      5|
|   A|    12|      8|
|   A|    15|     10|
|   B|    19|      4|
|   B|    24|      4|
|   B|    28|      2|
|   C|    40|      7|
|   C|    24|      3|
|   C|    13|      4|
+----+------+-------+

We can use the following syntax to return a DataFrame that contains the rows with the max value in the points column for each unique value in the team column:

from pyspark.sql import Window
import pyspark.sql.functions as F

#specify column to group by
w = Window.partitionBy('team')

#find row with max value in points column by team
df.withColumn('maxPoints', F.max('points').over(w))\
    .where(F.col('points') == F.col('maxPoints'))\
    .drop('maxPoints')\
    .show()

+----+------+-------+
|team|points|assists|
+----+------+-------+
|   A|    33|      5|
|   B|    28|      2|
|   C|    40|      7|
+----+------+-------+

The resulting DataFrame contains only the rows with the max value in the points column for each unique team.

For example, the max points value among players on team A was 33.

Thus, the entire row that contained this value was included in the final DataFrame.

Additional Resources

The following tutorials explain how to perform other common tasks in PySpark:

How to Calculate the Max by Group in PySpark
How to Calculate Max Value Across Columns in PySpark
How to Calculate the Max Value of a Column in PySpark

Featured Posts

One Reply to “PySpark: How to Select Row with Max Value in Each Group”

Leave a Reply

Your email address will not be published. Required fields are marked *