tree_shap.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <cstddef>
11 #include <cstdint>
12 #include <memory>
13 #include <variant>
14 
15 namespace ML {
16 namespace Explainer {
17 
18 template <typename T>
20 
22  std::variant<std::shared_ptr<TreePathInfo<float>>, std::shared_ptr<TreePathInfo<double>>>;
23 
24 using FloatPointer = std::variant<float*, double*>;
25 
27 
29  const FloatPointer data,
30  std::size_t n_rows,
31  std::size_t n_cols,
32  FloatPointer out_preds,
33  std::size_t out_preds_size);
34 
36  const FloatPointer data,
37  std::size_t n_rows,
38  std::size_t n_cols,
39  const FloatPointer background_data,
40  std::size_t background_n_rows,
41  std::size_t background_n_cols,
42  FloatPointer out_preds,
43  std::size_t out_preds_size);
44 
46  const FloatPointer data,
47  std::size_t n_rows,
48  std::size_t n_cols,
49  FloatPointer out_preds,
50  std::size_t out_preds_size);
51 
53  const FloatPointer data,
54  std::size_t n_rows,
55  std::size_t n_cols,
56  FloatPointer out_preds,
57  std::size_t out_preds_size);
58 
59 } // namespace Explainer
60 } // namespace ML
Definition: tree_shap.hpp:19
void gpu_treeshap_taylor_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
TreePathHandle extract_path_info(TreeliteModelHandle model)
void gpu_treeshap_interventional(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, const FloatPointer background_data, std::size_t background_n_rows, std::size_t background_n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
std::variant< float *, double * > FloatPointer
Definition: tree_shap.hpp:24
std::variant< std::shared_ptr< TreePathInfo< float > >, std::shared_ptr< TreePathInfo< double > >> TreePathHandle
Definition: tree_shap.hpp:22
Definition: dbscan.hpp:18
void * TreeliteModelHandle
Definition: treelite_defs.hpp:12