Visualizing tree-based regressors#

(View this notebook in Colab)

The dtreeviz library is designed to help machine learning practitioners visualize and interpret decision trees and decision-tree-based models, such as gradient boosting machines.

The purpose of this notebook is to illustrate the main capabilities and functions of the dtreeviz API. To do that, we will use scikit-learn and the toy but well-known Titanic data set for illustrative purposes. Currently, dtreeviz supports the following decision tree libraries:

To interopt with these different libraries, dtreeviz uses an adaptor object, obtained from function dtreeviz.model(), to extract model information necessary for visualization. Given such an adaptor object, all of the dtreeviz functionality is available to you using the same programmer interface. The basic dtreeviz usage recipe is:

  1. Import dtreeviz and your decision tree library

  2. Acquire and load data into memory

  3. Train a classifier or regressor model using your decision tree library

  4. Obtain a dtreeviz adaptor model using
    viz_model = dtreeviz.model(your_trained_model,...)

  5. Call dtreeviz functions, such as
    viz_model.view() or viz_model.explain_prediction_path(sample_x)

The four categories of dtreeviz functionality are:

  1. Tree visualizations

  2. Prediction path explanations

  3. Leaf information

  4. Feature space exploration

We have grouped code examples by classifiers and regressors, with a follow up section on partitioning feature space.

These examples require dtreeviz 2.0 or above because the code uses the new API introduced in 2.0.

Setup#

import sys
import os
%config InlineBackend.figure_format = 'retina' # Make visualizations look good
#%config InlineBackend.figure_format = 'svg'
%matplotlib inline

if 'google.colab' in sys.modules:
  !pip install -q dtreeviz
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

import dtreeviz

random_state = 1234 # get reproducible trees

Load Sample Data#

dataset_url = "https://raw.githubusercontent.com/parrt/dtreeviz/master/data/titanic/titanic.csv"
dataset = pd.read_csv(dataset_url)
# Fill missing values for Age
dataset.fillna({"Age":dataset.Age.mean()}, inplace=True)
# Encode categorical variables
dataset["Sex_label"] = dataset.Sex.astype("category").cat.codes
dataset["Cabin_label"] = dataset.Cabin.astype("category").cat.codes
dataset["Embarked_label"] = dataset.Embarked.astype("category").cat.codes

To demonstrate regressor tree visualization, we start by creating a regressors model that predicts age instead of survival:

features_reg = ["Pclass", "Fare", "Sex_label", "Cabin_label", "Embarked_label", "Survived"]
target_reg = "Age"

tree_regressor = DecisionTreeRegressor(max_depth=3, random_state=random_state, criterion="mae")
tree_regressor.fit(dataset[features_reg].values, dataset[target_reg].values)
DecisionTreeRegressor(criterion='mae', max_depth=3, random_state=1234)

Initialize dtreeviz model (adaptor)#

viz_rmodel = dtreeviz.model(model=tree_regressor,
                            X_train=dataset[features_reg],
                            y_train=dataset[target_reg],
                            feature_names=features_reg,
                            target_name=target_reg)

Tree structure visualisations#

viz_rmodel.view()
../../../../_images/c0df41717a624fc25525df21d110494e0e6a56ee1586bbd6923ea8195bfacc2b.svg
viz_rmodel.view(orientation="LR")
../../../../_images/72a144561b0c1298bf733c68c33fee601f375938e68f109e3a15e4a512462946.svg
viz_rmodel.view(fancy=False)
../../../../_images/fc27fdf7d604a823aa719132823560dc4e0f98a8faa1145fa6838470fc4eb7e8.svg
viz_rmodel.view(depth_range_to_display=(0, 2))
../../../../_images/520708839eecf1162ea9157d3c7fdb6c28983a4c3441976193a49125e4229f69.svg

Prediction path explanations#

x = dataset[features_reg].iloc[10]
x
Pclass              3.0
Fare               16.7
Sex_label           0.0
Cabin_label       145.0
Embarked_label      2.0
Survived            1.0
Name: 10, dtype: float64
viz_rmodel.view(x = x)
../../../../_images/13562244f3f09c4211061e24654b0f37b5620a917cd11a6a006a9ea627c2b232.svg
viz_rmodel.view(show_just_path=True, x = x)
../../../../_images/1725fb7f8e6fb58ac5201565db9a5b0ae49f46a18a9fc9c2c9bfe99b554497d1.svg
print(viz_rmodel.explain_prediction_path(x))
1.5 <= Pclass 
Fare < 27.82
139.5 <= Cabin_label 
viz_rmodel.instance_feature_importance(x, figsize=(3.5,2))
../../../../_images/67b0d745f0be497132b3d2b679c4edb147f7939da3ae5e8b7f05b16dc3fa4fe7.png

Leaf info#

viz_rmodel.leaf_sizes(figsize=(3.5,2))
../../../../_images/052ae58116f9626c05caf86fcaf74c54250411f70db5010d24fd6bb505aa080a.png
viz_rmodel.rtree_leaf_distributions()
../../../../_images/35752245f50321f3aff9510467471a63a24d1e09f81ac03641717655cccec906.png
viz_rmodel.node_stats(node_id=4)
Pclass Fare Sex_label Cabin_label Embarked_label Survived
count 72.0 72.0 72.0 72.0 72.0 72.0
mean 1.0 152.167936 0.347222 39.25 0.916667 0.763889
std 0.0 97.808005 0.479428 26.556742 1.031203 0.427672
min 1.0 66.6 0.0 -1.0 -1.0 0.0
25% 1.0 83.1583 0.0 20.75 0.0 1.0
50% 1.0 120.0 0.0 40.0 0.0 1.0
75% 1.0 211.3375 1.0 63.0 2.0 1.0
max 1.0 512.3292 1.0 79.0 2.0 1.0
viz_rmodel.leaf_purity(figsize=(3.5,2))
../../../../_images/17847d8209f1864d48a1b29dd8ea58e338d34ebf0e326bb964d33ab08116b921.png

Partitioning#

To demonstrate regression, let’s load a toy Cars data set and visualize the partitioning of univariate and bivariate feature spaces.

dataset_url = "https://raw.githubusercontent.com/parrt/dtreeviz/master/data/cars.csv"
df_cars = pd.read_csv(dataset_url)
X = df_cars.drop('MPG', axis=1)
y = df_cars['MPG']
features = list(X.columns)
dtr_cars = DecisionTreeRegressor(max_depth=3, criterion="mae")
dtr_cars.fit(X.values, y.values)
DecisionTreeRegressor(criterion='mae', max_depth=3)
viz_rmodel = dtreeviz.model(dtr_cars, X, y,
                            feature_names=features,
                            target_name='MPG')

The following visualization illustrates how the decision tree breaks up the WGT (car weight) in order to get relatively pure MPG (miles per gallon) target values.

viz_rmodel.rtree_feature_space(features=['WGT'])
../../../../_images/224004cb8f3c518d2849789a1c8e319e94eed7cfb1af51f98b2f3ba7e8102500.png

In order to visualize two-dimensional feature space, we can draw in three dimensions:

viz_rmodel.rtree_feature_space3D(features=['WGT','ENG'],
                                 fontsize=10,
                                 elev=30, azim=20,
                                 show={'splits', 'title'},
                                 colors={'tessellation_alpha': .5})
../../../../_images/4bb6a30789736b6c12738617e980fe8b6eff0e6f4b5c2c2a6895b32bf0109a7c.png

Equivalently, we can show a heat map as if we were looking at the three-dimensional plot from the top down:

viz_rmodel.rtree_feature_space(features=['WGT','ENG'])
../../../../_images/1391391f099413b24a8062d7b0d4ffd0db9746cf57a777b176bf7f15f4c1aac1.png