Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
tree_shap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <cstdint>
22 #include <memory>
23 #include <variant>
24 
25 namespace ML {
26 namespace Explainer {
27 
28 template <typename T>
30 
32  std::variant<std::shared_ptr<TreePathInfo<float>>, std::shared_ptr<TreePathInfo<double>>>;
33 
34 using FloatPointer = std::variant<float*, double*>;
35 
37 
39  const FloatPointer data,
40  std::size_t n_rows,
41  std::size_t n_cols,
42  FloatPointer out_preds,
43  std::size_t out_preds_size);
44 
46  const FloatPointer data,
47  std::size_t n_rows,
48  std::size_t n_cols,
49  const FloatPointer background_data,
50  std::size_t background_n_rows,
51  std::size_t background_n_cols,
52  FloatPointer out_preds,
53  std::size_t out_preds_size);
54 
56  const FloatPointer data,
57  std::size_t n_rows,
58  std::size_t n_cols,
59  FloatPointer out_preds,
60  std::size_t out_preds_size);
61 
63  const FloatPointer data,
64  std::size_t n_rows,
65  std::size_t n_cols,
66  FloatPointer out_preds,
67  std::size_t out_preds_size);
68 
69 } // namespace Explainer
70 } // namespace ML
Definition: tree_shap.hpp:29
void gpu_treeshap_taylor_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap_interventional(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, const FloatPointer background_data, std::size_t background_n_rows, std::size_t background_n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap_interactions(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
void gpu_treeshap(TreePathHandle path_info, const FloatPointer data, std::size_t n_rows, std::size_t n_cols, FloatPointer out_preds, std::size_t out_preds_size)
TreePathHandle extract_path_info(ModelHandle model)
std::variant< float *, double * > FloatPointer
Definition: tree_shap.hpp:34
std::variant< std::shared_ptr< TreePathInfo< float > >, std::shared_ptr< TreePathInfo< double > >> TreePathHandle
Definition: tree_shap.hpp:32
Definition: dbscan.hpp:27
void * ModelHandle
Definition: treelite_defs.hpp:23