One major problem in machine learning is handling imbalanced datasets, particularly in large-scale settings such as Apache Spark's MLlib. Unbalanced datasets, or unequal class representation, are prevalent in real-world situations. The predictive potential for the minority class may be diminished as a result of this mismatch if models are skewed in favor of the majority class. This post will walk you through the steps of utilizing Spark MLlib to handle these kinds of datasets, replete with a working code example.
Table of Content
Understanding the Problem of Imbalanced Data
Unbalanced datasets occur when the number of instances in one class significantly outnumbers the instances in other classes.
In fields like risk management, medical diagnosis, and fraud detection, imbalanced datasets are frequently encountered. For example, when it comes to credit card fraud detection, there are significantly fewer fraudulent transactions than there are valid ones.
In a credit card fraud detection situation, for example, a model may reach over 99% accuracy just by predicting every transaction as not fraudulent, even though just 0.17% of transactions are fraudulent. But there would be no utility for this in terms of fraud detection. We may use a variety of Spark MLlib techniques, including cost-sensitive learning, resampling, and ensemble methods, to address this obstacle.
Techniques to Handle Imbalanced Data in Spark MLlib
1. Data-Level Methods
To balance the classes, resampling entails making changes to the training dataset. Either oversampling the minority class or under sampling the majority class can be used to achieve this.
- Oversampling the Minority Class: This method entails creating synthetic data (using SMOTE, for example) or copying instances of the minority class. While Spark MLlib does not support SMOTE natively, we can use DataFrame operations to implement oversampling.
- Under sampling the Majority Class: If the dataset is tiny, this strategy may lead to a loss of important information as it lowers the number of instances in the majority class.
2. Algorithm-Level Methods
These methods modify the learning algorithm to account for class imbalance.
- Cost-Sensitive Learning: Assigns different costs to misclassifications of different classes, making the algorithm more sensitive to the minority class.The loss function in Spark MLlib can be changed to more severely penalize the minority class's misclassification, increasing the model's sensitivity to the minority class.
- Class Weights: In Spark MLlib, you can assign weights to classes to balance their influence during training. This is particularly useful in algorithms like Logistic Regression
3. Ensemble Methods
For unbalanced datasets, ensemble techniques like Random Forests and Gradient-Boosted Trees (GBT) can be especially useful. These techniques enhance overall performance by combining the predictions of several models.
- Random Forests: Random Forests can effectively manage imbalanced data by varying the class weights. You can adjust the number of trees and maximum depth in Spark MLlib's implementation to enhance performance.
- Gradient-Boosted Trees (GBT): To prevent overfitting, GBTs need to be carefully tuned. However, they can be more sensitive to the minority class.
Handling unbalanced datasets in Spark MLlib
Let's put these methods to use with the Credit Card Fraud Detection Dataset. With only 492 fraudulent transactions out of 284,807 transactions, this dataset is incredibly skewed. Here is a complete code example that uses Spark MLlib to handle this unbalanced dataset.
Click on the link to get dataset : Credit Card
Here's a detailed explanation of how to use Spark MLlib to handle unbalanced datasets using the code that was provided:
Step 1: Import Required Libraries
These are the necessary libraries for creating a Spark session, preparing data, and building a machine learning pipeline.
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col
Step 2: Create a Spark Session
This line initializes a Spark session, which is the entry point for using Spark to process data.
spark = SparkSession.builder.appName("CreditCardFraudDetection").getOrCreate()
Step 3: Load the Dataset
Here, the dataset is loaded into a DataFrame from the specified CSV file, with the header inferred as column names and data types automatically determined.
data_path = r"C:\Users\R.Daswanta kumar\Downloads\archive\creditcard.csv"
df = spark.read.csv(data_path, header=True, inferSchema=True)
Step 4: Display Schema and Sample Data
These commands display the structure (schema) of the DataFrame and show the first few rows of data for inspection.
df.printSchema()
df.show(5)
Output:
20s
# Load the dataset
data_path = r"/content/creditcard.csv"
df = spark.read.csv(data_path, header=True, inferSchema=True)
# Display schema and first few rows
df.printSchema()
df.show(5)
root
|-- Time: double (nullable = true)
|-- V1: double (nullable = true)
|-- V2: double (nullable = true)
|-- V3: double (nullable = true)
|-- V4: double (nullable = true)
|-- V5: double (nullable = true)
|-- V6: double (nullable = true)
|-- V7: double (nullable = true)
|-- V8: double (nullable = true)
|-- V9: double (nullable = true)
|-- V10: double (nullable = true)
|-- V11: double (nullable = true)
|-- V12: double (nullable = true)
|-- V13: double (nullable = true)
|-- V14: double (nullable = true)
|-- V15: double (nullable = true)
|-- V16: double (nullable = true)
|-- V17: double (nullable = true)
|-- V18: double (nullable = true)
|-- V19: double (nullable = true)
|-- V20: double (nullable = true)
|-- V21: double (nullable = true)
|-- V22: double (nullable = true)
|-- V23: double (nullable = true)
|-- V24: double (nullable = true)
|-- V25: double (nullable = true)
|-- V26: double (nullable = true)
|-- V27: double (nullable = true)
|-- V28: double (nullable = true)
|-- Amount: double (nullable = true)
|-- Class: integer (nullable = true)
+----+------------------+-------------------+----------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------+-------------------+------------------+------------------+------------------+------------------+--------------------+-------------------+------+-----+
|Time| V1| V2| V3| V4| V5| V6| V7| V8| V9| V10| V11| V12| V13| V14| V15| V16| V17| V18| V19| V20| V21| V22| V23| V24| V25| V26| V27| V28|Amount|Class|
+----+------------------+-------------------+----------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------+-------------------+------------------+------------------+------------------+------------------+--------------------+-------------------+------+-----+
| 0.0| -1.3598071336738|-0.0727811733098497|2.53634673796914| 1.37815522427443| -0.338320769942518| 0.462387777762292| 0.239598554061257|0.0986979012610507| 0.363786969611213| 0.0907941719789316|-0.551599533260813|-0.617800855762348|-0.991389847235408|-0.311169353699879| 1.46817697209427|-0.470400525259478| 0.207971241929242| 0.0257905801985591| 0.403992960255733| 0.251412098239705| -0.018306777944153| 0.277837575558899|-0.110473910188767|0.0669280749146731| 0.128539358273528|-0.189114843888824| 0.133558376740387|-0.0210530534538215|149.62| 0|
| 0.0| 1.19185711131486| 0.26615071205963|0.16648011335321| 0.448154078460911| 0.0600176492822243|-0.0823608088155687|-0.0788029833323113|0.0851016549148104|-0.255425128109186| -0.166974414004614| 1.61272666105479| 1.06523531137287| 0.48909501589608|-0.143772296441519| 0.635558093258208| 0.463917041022171|-0.114804663102346| -0.183361270123994|-0.145783041325259|-0.0690831352230203| -0.225775248033138| -0.638671952771851| 0.101288021253234|-0.339846475529127| 0.167170404418143| 0.125894532368176|-0.00898309914322813| 0.0147241691924927| 2.69| 0|
| 1.0| -1.35835406159823| -1.34016307473609|1.77320934263119| 0.379779593034328| -0.503198133318193| 1.80049938079263| 0.791460956450422| 0.247675786588991| -1.51465432260583| 0.207642865216696| 0.624501459424895| 0.066083685268831| 0.717292731410831|-0.165945922763554| 2.34586494901581| -2.89008319444231| 1.10996937869599| -0.121359313195888| -2.26185709530414| 0.524979725224404| 0.247998153469754| 0.771679401917229| 0.909412262347719|-0.689280956490685|-0.327641833735251|-0.139096571514147| -0.0553527940384261|-0.0597518405929204|378.66| 0|
| 1.0|-0.966271711572087| -0.185226008082898|1.79299333957872|-0.863291275036453|-0.0103088796030823| 1.24720316752486| 0.23760893977178| 0.377435874652262| -1.38702406270197|-0.0549519224713749|-0.226487263835401| 0.178228225877303| 0.507756869957169| -0.28792374549456|-0.631418117709045| -1.0596472454325|-0.684092786345479| 1.96577500349538| -1.2326219700892| -0.208037781160366| -0.108300452035545|0.00527359678253453|-0.190320518742841| -1.17557533186321| 0.647376034602038|-0.221928844458407| 0.0627228487293033| 0.0614576285006353| 123.5| 0|
| 2.0| -1.15823309349523| 0.877736754848451| 1.548717846511| 0.403033933955121| -0.407193377311653| 0.0959214624684256| 0.592940745385545|-0.270532677192282| 0.817739308235294| 0.753074431976354|-0.822842877946363| 0.53819555014995| 1.3458515932154| -1.11966983471731| 0.175121130008994|-0.451449182813529|-0.237033239362776|-0.0381947870352842| 0.803486924960175| 0.408542360392758|-0.00943069713232919| 0.79827849458971|-0.137458079619063| 0.141266983824769|-0.206009587619756| 0.502292224181569| 0.219422229513348| 0.215153147499206| 69.99| 0|
+----+------------------+-------------------+----------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------+-------------------+------------------+------------------+------------------+------------------+--------------------+-------------------+------+-----+
only showing top 5 rows
Step 5: Check Dataset Balance
This line groups the data by the Class column (indicating fraud or non-fraud) and counts the number of instances in each class to check for imbalance.
df.groupBy("Class").count().show()
Output:
+-----+------+
|Class| count|
+-----+------+
| 1| 492|
| 0|284315|
+-----+------+
Step 6: Oversample the Minority Class
To address class imbalance, the minority class (fraudulent transactions) is oversampled by duplicating entries to balance the dataset.
minority_df = df.filter(col('Class') == 1)
majority_df = df.filter(col('Class') == 0)
oversampled_minority_df = minority_df.sample(withReplacement=True, fraction=4.0, seed=42)
= majority_df.union(oversampled_minority_df)
balanced_df
Output:
DataFrame[Time: double, V1: double, V2: double, V3: double, V4: double, V5: double, V6: double, V7: double, V8: double, V9: double, V10: double, V11: double, V12: double, V13: double, V14: double, V15: double, V16: double, V17: double, V18: double, V19: double, V20: double, V21: double, V22: double, V23: double, V24: double, V25: double, V26: double, V27: double, V28: double, Amount: double, Class: int]Step 7: Assemble Features
The features (input variables) are assembled into a single vector, which is used as input for the machine learning model.
feature_cols = df.columns[:-1]
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
assembled_df = assembler.transform(balanced_df).select('features', 'Class')
assembled_df
Output:
DataFrame[features: vector, Class: int]Step 8: Split Data into Training and Test Sets
The dataset is split into training (70%) and test (30%) sets to evaluate the model’s performance.
train_df, test_df = assembled_df.randomSplit([0.7, 0.3], seed=42)
Step 9: Train a RandomForest Model
A RandomForest classifier is initialized and trained on the training data.
# code
print("GFG")
rf = RandomForestClassifier(featuresCol='features', labelCol='Class', numTrees=100)
model = rf.fit(train_df)
Step 10: Make Predictions
The trained model is used to predict the classes of the test dataset.
predictions = model.transform(test_df)
Step 11: Evaluate the Model
The model is evaluated using F1 score, precision, and recall to assess its performance, particularly on the minority class (fraud detection).
evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="f1")
f1_score = evaluator.evaluate(predictions)
print(f"F1 Score: {f1_score}")
precision_evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="weightedPrecision")
recall_evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="weightedRecall")
precision = precision_evaluator.evaluate(predictions)
recall = recall_evaluator.evaluate(predictions)
print(f"Precision: {precision}")
print(f"Recall: {recall}")
Output:
F1 Score: 0.9983080814787306
Precision: 0.9983450537543812
Recall: 0.9983863046376201
Step 12: Stop the Spark Session
Finally, the Spark session is stopped to free up resources.
spark.stop()
This step-by-step breakdown helps in understanding each phase of the process, from loading data to evaluating a machine learning model in a Spark environment.
Full Complete Code: Handling Unbalanced Dataset with Spark MLlib
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col
# Create a Spark session
spark = SparkSession.builder.appName("CreditCardFraudDetection").getOrCreate()
# Load the dataset
data_path = r"C:\Users\R.Daswanta kumar\Downloads\archive\creditcard.csv"
df = spark.read.csv(data_path, header=True, inferSchema=True)
# Display schema and first few rows
df.printSchema()
df.show(5)
# Check the balance of the dataset
df.groupBy("Class").count().show()
# Oversample the minority class
minority_df = df.filter(col('Class') == 1)
majority_df = df.filter(col('Class') == 0)
# Perform oversampling on the minority class (increase the fraction as needed)
oversampled_minority_df = minority_df.sample(withReplacement=True, fraction=4.0, seed=42)
balanced_df = majority_df.union(oversampled_minority_df)
# Assemble features
feature_cols = df.columns[:-1] # All columns except the label column 'Class'
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
assembled_df = assembler.transform(balanced_df).select('features', 'Class')
# Split data into training and test sets
train_df, test_df = assembled_df.randomSplit([0.7, 0.3], seed=42)
# Train a RandomForest model
rf = RandomForestClassifier(featuresCol='features', labelCol='Class', numTrees=100)
model = rf.fit(train_df)
# Make predictions
predictions = model.transform(test_df)
# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="f1")
f1_score = evaluator.evaluate(predictions)
print(f"F1 Score: {f1_score}")
# Precision and Recall
precision_evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="weightedPrecision")
recall_evaluator = MulticlassClassificationEvaluator(labelCol="Class", predictionCol="prediction", metricName="weightedRecall")
precision = precision_evaluator.evaluate(predictions)
recall = recall_evaluator.evaluate(predictions)
print(f"Precision: {precision}")
print(f"Recall: {recall}")
# Stop the Spark session
spark.stop()
Explanation of the Code:
- Loading the Dataset: The dataset is loaded from the specified path using Spark’s read.csv function, with headers and schema inferred automatically.
- Class Balance Check: The groupBy("Class").count().show() command checks the distribution of the classes, confirming the imbalance.
- Oversampling: The minority class is oversampled by duplicating its instances. This creates a balanced dataset that mitigates the bias towards the majority class.
- Feature Assembly: Features are assembled into a single vector using the VectorAssembler, excluding the label (Class).
- Model Training: The dataset is split into training and test sets, and a Random Forest classifier is trained on the training data.
- Evaluation: The model is evaluated using the F1 Score, Precision, and Recall, which are essential metrics for imbalanced datasets.
Challenges of Unbalanced Datasets in Spark
Handling unbalanced datasets in Spark involves unique challenges due to its distributed computing nature. The following are some of the key challenges:
- Data Partitioning: Spark divides data into partitions, which can exacerbate class imbalance if not managed correctly.
- Scalability: Techniques that work well on smaller datasets may not scale effectively in a distributed environment like Spark.
- Algorithm Limitations: Not all machine learning algorithms in Spark MLlib are designed to handle class imbalance natively
Conclusion
Using Spark MLlib to handle unbalanced datasets requires a trifecta of preparation, algorithm modifications, and meticulous assessment. You can create reliable models that function well even with severely unbalanced data by utilising strategies like oversampling, cost-sensitive learning, and ensemble approaches. By using this method, you can be sure that your machine learning models fairly and accurately represent both majority and minority classes, which will produce predictions that are more trustworthy in important applications.