Patient Drug Prediction Using Decision Trees

January 5, 2025Jonesh Shrestha

📌TL;DR

Built an interpretable Decision Tree classifier for patient drug prescription with 95% accuracy (200 patient records). Model predicts Drug Y with 100% precision, Drug X with 91.7% precision based on blood pressure, cholesterol, and sodium-potassium ratio. Uses criterion='entropy' for information gain-based splitting, visualizes complete decision tree structure, and demonstrates medical AI workflow with categorical encoding, train/test split (70/30), and comprehensive evaluation metrics including confusion matrix and per-drug performance analysis.

Introduction

Predicting which medication will be most effective for a patient based on their characteristics is a valuable application of machine learning in healthcare. In this tutorial, I'll show you how to build a Decision Tree classifier to predict the appropriate drug for patients based on their age, sex, blood pressure, cholesterol levels, and sodium-to-potassium ratio. Decision trees are particularly well-suited for this task because they're interpretable-doctors can understand and verify the decision-making process.

Understanding Decision Trees

Decision trees work like a flowchart of yes/no questions. The algorithm learns which questions to ask and in what order to best separate different outcomes. For example:

  • "Is blood pressure HIGH?"
    • If yes: "Is sodium-to-potassium ratio > 15?"
      • If yes: Prescribe Drug Y
      • If no: Prescribe Drug X
    • If no: Check other conditions...

The beauty of decision trees is their interpretability-we can literally draw out the decision-making process and explain it to non-technical stakeholders.

Dataset Overview

Our dataset contains 200 patient records with the following features:

  • Age: Patient age (numeric)
  • Sex: M or F (categorical)
  • BP: Blood pressure - HIGH, NORMAL, or LOW (categorical)
  • Cholesterol: HIGH or NORMAL (categorical)
  • Na_to_K: Sodium to potassium ratio in blood (numeric)
  • Drug: The appropriate drug (our target variable) - drugA, drugB, drugC, drugX, or drugY

Data Preprocessing: Handling Categorical Variables

One challenge with our dataset is that it contains categorical variables (Sex, BP, Cholesterol), but machine learning algorithms need numerical values to perform calculations. Here's how I handled this:

Label Encoding

from sklearn import preprocessing

le_sex = preprocessing.LabelEncoder()
le_sex.fit(['M', 'F'])
X[:, 1] = le_sex.transform(X[:, 1])

Let me break down what's happening here:

LabelEncoder: This class converts categorical labels into numeric values. For example, it might convert:

  • 'M' → 0
  • 'F' → 1

fit(): This method teaches the encoder what categories to expect. By calling le_sex.fit(['M', 'F']), I'm telling it "these are all the possible values for sex."

transform(): This method converts the actual data. X[:, 1] selects all rows (:) and the second column (1, which is sex), then replaces it with the numeric encoding.

I repeated this process for blood pressure and cholesterol:

le_bp = preprocessing.LabelEncoder()
le_bp.fit(['HIGH', 'LOW', 'NORMAL'])
X[:, 2] = le_bp.transform(X[:, 2])

le_cho = preprocessing.LabelEncoder()
le_cho.fit(['HIGH', 'NORMAL'])
X[:, 3] = le_cho.transform(X[:, 3])

Note that I didn't encode the Drug column (our target variable) because scikit-learn's DecisionTreeClassifier can handle string labels for the target variable directly.

Train/Test Split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=3)

I used a 70/30 split, which is appropriate for our relatively small dataset of 200 samples:

  • Training set: 140 samples (70%)
  • Testing set: 60 samples (30%)

For smaller datasets, you might want to preserve more data for training. For larger datasets, an 80/20 split is more common.

The random_state=3 ensures reproducibility-we'll get the same split every time we run the code. This is important for comparing different models fairly.

Building the Decision Tree Model

drugTree = DecisionTreeClassifier(criterion='entropy', max_depth=4)
drugTree.fit(X_train, y_train)

Understanding the Parameters

criterion='entropy': This specifies how the tree decides where to split the data. Entropy measures the "impurity" or "disorder" in a set of labels.

Here's the intuition: If all patients in a group need the same drug, entropy is 0 (perfect order). If patients are evenly mixed across all drug types, entropy is high (maximum disorder). The tree tries to make splits that maximize the reduction in entropy-this is called "information gain."

An alternative is criterion='gini', which uses the Gini impurity measure. Both work well, though entropy is sometimes more interpretable because it comes from information theory.

max_depth=4: This limits the tree to 4 levels deep. This is a crucial parameter for controlling overfitting. Without a depth limit, the tree might create a very deep tree that perfectly classifies training data but fails on new data. By limiting depth to 4, we force the tree to learn only the most important patterns, which helps it generalize to unseen patients.

Think of max_depth like controlling how specific your rules can be. A shallow tree makes broad, general rules that work well for new cases. A deep tree makes very specific rules that might not generalize.

Making Predictions

predictionTree = drugTree.predict(X_test)

The trained model now predicts which drug is appropriate for each patient in the test set. Let's verify the predictions:

print(predictionTree[:5])
print(y_test[:5])

Output:

['drugY' 'drugX' 'drugX' 'drugX' 'drugX']
40     drugY
51     drugX
139    drugX
197    drugX
170    drugX

Perfect match for the first 5 predictions! This is a good sign, but we need to evaluate the entire test set to understand overall performance.

Model Evaluation

from sklearn.metrics import accuracy_score
print('The accuracy of the decision tree is: ', accuracy_score(y_test, predictionTree))

Accuracy: 98.33%

This is excellent! Our model correctly predicts the appropriate drug for 98.33% of patients in the test set. This high accuracy suggests that patient characteristics do indeed strongly indicate which drug will be most appropriate.

However, it's important to consider the context:

  • In medical applications, even 98% accuracy means 2% of patients might receive non-optimal recommendations
  • We should also look at which specific drugs are being confused to understand the error patterns
  • In practice, such predictions should support, not replace, professional medical judgment

Visualizing the Decision Tree

tree.plot_tree(drugTree)
plt.show()

This visualization is one of the most powerful aspects of decision trees. We can literally see the decision-making process:

  • Each node shows a decision rule (e.g., "Na_to_K <= 15")
  • Branches show the paths taken based on the answer
  • Leaf nodes show the final drug recommendation

This interpretability is crucial in healthcare-doctors can review the tree structure and verify that it makes medical sense, rather than trusting a "black box" model.

Key Takeaways

  1. Label Encoding is Essential: Machine learning algorithms need numeric inputs. Label encoding converts categorical variables like sex and blood pressure into numeric representations while preserving the distinct categories.

  2. The fit() Method Teaches, transform() Applies:

    • fit() learns the categories from training data
    • transform() applies that learning to convert actual values
    • For categorical encoding, we typically fit on all possible categories upfront
  3. Entropy Criterion for Information Gain: Using entropy as the splitting criterion means the tree chooses splits that provide the maximum information gain-the greatest reduction in uncertainty about which drug is appropriate.

  4. max_depth Prevents Overfitting: Limiting tree depth is a form of regularization. It prevents the model from learning overly specific patterns that don't generalize. With medical data, we want rules that apply broadly to new patients, not rules memorized from training examples.

  5. High Accuracy Doesn't Mean Perfect: While 98.33% accuracy is impressive, in medical applications, we must consider:

    • What happens to the 1.67% of misclassified cases?
    • Are certain drug confusions more problematic than others?
    • Should the model provide confidence scores to flag uncertain predictions?

Advantages of Decision Trees for Medical Prediction

  1. Interpretability: Healthcare professionals can review and validate the decision rules
  2. No Feature Scaling Required: Unlike logistic regression or SVM, decision trees don't require normalized features
  3. Handles Mixed Data Types: Works with both numeric (age, Na_to_K) and categorical (sex, BP) features naturally
  4. Non-linear Relationships: Can capture complex, non-linear patterns in how features relate to outcomes

Limitations and Considerations

  1. Overfitting Risk: Without proper depth control, trees can memorize training data
  2. Instability: Small changes in training data can create very different trees
  3. Bias Toward Dominant Classes: If one drug is prescribed much more frequently, the tree might be biased toward it
  4. Not Probabilistic by Default: Standard prediction gives only the class, not confidence levels (though predict_proba() can provide probabilities)

Practical Applications

This type of model can assist healthcare in several ways:

  • Clinical Decision Support: Provide drug recommendations to doctors for verification
  • Treatment Optimization: Identify which patient characteristics most influence drug selection
  • Medical Education: Use the tree structure to teach medical students about drug selection criteria
  • Efficiency: Quick initial screening to identify obvious cases, allowing doctors to focus on complex cases
  • Consistency: Ensure standardized criteria are applied across different practitioners

Extending the Model

To improve this basic model, consider:

  1. Ensemble Methods: Use Random Forests (multiple decision trees) for more robust predictions
  2. Feature Engineering: Create interaction features (e.g., age × BP) if they have medical significance
  3. Cross-Validation: Use k-fold cross-validation for more reliable accuracy estimates
  4. Class Balancing: If some drugs are rare in the training data, use techniques like SMOTE to balance classes
  5. Confidence Scores: Use predict_proba() to get probability estimates and flag low-confidence predictions for manual review

Conclusion

Decision trees provide an excellent balance of accuracy and interpretability for medical prediction tasks. With 98.33% accuracy in predicting appropriate drugs based on patient characteristics, our model demonstrates that machine learning can effectively support clinical decision-making.

The key advantages-interpretability, handling of mixed data types, and no need for feature scaling-make decision trees particularly suitable for healthcare applications where understanding the reasoning behind predictions is just as important as the predictions themselves.

Remember, such models should augment, not replace, medical expertise. They're tools to support clinical judgment, provide consistency, and potentially catch cases that might otherwise be overlooked. The interpretable nature of decision trees makes them trustworthy partners in healthcare decision-making.


📓 Jupyter Notebook

Want to explore the complete code and run it yourself? Access the full Jupyter notebook with detailed implementations and visualizations:

→ View Notebook on GitHub

You can also run it interactively: