18 #include <type_traits>
25 namespace experimental {
48 template <
bool has_vector_leaves,
49 bool has_categorical_nodes,
52 typename node_id_mapping_t = std::nullptr_t>
54 io_t
const* __restrict__ row,
55 node_t
const* __restrict__ first_root_node =
nullptr,
56 node_id_mapping_t node_id_mapping =
nullptr)
59 auto cur_node = *
node;
61 auto input_val = row[cur_node.feature_index()];
62 auto condition =
true;
63 if constexpr (has_categorical_nodes) {
64 if (cur_node.is_categorical()) {
65 auto valid_categories = categorical_set_type{
67 condition = valid_categories.test(input_val);
69 condition = (input_val < cur_node.threshold());
72 condition = (input_val < cur_node.threshold());
74 if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
77 }
while (!cur_node.is_leaf());
78 if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
79 return cur_node.template output<has_vector_leaves>();
81 return node_id_mapping[
node - first_root_node];
112 template <
bool has_vector_leaves,
115 typename categorical_storage_t,
116 typename node_id_mapping_t = std::nullptr_t>
118 io_t
const* __restrict__ row,
119 categorical_storage_t
const* __restrict__ categorical_storage,
120 node_t
const* __restrict__ first_root_node =
nullptr,
121 node_id_mapping_t node_id_mapping =
nullptr)
124 auto cur_node = *
node;
126 auto input_val = row[cur_node.feature_index()];
128 if (!isnan(input_val)) {
129 if (cur_node.is_categorical()) {
130 auto valid_categories =
131 categorical_set_type{categorical_storage + cur_node.index() + 1,
132 uint32_t(categorical_storage[cur_node.index()])};
133 condition = valid_categories.test(input_val);
135 condition = (input_val < cur_node.threshold());
140 }
while (!cur_node.is_leaf());
141 if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
142 return cur_node.template output<has_vector_leaves>();
144 return node_id_mapping[
node - first_root_node];
166 template <
bool has_vector_leaves,
167 bool has_categorical_nodes,
168 bool has_nonlocal_categories,
172 typename categorical_data_t>
175 io_t
const* __restrict__ row,
176 categorical_data_t categorical_data)
178 using node_t =
typename forest_t::node_type;
179 if constexpr (predict_leaf) {
181 if constexpr (has_nonlocal_categories) {
188 leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
196 auto tree_output = std::conditional_t<has_vector_leaves,
198 typename node_t::threshold_type>{};
199 if constexpr (has_nonlocal_categories) {
200 tree_output = evaluate_tree_impl<has_vector_leaves>(
203 tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
#define DEVICE
Definition: gpu_support.hpp:34
#define HOST
Definition: gpu_support.hpp:33
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:173
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:53
uint32_t index_type
Definition: index_type.hpp:21
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:27
Definition: bitset.hpp:32
Definition: forest.hpp:34
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:56
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:63
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:159
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:169
HOST DEVICE auto const & index() const
Definition: node.hpp:186