Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
infer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cstddef>
21 #include <iostream>
22 #include <optional>
23 #include <type_traits>
24 #ifdef CUML_ENABLE_GPU
26 #endif
32 namespace ML {
33 namespace experimental {
34 namespace fil {
35 namespace detail {
36 
37 /*
38  * Perform inference based on the given forest and input parameters
39  *
40  * @tparam D The device type (CPU/GPU) used to perform inference
41  * @tparam forest_t The type of the forest
42  * @param forest The forest to be evaluated
43  * @param postproc The postprocessor object used to execute
44  * postprocessing
45  * @param output Pointer to where the output should be written
46  * @param input Pointer to where the input data can be read from
47  * @param row_count The number of rows in the input data
48  * @param col_count The number of columns in the input data
49  * @param output_count The number of outputs per row
50  * @param has_categorical_nodes Whether or not any node within the forest has
51  * a categorical split
52  * @param vector_output Pointer to the beginning of storage for vector
53  * outputs of leaves (nullptr for no vector output)
54  * @param categorical_data Pointer to external categorical data storage if
55  * required
56  * @param infer_type Type of inference to perform. Defaults to summing the outputs of all trees
57  * and produce an output per row. If set to "per_tree", we will instead output all outputs of
58  * individual trees. If set to "leaf_id", we will output the integer ID of the leaf node
59  * for each tree.
60  * @param specified_chunk_size If non-nullopt, the size of "mini-batches"
61  * used for distributing work across threads
62  * @param device The device on which to execute evaluation
63  * @param stream Optionally, the CUDA stream to use
64  */
65 template <raft_proto::device_type D, typename forest_t>
66 void infer(forest_t const& forest,
68  typename forest_t::io_type* output,
69  typename forest_t::io_type* input,
70  index_type row_count,
71  index_type col_count,
72  index_type output_count,
73  bool has_categorical_nodes,
74  typename forest_t::io_type* vector_output = nullptr,
75  typename forest_t::node_type::index_type* categorical_data = nullptr,
77  std::optional<index_type> specified_chunk_size = std::nullopt,
80 {
81  if (vector_output == nullptr) {
82  if (categorical_data == nullptr) {
83  if (!has_categorical_nodes) {
84  inference::infer<D, false, forest_t, std::nullptr_t, std::nullptr_t>(forest,
85  postproc,
86  output,
87  input,
88  row_count,
89  col_count,
90  output_count,
91  nullptr,
92  nullptr,
93  infer_type,
94  specified_chunk_size,
95  device,
96  stream);
97  } else {
98  inference::infer<D, true, forest_t, std::nullptr_t, std::nullptr_t>(forest,
99  postproc,
100  output,
101  input,
102  row_count,
103  col_count,
104  output_count,
105  nullptr,
106  nullptr,
107  infer_type,
108  specified_chunk_size,
109  device,
110  stream);
111  }
112  } else {
113  inference::infer<D, true, forest_t>(forest,
114  postproc,
115  output,
116  input,
117  row_count,
118  col_count,
119  output_count,
120  nullptr,
121  categorical_data,
122  infer_type,
123  specified_chunk_size,
124  device,
125  stream);
126  }
127  } else {
128  if (categorical_data == nullptr) {
129  if (!has_categorical_nodes) {
130  inference::infer<D, false, forest_t>(forest,
131  postproc,
132  output,
133  input,
134  row_count,
135  col_count,
136  output_count,
137  vector_output,
138  nullptr,
139  infer_type,
140  specified_chunk_size,
141  device,
142  stream);
143  } else {
144  inference::infer<D, true, forest_t>(forest,
145  postproc,
146  output,
147  input,
148  row_count,
149  col_count,
150  output_count,
151  vector_output,
152  nullptr,
153  infer_type,
154  specified_chunk_size,
155  device,
156  stream);
157  }
158  } else {
159  inference::infer<D, true, forest_t>(forest,
160  postproc,
161  output,
162  input,
163  row_count,
164  col_count,
165  output_count,
166  vector_output,
167  categorical_data,
168  infer_type,
169  specified_chunk_size,
170  device,
171  stream);
172  }
173  }
174 }
175 
176 } // namespace detail
177 } // namespace fil
178 } // namespace experimental
179 } // namespace ML
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:66
uint32_t index_type
Definition: index_type.hpp:21
infer_kind
Definition: infer_kind.hpp:20
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:27
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: forest.hpp:34
Definition: postprocessor.hpp:137
Definition: base.hpp:22