tree_shap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <cstddef>
22 #include <cstdint>
23 #include <memory>
24 #include <variant>
25 
26 namespace ML {
27 namespace Explainer {
28 
29 template <typename T>
31 
33  std::variant<std::shared_ptr<TreePathInfo<float>>, std::shared_ptr<TreePathInfo<double>>>;
34 
35 using FloatPointer = std::variant<float*, double*>;
36 
38 
40  const FloatPointer data,
41  std::size_t n_rows,
42  std::size_t n_cols,
43  FloatPointer out_preds,
44  std::size_t out_preds_size);
45 
47  const FloatPointer data,
48  std::size_t n_rows,
49  std::size_t n_cols,
50  const FloatPointer background_data,
51  std::size_t background_n_rows,
52  std::size_t background_n_cols,
53  FloatPointer out_preds,
54  std::size_t out_preds_size);
55 
57  const FloatPointer data,
58  std::size_t n_rows,
59  std::size_t n_cols,
60  FloatPointer out_preds,
61  std::size_t out_preds_size);
62 
64  const FloatPointer data,
65  std::size_t n_rows,
66  std::size_t n_cols,
67  FloatPointer out_preds,
68  std::size_t out_preds_size);
69 
70 } // namespace Explainer
71 } // namespace ML
Definition: tree_shap.hpp:30
void gpu_treeshap_taylor_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
TreePathHandle extract_path_info(TreeliteModelHandle model)
void gpu_treeshap_interventional(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, const FloatPointer background_data, std::size_t background_n_rows, std::size_t background_n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
std::variant< float *, double * > FloatPointer
Definition: tree_shap.hpp:35
std::variant< std::shared_ptr< TreePathInfo< float > >, std::shared_ptr< TreePathInfo< double > >> TreePathHandle
Definition: tree_shap.hpp:33
Definition: dbscan.hpp:30
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23