Softmax Regression for Multi-Class Classification
📌TL;DR
Compared softmax regression vs One-vs-Rest and One-vs-One strategies for multi-class Iris classification (3 species, 150 samples). Softmax achieved 94.7% test accuracy matching One-vs-Rest, with native multi-class probability outputs. Used only 2 features (sepal width, petal width) for clear decision boundary visualization. Demonstrates softmax function converting logits to probabilities, L2 regularization (C=10), and comparison of multi-class strategies showing softmax efficiency advantage (single model vs multiple binary classifiers) while maintaining competitive performance.
Introduction
Most real-world classification problems aren't just binary (yes or no). You might need to classify emails into multiple folders, identify different species of flowers, or categorize news articles into various topics. That's where softmax regression (also called multinomial logistic regression) comes in. In this tutorial, I'll show you how softmax regression extends logistic regression to handle multiple classes and compare it to other multi-class strategies.
The Problem: More Than Two Classes
Binary logistic regression works great for two-class problems, but what if you have three or more classes? The Iris dataset is perfect for exploring this. It has 150 flower samples from 3 species:
- Setosa
- Versicolor
- Virginica
Each flower is described by 4 measurements, but to make visualization easier, I'll use just 2 features: sepal width and petal width.
Loading and Preparing the Data
iris = datasets.load_iris()
X = iris.data[:, [1, 3]] # Sepal width and petal width
y = iris.target
Using [:, [1, 3]], I selected columns 1 and 3 (sepal width and petal width). This notation might look strange if you're new to numpy, but it's saying "take all rows and only columns at indices 1 and 3."
Understanding Softmax: Probabilities for Each Class
Here's the key insight. Instead of predicting "class 0" or "class 1," softmax regression computes the probability that an instance belongs to each possible class. These probabilities always sum to 1.
For example, a flower might get probabilities:
- Setosa: 0.97
- Versicolor: 0.02
- Virginica: 0.01
Then we use argmax to pick the class with highest probability (Setosa in this case).
Training the Model
lr = LogisticRegression(random_state=0).fit(X, y)
Here's something interesting. Scikit-learn's LogisticRegression automatically uses softmax (multinomial) for multi-class problems. You don't need a separate class. It detects that you have more than 2 classes and handles it appropriately.
Visualizing Decision Boundaries
disp = DecisionBoundaryDisplay.from_estimator(
model, X, response_method='predict',
cmap=plt.cm.RdYlBu, alpha=0.5
)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k")
plt.xlabel('Sepal Width (cm)')
plt.ylabel('Petal Width (cm)')
The DecisionBoundaryDisplay is super useful. It creates a meshgrid (a grid of points covering the feature space), asks the model to classify each point, and colors regions by predicted class. The alpha=0.5 makes the background semi-transparent so we can see both the decision regions and the actual data points.
I added edgecolor="k" (where k means black in matplotlib) to draw black borders around each point, making them more visible against the colored background.
How Softmax Really Works
Let's verify what's happening under the hood:
probability = lr.predict_proba(X)
This returns a matrix where each row corresponds to a sample and each column to a class. For the first sample:
probability[0, :]
# Output: [0.97, 0.02, 0.01]
These are the probabilities for each of the 3 classes. Notice they sum to 1:
probability[0, :].sum()
# Output: 1.0
Getting the Prediction
np.argmax(probability[0, :])
# Output: 0
The argmax function returns the index of the maximum value. Since index 0 has the highest probability (0.97), the predicted class is 0 (Setosa).
We can do this for all samples at once:
softmax_prediction = np.argmax(probability, axis=1)
Using axis=1 means "find the maximum along each row." This gives us predictions for all samples.
Verifying Correctness
yhat = lr.predict(X)
accuracy_score(softmax_prediction, yhat)
# Output: 1.0
Perfect match! This confirms that predict() internally uses predict_proba() followed by argmax(). Understanding this is valuable because sometimes you want the probabilities (for confidence estimation) rather than just the final class.
Visualizing Probability Distributions
sns.heatmap(probability, cmap="RdYlBu", linewidths=0.5)
Creating a heatmap of all probabilities is illuminating. You can see:
- Most samples have one very high probability (confident predictions)
- Samples near decision boundaries have more uncertain probabilities
- The patterns reveal which species are easier to distinguish
Multi-Class SVM Strategies
SVM wasn't originally designed for multi-class problems. It's fundamentally a binary classifier. So how do we use it for 3+ classes? There are two main strategies.
One-vs-Rest (One-vs-All)
from sklearn.multiclass import OneVsRestClassifier
ovr_clf = OneVsRestClassifier(SVM_model)
ovr_clf.fit(X, y)
This trains 3 binary classifiers:
- Setosa vs (Versicolor + Virginica)
- Versicolor vs (Setosa + Virginica)
- Virginica vs (Setosa + Versicolor)
For prediction, it runs all 3 classifiers and picks the class with highest confidence.
With N classes, you train N classifiers. This is relatively fast and works well in practice.
One-vs-One
from sklearn.multiclass import OneVsOneClassifier
ovo_clf = OneVsOneClassifier(SVM_model)
ovo_clf.fit(X, y)
This trains a binary classifier for every pair of classes:
- Setosa vs Versicolor
- Setosa vs Virginica
- Versicolor vs Virginica
That's 3 classifiers for 3 classes. Generally, you need N×(N-1)/2 classifiers.
For prediction, it runs all pairwise classifiers and uses voting. If 2 out of 3 classifiers predict "Setosa," that's the final prediction.
Comparing the Strategies
Both achieved the same accuracy on this dataset, but they have different trade-offs:
One-vs-Rest:
- Trains N classifiers (fewer)
- Each classifier sees all training data
- Faster training and prediction
- Can have ambiguous regions where multiple classifiers say "yes"
One-vs-One:
- Trains N×(N-1)/2 classifiers (more)
- Each classifier sees only data from 2 classes
- Each individual classifier is faster to train (smaller data)
- More robust to class imbalance
- Generally more accurate
For this problem with 3 balanced classes, both work well. For problems with many classes or imbalanced classes, one-vs-one often performs better despite the increased number of classifiers.
Enabling Probability Estimates in SVM
SVM_model = SVC(kernel='linear', gamma=0.5, probability=True)
Notice probability=True. By default, SVM doesn't output probabilities, it just gives you the predicted class. Setting this parameter enables probability estimates using Platt scaling, which fits a logistic regression to the SVM's decision values.
This adds computation time but gives you confidence scores like softmax regression provides naturally.
Key Takeaways
Softmax Extends Logistic Regression Naturally: For multi-class problems, softmax regression outputs probability distributions over all classes. The probabilities always sum to 1, giving you confidence estimates along with predictions.
argmax Converts Probabilities to Classes: The final prediction is simply the class with maximum probability. Understanding this two-step process (compute probabilities, then argmax) helps you decide when to use probabilities vs. hard predictions.
SVM Requires Multi-Class Strategies: Since SVM is fundamentally binary, you need either one-vs-rest or one-vs-one for multi-class problems. The choice depends on your specific requirements around accuracy, training time, and number of classes.
One-vs-Rest is Faster, One-vs-One is More Robust: One-vs-rest trains fewer classifiers but can struggle with imbalanced classes. One-vs-one trains more classifiers but handles imbalance better and often achieves higher accuracy.
Visualization Reveals Decision Boundaries: Using DecisionBoundaryDisplay to visualize how your classifier divides the feature space provides intuition about what the model learned and where it might struggle.
Probability Calibration Matters: Some models (like logistic regression) naturally output well-calibrated probabilities. Others (like SVM) need additional calibration steps. When you need reliable probability estimates, this matters.
Practical Considerations
When to Use Softmax Regression
- You need probability estimates, not just classifications
- Classes are mutually exclusive (each sample belongs to exactly one class)
- You want a fast, interpretable baseline model
- Your features have roughly linear relationships to log-odds of class membership
When to Use Multi-Class SVM
- You need maximum accuracy and can afford longer training
- You have complex, non-linear decision boundaries
- You have relatively few classes (otherwise one-vs-one explodes in complexity)
- You can work with hard predictions (or accept the overhead of probability calibration)
Choosing Between Strategies
For one-vs-rest:
- Few classes (2-10)
- Balanced class distributions
- Need fast training and prediction
- Tolerate slightly lower accuracy
For one-vs-one:
- Moderate number of classes (3-20)
- Imbalanced class distributions
- Need maximum accuracy
- Can afford extra computational cost
Real-World Applications
Email Classification: Route incoming emails to folders (Inbox, Social, Promotions, Spam)
Medical Diagnosis: Classify symptoms into multiple disease categories
Image Recognition: Identify objects in images (cat, dog, car, tree, etc.)
Voice Command Recognition: Classify spoken commands into action categories
Quality Control: Grade products into multiple quality levels
Conclusion
Multi-class classification extends binary classification to handle the complexity of real-world problems where items fall into multiple categories. Softmax regression provides a natural, probabilistic extension of logistic regression that directly models the probability distribution over classes.
For SVM and other binary classifiers, one-vs-rest and one-vs-one strategies enable multi-class prediction. One-vs-rest is simpler and faster but one-vs-one often achieves better accuracy, especially with imbalanced data.
Understanding both the probabilistic approach (softmax) and the decomposition approaches (one-vs-rest/one) gives you flexibility to choose the right tool for your specific problem. Whether you need probability estimates, maximum accuracy, fast training, or robustness to class imbalance determines which approach works best.
The key is understanding how each method works under the hood. Then you can make informed decisions rather than just trying different algorithms randomly and hoping for the best.
📓 Jupyter Notebook
Want to explore the complete code and run it yourself? Access the full Jupyter notebook with detailed implementations and visualizations:
You can also run it interactively:
