TreeExplainer#

class cuml.explainer.TreeExplainer(model, *, data=None, convert_dtype=True)#

Model explainer that calculates Shapley values for the predictions of tree-based models. Shapley values are a method of attributing various input features to a given model prediction.

Uses GPUTreeShap [1] as a back-end to accelerate computation using GPUs.

Different variants of Shapley values exist based on different interpretations of marginalising out (or conditioning on) features. For the “tree_path_dependent” approach, see [2].

For the “interventional” approach, see [3].

We also provide two variants of feature interactions. For the “shapley-interactions” variant of interactions, see [2], for the “shapley-taylor” variant, see [4].

[1]

Mitchell, Rory, Eibe Frank, and Geoffrey Holmes. “GPUTreeShap: massively parallel exact calculation of SHAP scores for tree ensembles.” PeerJ Computer Science 8 (2022): e880.

[2] (1,2)

Lundberg, Scott M., et al. “From local explanations to global understanding with explainable AI for trees.” Nature machine intelligence 2.1 (2020): 56-67.

[3]

Janzing, Dominik, Lenon Minorics, and Patrick Blöbaum. “Feature relevance quantification in explainable AI: A causal problem.” International Conference on artificial intelligence and statistics. PMLR, 2020.

[4]

Sundararajan, Mukund, Kedar Dhamdhere, and Ashish Agarwal. “The Shapley Taylor Interaction Index.” International Conference on Machine Learning. PMLR, 2020.

Parameters:
modelmodel object

The tree based machine learning model. XGBoost, LightGBM, cuml random forest and sklearn random forest models are supported. Categorical features in XGBoost or LightGBM models are natively supported.

dataarray or DataFrame

Optional background dataset to use for marginalising out features. If this argument is supplied, an “interventional” approach is used. Computation time increases with the size of this background data set, consider starting with between 100-1000 examples. If this argument is not supplied, statistics from the tree model are used to marginalise out features (“tree_path_dependent”).

Attributes:
expected_value

expected_value: object

Methods

shap_interaction_values(self, X[, method, ...])

Estimate the SHAP interaction values for a set of samples.

shap_values(self, X[, convert_dtype])

Estimate the SHAP values for a set of samples.

Examples

>>> import numpy as np
>>> import cuml
>>> from cuml.explainer import TreeExplainer
>>> X = np.array([[0.0, 2.0], [1.0, 0.5]])
>>> y = np.array([0, 1])
>>> model = cuml.ensemble.RandomForestRegressor().fit(X, y)
>>> explainer = TreeExplainer(model=model)
>>> shap_values = explainer.shap_values(X)
expected_value#

expected_value: object

shap_interaction_values(self, X, method='shapley-interactions', convert_dtype=True)[source]#

Estimate the SHAP interaction values for a set of samples. For a given row, the SHAP values plus the expected_value attribute sum up to the raw model prediction. ‘Raw model prediction’ means before the application of a link function, for example, the SHAP values of an XGBoost binary classification are in the additive logit space as opposed to probability space.

Interventional feature marginalisation is not supported.

Parameters:
X

A matrix of samples (# samples x # features) on which to explain the model’s output.

method

One of [‘shapley-interactions’, ‘shapley-taylor’]

Returns:
array

Returns a matrix of SHAP values of shape (# classes x # samples x # features x # features).

shap_values(self, X, convert_dtype=True)[source]#

Estimate the SHAP values for a set of samples. For a given row, the SHAP values plus the expected_value attribute sum up to the raw model prediction. ‘Raw model prediction’ means before the application of a link function, for example, the SHAP values of an XGBoost binary classification will be in the additive logit space as opposed to probability space.

Parameters:
X

A matrix of samples (# samples x # features) on which to explain the model’s output.

Returns:
array

Returns a matrix of SHAP values of shape (# classes x # samples x # features).