Loading [MathJax]/jax/output/HTML-CSS/config.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
evaluate_tree.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <stdint.h>
18 #include <type_traits>
19 #ifndef __CUDACC__
20 #include <math.h>
21 #endif
24 namespace ML {
25 namespace experimental {
26 namespace fil {
27 namespace detail {
28 
29 /*
30  * Evaluate a single tree on a single row.
31  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
32  * instead of the leaf value.
33  *
34  * @tparam has_vector_leaves Whether or not this tree has vector leaves
35  * @tparam has_categorical_nodes Whether or not this tree has any nodes with
36  * categorical splits
37  * @tparam node_t The type of nodes in this tree
38  * @tparam io_t The type used for input to and output from this tree (typically
39  * either floats or doubles)
40  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
41  * node_id_mapping.
42  * @param node Pointer to the root node of this tree
43  * @param row Pointer to the input data for this row
44  * @param first_root_node Pointer to the root node of the first tree.
45  * @param node_id_mapping Array representing the mapping from internal node IDs to
46  * final leaf ID outputs
47  */
48 template <bool has_vector_leaves,
49  bool has_categorical_nodes,
50  typename node_t,
51  typename io_t,
52  typename node_id_mapping_t = std::nullptr_t>
53 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
54  io_t const* __restrict__ row,
55  node_t const* __restrict__ first_root_node = nullptr,
56  node_id_mapping_t node_id_mapping = nullptr)
57 {
58  using categorical_set_type = bitset<uint32_t, typename node_t::index_type const>;
59  auto cur_node = *node;
60  do {
61  auto input_val = row[cur_node.feature_index()];
62  auto condition = true;
63  if constexpr (has_categorical_nodes) {
64  if (cur_node.is_categorical()) {
65  auto valid_categories = categorical_set_type{
66  &cur_node.index(), uint32_t(sizeof(typename node_t::index_type) * 8)};
67  condition = valid_categories.test(input_val);
68  } else {
69  condition = (input_val < cur_node.threshold());
70  }
71  } else {
72  condition = (input_val < cur_node.threshold());
73  }
74  if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
75  node += cur_node.child_offset(condition);
76  cur_node = *node;
77  } while (!cur_node.is_leaf());
78  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
79  return cur_node.template output<has_vector_leaves>();
80  } else {
81  return node_id_mapping[node - first_root_node];
82  }
83 }
84 
85 /*
86  * Evaluate a single tree which requires external categorical storage on a
87  * single node.
88  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
89  * instead of the leaf value.
90  *
91  * For non-categorical models and models with a relatively small number of
92  * categories for any feature, all information necessary for model evaluation
93  * can be stored on a single node. If the number of categories for any
94  * feature exceeds the available space on a node, however, the
95  * categorical split data must be stored external to the node. We pass a
96  * pointer to this external data and reconstruct bitsets from it indicating
97  * the positive and negative categories for each categorical node.
98  *
99  * @tparam has_vector_leaves Whether or not this tree has vector leaves
100  * @tparam node_t The type of nodes in this tree
101  * @tparam io_t The type used for input to and output from this tree (typically
102  * either floats or doubles)
103  * @tparam categorical_storage_t The underlying type used for storing
104  * categorical data (typically char)
105  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
106  * node_id_mapping.
107  * @param node Pointer to the root node of this tree
108  * @param row Pointer to the input data for this row
109  * @param categorical_storage Pointer to where categorical split data is
110  * stored.
111  */
112 template <bool has_vector_leaves,
113  typename node_t,
114  typename io_t,
115  typename categorical_storage_t,
116  typename node_id_mapping_t = std::nullptr_t>
117 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
118  io_t const* __restrict__ row,
119  categorical_storage_t const* __restrict__ categorical_storage,
120  node_t const* __restrict__ first_root_node = nullptr,
121  node_id_mapping_t node_id_mapping = nullptr)
122 {
123  using categorical_set_type = bitset<uint32_t, categorical_storage_t const>;
124  auto cur_node = *node;
125  do {
126  auto input_val = row[cur_node.feature_index()];
127  auto condition = cur_node.default_distant();
128  if (!isnan(input_val)) {
129  if (cur_node.is_categorical()) {
130  auto valid_categories =
131  categorical_set_type{categorical_storage + cur_node.index() + 1,
132  uint32_t(categorical_storage[cur_node.index()])};
133  condition = valid_categories.test(input_val);
134  } else {
135  condition = (input_val < cur_node.threshold());
136  }
137  }
138  node += cur_node.child_offset(condition);
139  cur_node = *node;
140  } while (!cur_node.is_leaf());
141  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
142  return cur_node.template output<has_vector_leaves>();
143  } else {
144  return node_id_mapping[node - first_root_node];
145  }
146 }
147 
166 template <bool has_vector_leaves,
167  bool has_categorical_nodes,
168  bool has_nonlocal_categories,
169  bool predict_leaf,
170  typename forest_t,
171  typename io_t,
172  typename categorical_data_t>
174  index_type tree_index,
175  io_t const* __restrict__ row,
176  categorical_data_t categorical_data)
177 {
178  using node_t = typename forest_t::node_type;
179  if constexpr (predict_leaf) {
180  auto leaf_node_id = index_type{};
181  if constexpr (has_nonlocal_categories) {
182  leaf_node_id = evaluate_tree_impl<has_vector_leaves>(forest.get_tree_root(tree_index),
183  row,
184  categorical_data,
187  } else {
188  leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
189  forest.get_tree_root(tree_index),
190  row,
193  }
194  return leaf_node_id;
195  } else {
196  auto tree_output = std::conditional_t<has_vector_leaves,
197  typename node_t::index_type,
198  typename node_t::threshold_type>{};
199  if constexpr (has_nonlocal_categories) {
200  tree_output = evaluate_tree_impl<has_vector_leaves>(
201  forest.get_tree_root(tree_index), row, categorical_data);
202  } else {
203  tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
204  forest.get_tree_root(tree_index), row);
205  }
206  return tree_output;
207  }
208 }
209 
210 } // namespace detail
211 } // namespace fil
212 } // namespace experimental
213 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:34
#define HOST
Definition: gpu_support.hpp:33
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:173
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:53
uint32_t index_type
Definition: index_type.hpp:21
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:27
Definition: bitset.hpp:32
Definition: forest.hpp:34
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:56
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:63
Definition: node.hpp:91
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:159
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:169
HOST DEVICE auto const & index() const
Definition: node.hpp:186