Python Decision Tree Regression

Decision Tree Regression is a supervised machine learning algorithm that predicts continuous numerical values by recursively partitioning data into smaller subsets based on feature thresholds. Unlike linear regression, which fits a single line through the data, a decision tree learns a series of if-then rules that divide the feature space into rectangular regions, each with its own predicted value. This article walks through how decision tree regression works, how to implement it in Python with scikit-learn, and how to tune and prune the tree for better generalization.

Regression problems require predicting a continuous value rather than a discrete class label. While linear models assume a straight-line relationship between features and the target, real-world data often contains nonlinear patterns, thresholds, and interactions that a line cannot capture. Decision tree regression handles these scenarios naturally. It builds a tree structure where each internal node tests a feature against a threshold, each branch represents the outcome of that test, and each leaf node holds a predicted value—typically the mean of the training samples that landed in that region.

How Decision Tree Regression Works

A decision tree regressor works by splitting the dataset recursively. At each node, the algorithm examines every feature and every possible threshold within that feature, then selects the split that minimizes a chosen impurity measure. For regression, the default criterion in scikit-learn is squared_error, which measures the variance of the target values within each resulting partition.

The splitting process continues until a stopping condition is met—such as reaching a maximum depth, having too few samples to split further, or achieving a minimum impurity decrease. Once the tree is built, predicting a new data point means traversing the tree from root to leaf by following the decision rules at each node. The predicted value is the mean of the training targets in the leaf node where the data point ends up.

Note

Because predictions are piecewise constant (the mean of training samples in each leaf), decision tree regression produces step-like prediction curves rather than smooth ones. Increasing the tree depth increases the number of steps, allowing the model to approximate more complex functions—but also increasing the risk of overfitting.

The splitting criterion options available in scikit-learn's DecisionTreeRegressor (version 1.8.0) are:

  • squared_error — Minimizes the mean squared error (MSE) within each partition. This is the default and equivalent to variance reduction.
  • friedman_mse — A variation of MSE that uses Friedman's improvement score, which can produce slightly better splits in practice.
  • absolute_error — Minimizes the mean absolute error (MAE), making splits more robust to outliers than squared error.
  • poisson — Uses Poisson deviance as the splitting criterion, suitable for count data or targets that are strictly non-negative.

Building a Decision Tree Regressor in Python

The following example demonstrates how to train a DecisionTreeRegressor on a synthetic sine wave dataset with added noise. This is a common pedagogical example because it clearly shows how the tree's piecewise constant predictions approximate a smooth curve.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# Generate synthetic data: noisy sine wave
np.random.seed(42)
X = np.sort(5 * np.random.rand(200, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Create and train the decision tree regressor
regressor = DecisionTreeRegressor(max_depth=4, random_state=42)
regressor.fit(X_train, y_train)

# Predict on the test set
y_pred = regressor.predict(X_test)

# Evaluate performance
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"Mean Squared Error: {mse:.4f}")
print(f"R-squared Score:    {r2:.4f}")

This code generates 200 data points from a noisy sine function, splits them into training and test sets, fits a decision tree with a maximum depth of 4, and evaluates the model using mean squared error and the R-squared coefficient.

Visualizing Predictions Against the True Function

To see how the decision tree approximates the underlying sine curve, the following code generates a dense set of test points and plots the tree's predictions alongside the original data.

# Create a dense grid for smooth prediction visualization
X_grid = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_grid_pred = regressor.predict(X_grid)

# Plot the results
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, s=20, edgecolor="black",
            c="darkorange", label="Training data", alpha=0.7)
plt.scatter(X_test, y_test, s=30, edgecolor="black",
            c="royalblue", label="Test data", marker="^", alpha=0.8)
plt.plot(X_grid, y_grid_pred, color="crimson",
         label=f"Prediction (max_depth=4)", linewidth=2)
plt.xlabel("Feature (X)")
plt.ylabel("Target (y)")
plt.title("Decision Tree Regression: Noisy Sine Wave")
plt.legend()
plt.tight_layout()
plt.show()

The resulting plot reveals the characteristic staircase pattern of decision tree regression. Each horizontal segment corresponds to a leaf node in the tree, predicting the average of the training samples that fall within that region of the feature space.

Key Hyperparameters

The DecisionTreeRegressor in scikit-learn exposes several hyperparameters that control the tree's complexity and how it fits the data. Tuning these parameters is essential for balancing model accuracy against overfitting. Here are the parameters that matter in practice:

  • max_depth — The maximum number of levels the tree can grow. Setting this to None (the default) allows the tree to expand until all leaves are pure or contain fewer than min_samples_split samples. In practice, limiting max_depth is one of the simplest and most effective ways to prevent overfitting.
  • min_samples_split — The minimum number of samples required to split an internal node. The default is 2. Raising this value forces the tree to generalize more by preventing splits on very small groups of data points.
  • min_samples_leaf — The minimum number of samples required to form a leaf node. The default is 1. Increasing this acts as a smoothing mechanism because it ensures each prediction region contains enough data to produce a stable average.
  • max_features — The number of features to consider when looking for the best split. Useful in high-dimensional datasets to introduce randomness and reduce overfitting. Options include "sqrt", "log2", an integer count, or a float representing a fraction of total features.
  • ccp_alpha — The complexity parameter for Minimal Cost-Complexity Pruning. A non-negative value where higher values prune the tree more aggressively. Covered in detail in the next section.
  • monotonic_cst — An array of monotonicity constraints for each feature. Set to 1 for an increasing constraint, -1 for decreasing, or 0 for no constraint. This is useful when domain knowledge dictates that the target should increase or decrease with a particular feature.
Pro Tip

Start by setting max_depth to a small value (3-5) and gradually increase it while monitoring performance on a validation set. A tree that is too shallow underfits, while a tree that is too deep memorizes the training data and performs poorly on unseen inputs.

Cost-Complexity Pruning

Pruning is the process of reducing the size of a fully grown tree by removing branches that contribute little to predictive accuracy. Scikit-learn implements Minimal Cost-Complexity Pruning through the ccp_alpha parameter. This approach defines a cost-complexity measure for each subtree and prunes nodes that fall below the threshold set by alpha.

The cost-complexity measure for a tree T is defined as the sum of the total leaf impurity plus alpha multiplied by the number of leaf nodes. A higher alpha penalizes larger trees more heavily, resulting in a simpler model.

Scikit-learn provides a method called cost_complexity_pruning_path that computes the effective alpha values for the entire pruning sequence. Here is how to use it to find the optimal alpha through cross-validation:

from sklearn.model_selection import cross_val_score

# Train a full tree to get the pruning path
full_tree = DecisionTreeRegressor(random_state=42)
full_tree.fit(X_train, y_train)

# Get the effective alphas and their corresponding impurities
pruning_path = full_tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = pruning_path.ccp_alphas
impurities = pruning_path.impurities

# Evaluate each alpha using cross-validation
cv_scores = []
for alpha in ccp_alphas:
    tree = DecisionTreeRegressor(ccp_alpha=alpha, random_state=42)
    scores = cross_val_score(tree, X_train, y_train,
                             cv=5, scoring="neg_mean_squared_error")
    cv_scores.append(scores.mean())

# Find the best alpha
best_idx = np.argmax(cv_scores)
best_alpha = ccp_alphas[best_idx]
print(f"Best ccp_alpha: {best_alpha:.6f}")

# Train the pruned tree
pruned_tree = DecisionTreeRegressor(
    ccp_alpha=best_alpha, random_state=42
)
pruned_tree.fit(X_train, y_train)

# Compare performance
y_pred_pruned = pruned_tree.predict(X_test)
mse_pruned = mean_squared_error(y_test, y_pred_pruned)
print(f"Pruned Tree MSE:   {mse_pruned:.4f}")
print(f"Pruned Tree Depth: {pruned_tree.get_depth()}")
print(f"Pruned Tree Leaves: {pruned_tree.get_n_leaves()}")

The cost_complexity_pruning_path method returns an array of alpha values. Each alpha corresponds to a subtree that would be produced if pruning were applied at that level. Cross-validation then identifies which alpha gives the best generalization performance.

Visualizing the Tree

One of the key advantages of decision trees is interpretability. Scikit-learn provides the plot_tree function and the export_text function for visual and text-based representations of the tree structure.

from sklearn.tree import plot_tree, export_text

# Visual tree plot
plt.figure(figsize=(20, 10))
plot_tree(
    regressor,
    filled=True,
    rounded=True,
    impurity=True,
    fontsize=10,
    precision=3,
    feature_names=["X"]
)
plt.title("Decision Tree Structure (max_depth=4)")
plt.tight_layout()
plt.show()

# Text-based representation
tree_rules = export_text(regressor, feature_names=["X"])
print(tree_rules)

The filled=True argument colors the nodes based on their predicted values, making it easy to see which regions of the feature space produce higher or lower predictions. The text representation produced by export_text prints the decision rules in a human-readable format that can be useful for debugging or documentation.

Comparing Tree Depths

To understand the effect of max_depth on prediction quality, the following example trains three regressors with different depths and plots their predictions on the same chart.

# Train trees with different depths
depths = [2, 5, 8]
colors = ["cornflowerblue", "yellowgreen", "tomato"]
X_grid = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]

plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=15, edgecolor="black",
            c="darkorange", label="Data", alpha=0.6)

for depth, color in zip(depths, colors):
    tree = DecisionTreeRegressor(max_depth=depth, random_state=42)
    tree.fit(X, y)
    y_pred_grid = tree.predict(X_grid)
    plt.plot(X_grid, y_pred_grid, color=color,
             label=f"max_depth={depth}", linewidth=2)

plt.xlabel("Feature (X)")
plt.ylabel("Target (y)")
plt.title("Effect of max_depth on Decision Tree Regression")
plt.legend()
plt.tight_layout()
plt.show()

With max_depth=2, the tree produces only a few broad segments and underfits the data. At max_depth=5, the tree captures the overall shape of the sine wave while maintaining some smoothness. At max_depth=8, the tree begins to follow the noise in the training data, producing jagged predictions that are unlikely to generalize well to unseen data.

Warning

A decision tree with no depth limit (max_depth=None) will grow until every leaf contains a single training sample, producing zero training error but poor generalization. Always constrain the tree using max_depth, min_samples_leaf, or ccp_alpha.

Strengths, Weaknesses, and When to Use It

Strengths

  • Interpretability — The tree structure can be visualized and understood by non-technical stakeholders. Each prediction can be traced through a series of simple if-then rules.
  • No feature scaling required — Unlike algorithms such as SVMs or k-nearest neighbors, decision trees are invariant to the scale of input features because splits are based on thresholds, not distances.
  • Handles nonlinear relationships — Decision trees naturally capture complex interactions and nonlinear patterns without requiring the user to specify the functional form.
  • Built-in missing value support — As of scikit-learn 1.8.0, DecisionTreeRegressor has native support for missing values when using splitter='best' with criteria squared_error, friedman_mse, or poisson.
  • Low computational cost — Predictions are fast because they only require traversing the depth of the tree, making them suitable for real-time applications.

Weaknesses

  • Prone to overfitting — Without constraints, decision trees will memorize the training data. Pruning and hyperparameter tuning are essential.
  • High variance — Small changes in the training data can produce completely different tree structures. Ensemble methods like Random Forests and Gradient Boosted Trees address this limitation.
  • Piecewise constant predictions — Because leaf nodes output averages, the model cannot extrapolate beyond the range of the training data and produces discontinuous prediction boundaries.
  • Greedy splitting — The algorithm makes locally optimal decisions at each node without considering the global optimality of the tree. This can lead to suboptimal tree structures.

When to Use Decision Tree Regression

Decision tree regression works well when interpretability is a priority, when the relationship between features and the target is nonlinear or involves thresholds, and when the dataset contains a mix of numerical and categorical features. For higher predictive accuracy, consider ensemble methods like RandomForestRegressor or GradientBoostingRegressor, which build on the decision tree foundation while reducing variance through aggregation or sequential boosting.

Key Takeaways

  1. Recursive partitioning: Decision Tree Regression works by recursively splitting the feature space into regions and predicting the mean target value within each region, producing piecewise constant outputs.
  2. Hyperparameter control is essential: Parameters like max_depth, min_samples_split, min_samples_leaf, and ccp_alpha directly control model complexity. Without constraints, the tree will overfit the training data.
  3. Cost-complexity pruning: The ccp_alpha parameter and cost_complexity_pruning_path method provide a principled way to find the right balance between model complexity and generalization by penalizing large trees.
  4. Visualization and interpretability: Scikit-learn's plot_tree and export_text functions allow for clear visualization of the decision rules, making decision trees one of the more transparent machine learning models available.
  5. Foundation for ensemble methods: Understanding decision tree regression provides the groundwork for more powerful algorithms like Random Forests and Gradient Boosting, which use collections of trees to achieve better predictive performance.

Decision tree regression is a versatile and intuitive algorithm that serves as both a practical modeling tool and a stepping stone to more advanced ensemble techniques. By understanding how it partitions data, how to control its growth through hyperparameters, and how to prune it for better generalization, you gain a solid foundation for tackling regression problems across a wide range of domains.

back to articles