Python Decision Trees: Classification, Regression, and Tuning with scikit-learn

Decision trees are among the most intuitive machine learning algorithms available. They mirror the way humans naturally make decisions -- by asking a series of yes-or-no questions to arrive at a conclusion. In Python, scikit-learn (version 1.8 as of this writing) provides a powerful and straightforward implementation through its DecisionTreeClassifier and DecisionTreeRegressor classes. This article walks through how they work, how to build them, and how to tune them for real-world performance.

If you have ever used a flowchart to troubleshoot a problem, you already understand the core idea behind decision trees. The algorithm takes a dataset, identifies which feature best separates the data at each step, and builds a tree-shaped model of nested if-then-else rules. The result is a model that is easy to interpret, works with both numerical and categorical data, and requires minimal preprocessing compared to many other algorithms.

What Is a Decision Tree?

A decision tree is a supervised learning algorithm used for both classification (predicting categories) and regression (predicting continuous values). It works by recursively partitioning data into subsets based on the values of input features, creating a branching structure that leads to predictions at the terminal nodes.

There are a few key terms to understand before going further. The root node is the topmost node, representing the entire dataset before any splitting occurs. Internal nodes (also called decision nodes) are points where the data is split based on a condition applied to a feature. Branches represent the outcome of each split. Leaf nodes are the terminal endpoints that provide the final prediction -- a class label in classification or a numerical value in regression.

At each internal node, the algorithm evaluates every feature and every possible threshold to find the split that best separates the target variable. It repeats this process on each resulting subset until it reaches a stopping condition, such as a maximum tree depth or a minimum number of samples in a node.

Note

Decision trees are non-parametric, meaning they make no assumptions about the underlying data distribution. This makes them flexible for a wide range of problems, but also more prone to overfitting if left unconstrained.

Building a Classification Tree

Classification trees predict discrete class labels. scikit-learn's DecisionTreeClassifier handles this task. The following example uses the well-known Iris dataset, which contains measurements of 150 iris flowers across three species.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Load the dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split into training and test sets (80/20)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=42
)

# Create and train the classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

# Make predictions and evaluate
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

This code loads the data, splits it into training and test sets using an 80/20 ratio, creates a DecisionTreeClassifier with default settings, trains it on the training data, and then evaluates its accuracy on the held-out test set. The random_state parameter ensures reproducible results.

With default settings, the tree grows until every leaf is pure (contains samples from only one class) or until another stopping condition is met. This often produces a very accurate model on training data, but it can overfit. We will address that in the tuning section below.

Predicting Probabilities

In addition to hard class predictions, you can retrieve the probability estimates for each class using the predict_proba() method. This is useful when you need confidence scores rather than just a label.

# Get probability estimates for each class
probabilities = clf.predict_proba(X_test[:5])

for i, prob in enumerate(probabilities):
    predicted_class = iris.target_names[y_pred[i]]
    print(f"Sample {i}: {prob} -> {predicted_class}")

Building a Regression Tree

Regression trees predict continuous values instead of class labels. scikit-learn provides the DecisionTreeRegressor class for this purpose. The API is nearly identical to the classifier -- the primary difference is that the splitting criterion defaults to squared error instead of Gini impurity, and evaluation metrics change accordingly.

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score

# Load the California Housing dataset
housing = fetch_california_housing()
X, y = housing.data, housing.target

# Split the data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=42
)

# Create and train the regressor
reg = DecisionTreeRegressor(max_depth=6, random_state=42)
reg.fit(X_train, y_train)

# Make predictions and evaluate
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Mean Squared Error: {mse:.4f}")
print(f"R-squared: {r2:.4f}")

Notice that max_depth=6 is set explicitly here. Without it, the regression tree would grow until every leaf contains a single sample, producing a severely overfitting model. Constraining the depth forces the tree to generalize better, though the optimal depth depends on the dataset.

Pro Tip

Decision trees do not require feature scaling. Unlike algorithms such as SVMs or k-nearest neighbors, decision trees split on feature thresholds directly, so the magnitude of feature values has no effect on performance. You can skip the StandardScaler step.

How Splitting Works: Gini vs. Entropy

The criterion parameter controls how the tree measures the quality of each potential split. For classification trees, scikit-learn offers two options: "gini" (the default) and "entropy".

Gini impurity measures the probability that a randomly chosen sample from a node would be misclassified if labeled according to the distribution of classes in that node. A Gini value of 0 means the node is pure (all samples belong to one class). The formula is: Gini = 1 - sum(p_i^2), where p_i is the proportion of samples belonging to class i.

Entropy (information gain) measures the amount of disorder or uncertainty in a node. A node with entropy of 0 is completely pure. The formula is: Entropy = -sum(p_i * log2(p_i)). When used as a splitting criterion, the algorithm chooses the split that maximizes the reduction in entropy (the information gain).

# Compare Gini vs. Entropy on the same dataset
clf_gini = DecisionTreeClassifier(criterion="gini", random_state=42)
clf_entropy = DecisionTreeClassifier(criterion="entropy", random_state=42)

clf_gini.fit(X_train, y_train)
clf_entropy.fit(X_train, y_train)

acc_gini = accuracy_score(y_test, clf_gini.predict(X_test))
acc_entropy = accuracy_score(y_test, clf_entropy.predict(X_test))

print(f"Gini accuracy:    {acc_gini:.4f}")
print(f"Entropy accuracy: {acc_entropy:.4f}")

In practice, both criteria produce similar results in many cases. Gini is the default because it is slightly faster to compute (no logarithm calculation). However, entropy can sometimes produce more balanced trees because information gain tends to favor splits that create purer nodes more aggressively.

For regression trees, the default criterion is "squared_error" (equivalent to variance reduction). scikit-learn also offers "friedman_mse" (Friedman's improvement score) and "absolute_error" as alternatives.

Visualizing Your Tree

One of the biggest advantages of decision trees is that you can actually see the model's logic. scikit-learn provides the plot_tree() function for quick visualization and export_text() for a text-based representation.

from sklearn.tree import plot_tree, export_text
import matplotlib.pyplot as plt

# Train a shallow tree for a cleaner visualization
clf_viz = DecisionTreeClassifier(max_depth=3, random_state=42)
clf_viz.fit(X_train, y_train)

# Plot the tree
plt.figure(figsize=(16, 8))
plot_tree(
    clf_viz,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title("Decision Tree - Iris Dataset (max_depth=3)")
plt.tight_layout()
plt.savefig("decision_tree_iris.png", dpi=150)
plt.show()

The filled=True parameter colors each node based on the majority class, making it easy to trace the decision path visually. The rounded=True parameter adds rounded corners for a cleaner appearance.

For situations where a graphical plot is not practical (such as logging or terminal output), the text representation is valuable:

# Text-based tree representation
tree_rules = export_text(
    clf_viz,
    feature_names=iris.feature_names
)
print(tree_rules)

This outputs the tree as a series of indented rules, showing each split condition, the number of samples, and the predicted class at each leaf. It is especially useful for debugging or documenting your model's logic.

Feature Importance

After training, you can inspect which features contributed the most to the model's decisions using the feature_importances_ attribute. This returns an array where each value represents the total reduction in the splitting criterion (Gini or entropy) contributed by that feature across all nodes.

# Display feature importances
importances = clf.feature_importances_

for name, importance in zip(iris.feature_names, importances):
    print(f"{name}: {importance:.4f}")

Hyperparameter Tuning and Pruning

An unconstrained decision tree will keep splitting until every leaf is pure, which almost always leads to overfitting. The model memorizes the training data, including its noise, and performs poorly on new, unseen samples. Tuning hyperparameters is how you control this behavior.

Here are the key hyperparameters for decision trees in scikit-learn:

max_depth limits how many levels the tree can grow. A shallow tree (e.g., depth 3-5) captures only the strongest patterns. A deep tree captures finer details but risks overfitting. Setting this to None (the default) allows unlimited growth.

min_samples_split sets the minimum number of samples a node must contain before it can be split further. Higher values prevent the tree from creating branches based on very few data points. The default is 2, meaning a node with just 2 samples can still be split.

min_samples_leaf sets the minimum number of samples required at each leaf node. If a potential split would create a leaf with fewer samples than this threshold, the split is rejected. This acts as a smoothing mechanism that prevents the tree from creating overly specific leaves.

max_features controls how many features the tree considers when searching for the best split at each node. When set to a value less than the total number of features, it introduces randomness that can reduce overfitting. Options include an integer count, a float representing a percentage, "sqrt", or "log2".

ccp_alpha is the complexity parameter for Minimal Cost-Complexity Pruning, introduced in scikit-learn 0.22. This is a form of post-pruning: after the tree is fully grown, subtrees that contribute less than ccp_alpha to the overall model performance are removed. A value of 0 (the default) means no pruning.

# Example: constrained decision tree
clf_tuned = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=4,
    max_features="sqrt",
    ccp_alpha=0.01,
    random_state=42
)
clf_tuned.fit(X_train, y_train)

y_pred_tuned = clf_tuned.predict(X_test)
accuracy_tuned = accuracy_score(y_test, y_pred_tuned)
print(f"Tuned accuracy: {accuracy_tuned:.4f}")
print(f"Tree depth: {clf_tuned.get_depth()}")
print(f"Number of leaves: {clf_tuned.get_n_leaves()}")
Warning

Setting hyperparameters too aggressively (e.g., max_depth=2 on a complex dataset) will cause underfitting. The tree will be too simple to capture meaningful patterns. Finding the right balance requires experimentation, and automated tuning (covered next) is the best approach.

Understanding Cost-Complexity Pruning

Cost-complexity pruning is a principled way to find the right tree size. Instead of guessing at depth or leaf constraints, you can use the cost_complexity_pruning_path() method to compute the effective alpha values at which subtrees are pruned, then select the best one through cross-validation.

import numpy as np

# Compute the pruning path
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

# Train a tree for each alpha and record test accuracy
accuracies = []
for alpha in ccp_alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    tree.fit(X_train, y_train)
    accuracies.append(accuracy_score(y_test, tree.predict(X_test)))

# Find the alpha that gives the best test accuracy
best_idx = np.argmax(accuracies)
best_alpha = ccp_alphas[best_idx]
print(f"Best ccp_alpha: {best_alpha:.6f}")
print(f"Best accuracy:  {accuracies[best_idx]:.4f}")

Automated Tuning with GridSearchCV

Manually testing hyperparameter combinations is tedious and error-prone. scikit-learn's GridSearchCV automates this process by exhaustively searching through a specified grid of parameter values and evaluating each combination using cross-validation.

from sklearn.model_selection import GridSearchCV

# Define the parameter grid
param_grid = {
    "max_depth": [3, 5, 7, 10, None],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4],
    "criterion": ["gini", "entropy"]
}

# Set up GridSearchCV with 5-fold cross-validation
grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring="accuracy",
    n_jobs=-1,
    verbose=0
)

grid_search.fit(X_train, y_train)

# Results
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best CV accuracy: {grid_search.best_score_:.4f}")

# Evaluate on test set
best_clf = grid_search.best_estimator_
test_accuracy = accuracy_score(y_test, best_clf.predict(X_test))
print(f"Test accuracy: {test_accuracy:.4f}")

The n_jobs=-1 parameter tells scikit-learn to use all available CPU cores for parallel processing, which significantly speeds up the search. The cv=5 parameter means each parameter combination is evaluated using 5-fold cross-validation, providing a more robust estimate of performance than a single train-test split.

Pro Tip

For larger parameter spaces, consider RandomizedSearchCV instead. It samples a fixed number of random parameter combinations rather than testing all of them, which is more efficient when the search space is large. You can also use BayesSearchCV from the scikit-optimize library for smarter, guided search.

Strengths, Weaknesses, and When to Use Them

Decision trees have clear strengths that make them a go-to choice for many tasks. They are easy to understand and explain -- you can show the tree to a non-technical stakeholder and walk through the logic. They require minimal data preprocessing (no scaling, no encoding of ordinal features). They handle both numerical and categorical targets, and the training process is fast compared to many other algorithms. As of scikit-learn 1.4, they also support monotonic constraints via the monotonic_cst parameter, which lets you enforce domain knowledge that a feature should have a consistently positive or negative relationship with the target.

However, decision trees also have notable weaknesses. They are prone to overfitting, especially on noisy datasets. They can be unstable -- small changes in the data can result in a completely different tree structure. Their predictions are piecewise constant (not smooth), which makes them poor at extrapolation. And finding the globally optimal tree is computationally intractable (NP-complete), so practical algorithms use greedy heuristics that may miss the best overall structure.

Decision trees are an excellent choice when interpretability is a priority, when the dataset is small to medium in size, or when you need a quick baseline model. For larger or more complex problems, consider ensemble methods like Random Forests or Gradient Boosted Trees (e.g., XGBoost, LightGBM), which combine many decision trees to achieve better performance while mitigating individual tree weaknesses.

Key Takeaways

  1. Two classes, one API: Use DecisionTreeClassifier for categorical targets and DecisionTreeRegressor for continuous targets. The training and prediction workflow is the same for both.
  2. Overfitting is the primary risk: An unconstrained tree will memorize your training data. Always set constraints like max_depth, min_samples_leaf, or ccp_alpha to control tree complexity.
  3. Gini and entropy produce similar results: Gini is slightly faster and is the default. Entropy can favor more balanced splits but rarely changes the outcome significantly.
  4. Visualization is a superpower: Use plot_tree() and export_text() to inspect your model's logic. Use feature_importances_ to understand which features drive predictions.
  5. Automate tuning: Use GridSearchCV or RandomizedSearchCV to systematically find the best hyperparameter combination rather than guessing manually.
  6. Know when to level up: When a single decision tree is not performing well enough, ensemble methods like Random Forest and Gradient Boosting build on the same tree concept to deliver stronger results.

Decision trees remain one of the foundational algorithms in machine learning. Their transparency, simplicity, and versatility make them a practical tool for everything from quick data exploration to production-ready models. By understanding how to build, visualize, and tune them effectively with scikit-learn, you add a reliable and interpretable technique to your toolkit that serves as both a standalone solution and a building block for more advanced methods.

back to articles