Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
forest_model.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cstddef>
24 #include <type_traits>
25 #include <variant>
26 
27 namespace ML {
28 namespace experimental {
29 namespace fil {
30 
37 struct forest_model {
40  : decision_forest_{forest}
41  {
42  }
43 
45  auto num_features()
46  {
47  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_features(); },
48  decision_forest_);
49  }
50 
52  auto num_outputs()
53  {
54  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_outputs(); },
55  decision_forest_);
56  }
57 
59  auto num_trees()
60  {
61  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_trees(); },
62  decision_forest_);
63  }
64 
67  {
68  return std::visit([](auto&& concrete_forest) { return concrete_forest.has_vector_leaves(); },
69  decision_forest_);
70  }
71 
74  {
75  return std::visit([](auto&& concrete_forest) { return concrete_forest.row_postprocessing(); },
76  decision_forest_);
77  }
78 
81  {
82  return std::visit(
83  [&val](auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
84  decision_forest_);
85  }
86 
90  {
91  return std::visit([](auto&& concrete_forest) { return concrete_forest.elem_postprocessing(); },
92  decision_forest_);
93  }
94 
96  auto memory_type()
97  {
98  return std::visit([](auto&& concrete_forest) { return concrete_forest.memory_type(); },
99  decision_forest_);
100  }
101 
104  {
105  return std::visit([](auto&& concrete_forest) { return concrete_forest.device_index(); },
106  decision_forest_);
107  }
108 
111  {
112  return std::visit(
113  [](auto&& concrete_forest) {
114  return std::is_same_v<typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
115  double>;
116  },
117  decision_forest_);
118  }
119 
143  template <typename io_t>
145  raft_proto::buffer<io_t> const& input,
147  infer_kind predict_type = infer_kind::default_kind,
148  std::optional<index_type> specified_chunk_size = std::nullopt)
149  {
150  std::visit(
151  [this, predict_type, &output, &input, &stream, &specified_chunk_size](
152  auto&& concrete_forest) {
153  if constexpr (std::is_same_v<
154  typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
155  io_t>) {
156  concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
157  } else {
158  throw type_error("Input type does not match model_type");
159  }
160  },
161  decision_forest_);
162  }
163 
191  template <typename io_t>
192  void predict(raft_proto::handle_t const& handle,
193  raft_proto::buffer<io_t>& output,
194  raft_proto::buffer<io_t> const& input,
195  infer_kind predict_type = infer_kind::default_kind,
196  std::optional<index_type> specified_chunk_size = std::nullopt)
197  {
198  std::visit(
199  [this, predict_type, &handle, &output, &input, &specified_chunk_size](
200  auto&& concrete_forest) {
201  using model_io_t = typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
202  if constexpr (std::is_same_v<model_io_t, io_t>) {
203  if (output.memory_type() == memory_type() && input.memory_type() == memory_type()) {
204  concrete_forest.predict(
205  output, input, handle.get_next_usable_stream(), predict_type, specified_chunk_size);
206  } else {
207  auto constexpr static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
208  auto constexpr static const MAX_CHUNK_SIZE = std::size_t{64};
209 
210  auto row_count = input.size() / num_features();
211  auto partition_size =
213  specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
214  auto partition_count = raft_proto::ceildiv(row_count, partition_size);
215  for (auto i = std::size_t{}; i < partition_count; ++i) {
216  auto stream = handle.get_next_usable_stream();
217  auto rows_in_this_partition =
218  std::min(partition_size, row_count - i * partition_size);
219  auto partition_in = raft_proto::buffer<io_t>{};
220  if (input.memory_type() != memory_type()) {
221  partition_in =
222  raft_proto::buffer<io_t>{rows_in_this_partition * num_features(), memory_type()};
223  raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
224  input,
225  0,
226  i * partition_size * num_features(),
227  partition_in.size(),
228  stream);
229  } else {
230  partition_in =
231  raft_proto::buffer<io_t>{input.data() + i * partition_size * num_features(),
232  rows_in_this_partition * num_features(),
233  memory_type()};
234  }
235  auto partition_out = raft_proto::buffer<io_t>{};
236  if (output.memory_type() != memory_type()) {
237  partition_out =
238  raft_proto::buffer<io_t>{rows_in_this_partition * num_outputs(), memory_type()};
239  } else {
240  partition_out =
241  raft_proto::buffer<io_t>{output.data() + i * partition_size * num_outputs(),
242  rows_in_this_partition * num_outputs(),
243  memory_type()};
244  }
245  concrete_forest.predict(
246  partition_out, partition_in, stream, predict_type, specified_chunk_size);
247  if (output.memory_type() != memory_type()) {
248  raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
249  partition_out,
250  i * partition_size * num_outputs(),
251  0,
252  partition_out.size(),
253  stream);
254  }
255  }
256  }
257  } else {
258  throw type_error("Input type does not match model_type");
259  }
260  },
261  decision_forest_);
262  }
263 
290  template <typename io_t>
291  void predict(raft_proto::handle_t const& handle,
292  io_t* output,
293  io_t* input,
294  std::size_t num_rows,
295  raft_proto::device_type out_mem_type,
296  raft_proto::device_type in_mem_type,
297  infer_kind predict_type = infer_kind::default_kind,
298  std::optional<index_type> specified_chunk_size = std::nullopt)
299  {
300  // TODO(wphicks): Make sure buffer lands on same device as model
301  auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type};
302  auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type};
303  predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
304  }
305 
306  private:
307  decision_forest_variant decision_forest_;
308 };
309 
310 } // namespace fil
311 } // namespace experimental
312 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
infer_kind
Definition: infer_kind.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:414
row_op
Definition: postproc_ops.hpp:22
Definition: dbscan.hpp:27
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: forest_model.hpp:37
auto elem_postprocessing()
Definition: forest_model.hpp:89
auto num_features()
Definition: forest_model.hpp:45
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:39
auto is_double_precision()
Definition: forest_model.hpp:110
auto device_index()
Definition: forest_model.hpp:103
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:192
auto num_trees()
Definition: forest_model.hpp:59
auto has_vector_leaves()
Definition: forest_model.hpp:66
auto num_outputs()
Definition: forest_model.hpp:52
auto memory_type()
Definition: forest_model.hpp:96
auto row_postprocessing()
Definition: forest_model.hpp:73
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:80
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:144
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:291
Definition: forest.hpp:34
Definition: exceptions.hpp:52
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:39
auto size() const noexcept
Definition: buffer.hpp:291
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:292
auto memory_type() const noexcept
Definition: buffer.hpp:293
Definition: handle.hpp:46
auto get_usable_stream_count() const
Definition: handle.hpp:49
auto get_next_usable_stream() const
Definition: handle.hpp:47