24 #include <type_traits>
28 namespace experimental {
47 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_features(); },
54 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_outputs(); },
61 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_trees(); },
68 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.has_vector_leaves(); },
75 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.row_postprocessing(); },
83 [&val](
auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
91 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.elem_postprocessing(); },
98 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.memory_type(); },
105 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.device_index(); },
113 [](
auto&& concrete_forest) {
114 return std::is_same_v<
typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
143 template <
typename io_t>
148 std::optional<index_type> specified_chunk_size = std::nullopt)
151 [
this, predict_type, &output, &input, &stream, &specified_chunk_size](
152 auto&& concrete_forest) {
153 if constexpr (std::is_same_v<
154 typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
156 concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
158 throw type_error(
"Input type does not match model_type");
191 template <
typename io_t>
196 std::optional<index_type> specified_chunk_size = std::nullopt)
199 [
this, predict_type, &handle, &output, &input, &specified_chunk_size](
200 auto&& concrete_forest) {
201 using model_io_t =
typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
202 if constexpr (std::is_same_v<model_io_t, io_t>) {
204 concrete_forest.predict(
207 auto constexpr
static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
208 auto constexpr
static const MAX_CHUNK_SIZE = std::size_t{64};
211 auto partition_size =
213 specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
215 for (
auto i = std::size_t{}; i < partition_count; ++i) {
217 auto rows_in_this_partition =
218 std::min(partition_size, row_count - i * partition_size);
223 raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
245 concrete_forest.predict(
246 partition_out, partition_in, stream, predict_type, specified_chunk_size);
248 raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
252 partition_out.size(),
258 throw type_error(
"Input type does not match model_type");
290 template <
typename io_t>
294 std::size_t num_rows,
298 std::optional<index_type> specified_chunk_size = std::nullopt)
303 predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
infer_kind
Definition: infer_kind.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:414
row_op
Definition: postproc_ops.hpp:22
Definition: dbscan.hpp:27
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: forest_model.hpp:37
auto elem_postprocessing()
Definition: forest_model.hpp:89
auto num_features()
Definition: forest_model.hpp:45
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:39
auto is_double_precision()
Definition: forest_model.hpp:110
auto device_index()
Definition: forest_model.hpp:103
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:192
auto num_trees()
Definition: forest_model.hpp:59
auto has_vector_leaves()
Definition: forest_model.hpp:66
auto num_outputs()
Definition: forest_model.hpp:52
auto memory_type()
Definition: forest_model.hpp:96
auto row_postprocessing()
Definition: forest_model.hpp:73
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:80
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:144
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:291
Definition: forest.hpp:34
Definition: exceptions.hpp:52
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:39
auto size() const noexcept
Definition: buffer.hpp:291
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:292
auto memory_type() const noexcept
Definition: buffer.hpp:293
Definition: handle.hpp:46
auto get_usable_stream_count() const
Definition: handle.hpp:49
auto get_next_usable_stream() const
Definition: handle.hpp:47