How to Split Data into Training & Test Sets in PySpark


Often when we fit machine learning algorithms to datasets, we first split the dataset into a training set and a test set.

The easiest way to split a dataset into a training and test set in PySpark is to use the randomSplit function as follows:

train_df, test_df = df.randomSplit(weights=[0.7,0.3], seed=100)

The weights argument specifies the percentage of observations from the original DataFrame to place in the training and test set, respectively.

In this example, we chose to place 70% of the observations into the training set and 30% in the test set.

The seed argument is an integer that is used to ensure that the random split is the same each time you run the code. 

The following example shows how to split a PySpark DataFrame into a training and test set in practice.

Example: Split Data into Training and Test Set in PySpark

First, let’s create the following PySpark DataFrame that contains information about hours spent studying, number of prep exams taken, and final exam score for various students at some university:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

#define data
data = [[1, 1, 76],
        [2, 3, 78],
        [2, 3, 85],
        [4, 5, 88],
        [2, 2, 72],
        [1, 2, 69],
        [5, 1, 94],
        [4, 1, 94],
        [2, 0, 88],
        [4, 3, 92],
        [4, 4, 90],
        [3, 3, 75],
        [6, 2, 96],
        [5, 4, 90],
        [3, 4, 82],
        [4, 4, 85],
        [6, 5, 99],
        [2, 1, 83],
        [1, 0, 62],
        [2, 1, 76]]
  
#define column names
columns = ['hours', 'prep_exams', 'score'] 
  
#create dataframe using data and column names
df = spark.createDataFrame(data, columns) 
  
#view first five rows of dataframe
df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    1|         1|   76|
|    2|         3|   78|
|    2|         3|   85|
|    4|         5|   88|
|    2|         2|   72|
+-----+----------+-----+

Suppose we would like to fit a multiple linear regression model to this dataset, using hours and prep_exams as the predictor variables and score as the response variable.

Before we do so, we may first want to randomly split the dataset so that 70% of the total rows are used for training and 30% are used for testing.

We can use the following syntax to do so:

#split dataset into training and test sets
train_df, test_df = df.randomSplit(weights=[0.7,0.3], seed=100)

We can then use the count() function to view the number of rows in each resulting dataset:

#view count of rows in train_df
print(train_df.count())

14

#view count of rows in test_df
print(test_df.count())

6

We can see that 14 of the 20 (70%) original rows are used for the training set.

We can also see that 6 of the 20 (30%) original rows are used for the test set.

If we’d like, we can also view the first five rows of both the training and test sets:

#view first five rows of training set
train_df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    1|         1|   76|
|    2|         3|   78|
|    2|         3|   85|
|    4|         5|   88|
|    1|         2|   69|
+-----+----------+-----+

#view first five rows of test set
test_df.limit(5).show()

+-----+----------+-----+
|hours|prep_exams|score|
+-----+----------+-----+
|    2|         2|   72|
|    2|         0|   88|
|    4|         1|   94|
|    3|         4|   82|
|    4|         4|   85|
+-----+----------+-----+

We have successfully split the original dataset into a training and test set.

We can now proceed to fit whatever model we’d like to the training set and then test the performance of the model on the test set.

Note: You can find the complete documentation for the PySpark randomSplit function here.

Additional Resources

The following tutorials explain how to perform other common tasks in PySpark:

How to Calculate the Mean of a Column in PySpark
How to Sum Multiple Columns in PySpark
How to Add Multiple Columns to PySpark

Featured Posts

Leave a Reply

Your email address will not be published. Required fields are marked *