Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
forest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
20 #include <stddef.h>
21 #include <type_traits>
22 
23 namespace ML {
24 namespace experimental {
25 namespace fil {
26 
27 /* A collection of trees which together form a forest model
28  */
29 template <tree_layout layout_v,
30  typename threshold_t,
31  typename index_t,
32  typename metadata_storage_t,
33  typename offset_t>
34 struct forest {
36  using io_type = threshold_t;
37  template <typename vector_output_t>
38  using raw_output_type = std::conditional_t<!std::is_same_v<vector_output_t, std::nullptr_t>,
39  std::remove_pointer_t<vector_output_t>,
40  typename node_type::threshold_type>;
41 
42  HOST DEVICE forest(node_type* forest_nodes,
43  index_type* forest_root_indexes,
44  index_type* node_id_mapping,
45  index_type num_trees,
47  : nodes_{forest_nodes},
48  root_node_indexes_{forest_root_indexes},
49  node_id_mapping_{node_id_mapping},
50  num_trees_{num_trees},
51  num_outputs_{num_outputs}
52  {
53  }
54 
55  /* Return pointer to the root node of the indicated tree */
56  HOST DEVICE auto* get_tree_root(index_type tree_index) const
57  {
58  return nodes_ + root_node_indexes_[tree_index];
59  }
60 
61  /* Return pointer to the mapping from internal node IDs to final node ID outputs.
62  * Only used when infer_type == infer_kind::leaf_id */
63  HOST DEVICE const auto* get_node_id_mapping() const { return node_id_mapping_; }
64 
65  /* Return the number of trees in this forest */
66  HOST DEVICE auto tree_count() const { return num_trees_; }
67 
68  /* Return the number of outputs per row for default evaluation of this
69  * forest */
70  HOST DEVICE auto num_outputs() const { return num_outputs_; }
71 
72  private:
73  node_type* nodes_;
74  index_type* root_node_indexes_;
75  index_type* node_id_mapping_;
76  index_type num_trees_;
77  index_type num_outputs_;
78 };
79 
80 } // namespace fil
81 } // namespace experimental
82 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:34
#define HOST
Definition: gpu_support.hpp:33
tree_layout
Definition: tree_layout.hpp:20
uint32_t index_type
Definition: index_type.hpp:21
Definition: dbscan.hpp:27
Definition: forest.hpp:34
HOST DEVICE forest(node_type *forest_nodes, index_type *forest_root_indexes, index_type *node_id_mapping, index_type num_trees, index_type num_outputs)
Definition: forest.hpp:42
threshold_t io_type
Definition: forest.hpp:36
std::conditional_t<!std::is_same_v< vector_output_t, std::nullptr_t >, std::remove_pointer_t< vector_output_t >, typename node_type::threshold_type > raw_output_type
Definition: forest.hpp:40
HOST DEVICE auto tree_count() const
Definition: forest.hpp:66
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:56
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:63
node< layout_v, threshold_t, index_t, metadata_storage_t, offset_t > node_type
Definition: forest.hpp:35
HOST DEVICE auto num_outputs() const
Definition: forest.hpp:70
Definition: node.hpp:91
threshold_t threshold_type
Definition: node.hpp:95