Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cmath>
18 #include <cstddef>
27 #include <queue>
28 #include <stack>
29 #include <treelite/c_api.h>
30 #include <treelite/tree.h>
31 #include <treelite/typeinfo.h>
32 
33 namespace ML {
34 namespace experimental {
35 namespace fil {
36 
37 namespace detail {
40 template <tree_layout layout, typename T>
43  std::conditional_t<layout == tree_layout::depth_first, std::stack<T>, std::queue<T>>;
44  void add(T const& val) { data_.push(val); }
45  void add(T const& hot, T const& distant)
46  {
47  if constexpr (layout == tree_layout::depth_first) {
48  data_.push(distant);
49  data_.push(hot);
50  } else {
51  data_.push(hot);
52  data_.push(distant);
53  }
54  }
55  auto next()
56  {
57  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
58  auto result = data_.top();
59  data_.pop();
60  return result;
61  } else {
62  auto result = data_.front();
63  data_.pop();
64  return result;
65  }
66  }
67  auto peek()
68  {
69  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
70  return data_.top();
71  } else {
72  return data_.front();
73  }
74  }
75  [[nodiscard]] auto empty() { return data_.empty(); }
76  auto size() { return data_.size(); }
77 
78  private:
79  backing_container_t data_;
80 };
81 
85  double constant = 1.0;
86 };
87 } // namespace detail
88 
94 template <tree_layout layout>
96  template <typename tl_threshold_t, typename tl_output_t>
97  struct treelite_node {
98  treelite::Tree<tl_threshold_t, tl_output_t> const& tree;
99  int node_id;
102 
103  auto is_leaf() { return tree.IsLeaf(node_id); }
104 
105  auto get_output()
106  {
107  auto result = std::vector<tl_output_t>{};
108  if (tree.HasLeafVector(node_id)) {
109  result = tree.LeafVector(node_id);
110  } else {
111  result.push_back(tree.LeafValue(node_id));
112  }
113  return result;
114  }
115 
116  auto get_categories() { return tree.MatchingCategories(node_id); }
117 
118  auto get_feature() { return tree.SplitIndex(node_id); }
119 
121  {
122  return tree.SplitType(node_id) == treelite::SplitFeatureType::kCategorical;
123  }
124 
126  {
127  auto result = false;
128  auto default_child = tree.DefaultChild(node_id);
129  if (is_categorical()) {
130  if (tree.CategoriesListRightChild(node_id)) {
131  result = (default_child == tree.RightChild(node_id));
132  } else {
133  result = (default_child == tree.LeftChild(node_id));
134  }
135  } else {
136  auto tl_operator = tree.ComparisonOp(node_id);
137  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
138  result = (default_child == tree.LeftChild(node_id));
139  } else {
140  result = (default_child == tree.RightChild(node_id));
141  }
142  }
143  return result;
144  }
145 
146  auto threshold() { return tree.Threshold(node_id); }
147 
148  auto categories()
149  {
150  auto result = decltype(tree.MatchingCategories(node_id)){};
151  if (is_categorical()) { result = tree.MatchingCategories(node_id); }
152  return result;
153  }
154 
156  {
157  auto tl_operator = tree.ComparisonOp(node_id);
158  return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
159  }
160  };
161 
162  template <typename tl_threshold_t, typename tl_output_t, typename lambda_t>
163  void node_for_each(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree, lambda_t&& lambda)
164  {
165  using node_index_t = decltype(tl_tree.LeftChild(0));
167  to_be_visited.add(node_index_t{});
168 
169  auto parent_indices = detail::traversal_container<layout, index_type>{};
170  auto cur_index = index_type{};
171  parent_indices.add(cur_index);
172 
173  while (!to_be_visited.empty()) {
174  auto node_id = to_be_visited.next();
175  auto remaining_size = to_be_visited.size();
176 
178  tl_tree, node_id, parent_indices.next(), cur_index};
179  lambda(tl_node, node_id);
180 
181  if (!tl_tree.IsLeaf(node_id)) {
182  auto tl_left_id = tl_tree.LeftChild(node_id);
183  auto tl_right_id = tl_tree.RightChild(node_id);
184  auto tl_operator = tl_tree.ComparisonOp(node_id);
185  if (!tl_node.is_categorical()) {
186  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
187  to_be_visited.add(tl_right_id, tl_left_id);
188  } else if (tl_operator == treelite::Operator::kGT ||
189  tl_operator == treelite::Operator::kGE) {
190  to_be_visited.add(tl_left_id, tl_right_id);
191  } else {
192  throw model_import_error("Unrecognized Treelite operator");
193  }
194  } else {
195  if (tl_tree.CategoriesListRightChild(node_id)) {
196  to_be_visited.add(tl_left_id, tl_right_id);
197  } else {
198  to_be_visited.add(tl_right_id, tl_left_id);
199  }
200  }
201  parent_indices.add(cur_index, cur_index);
202  }
203  ++cur_index;
204  }
205  }
206 
207  template <typename tl_threshold_t, typename tl_output_t, typename iter_t, typename lambda_t>
208  void node_transform(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
209  iter_t output_iter,
210  lambda_t&& lambda)
211  {
212  node_for_each(tl_tree, [&output_iter, &lambda](auto&& tl_node, int tl_node_id) {
213  *output_iter = lambda(tl_node);
214  ++output_iter;
215  });
216  }
217 
218  template <typename tl_threshold_t, typename tl_output_t, typename T, typename lambda_t>
219  auto node_accumulate(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
220  T init,
221  lambda_t&& lambda)
222  {
223  auto result = init;
224  node_for_each(tl_tree, [&result, &lambda](auto&& tl_node, int tl_node_id) {
225  result = lambda(result, tl_node);
226  });
227  return result;
228  }
229 
230  template <typename tl_threshold_t, typename tl_output_t>
231  auto get_nodes(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree)
232  {
233  auto result = std::vector<treelite_node<tl_threshold_t, tl_output_t>>{};
234  result.reserve(tl_tree.num_nodes);
235  node_transform(tl_tree, std::back_inserter(result), [](auto&& node) { return node; });
236  return result;
237  }
238 
239  template <typename tl_threshold_t, typename tl_output_t>
240  auto get_offsets(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree)
241  {
242  auto result = std::vector<index_type>(tl_tree.num_nodes);
243  auto nodes = get_nodes(tl_tree);
244  for (auto i = index_type{}; i < nodes.size(); ++i) {
245  // Current index should always be greater than or equal to parent index.
246  // Later children will overwrite values set by earlier children, ensuring
247  // that most distant offset is used.
248  result[nodes[i].parent_index] = index_type{i - nodes[i].parent_index};
249  }
250 
251  return result;
252  }
253 
254  template <typename lambda_t>
255  void tree_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
256  {
257  tl_model.Dispatch([&lambda](auto&& concrete_tl_model) {
258  std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
259  });
260  }
261 
262  template <typename iter_t, typename lambda_t>
263  void tree_transform(treelite::Model const& tl_model, iter_t output_iter, lambda_t&& lambda)
264  {
265  tl_model.Dispatch([&output_iter, &lambda](auto&& concrete_tl_model) {
266  std::transform(std::begin(concrete_tl_model.trees),
267  std::end(concrete_tl_model.trees),
268  output_iter,
269  lambda);
270  });
271  }
272 
273  template <typename T, typename lambda_t>
274  auto tree_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
275  {
276  auto result = init;
277  tree_for_each(tl_model, [&result, &lambda](auto&& tree) { result = lambda(result, tree); });
278  return result;
279  }
280 
281  auto num_trees(treelite::Model const& tl_model)
282  {
283  auto result = index_type{};
284  tl_model.Dispatch(
285  [&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); });
286  return result;
287  }
288 
289  auto get_offsets(treelite::Model const& tl_model)
290  {
291  auto result = std::vector<std::vector<index_type>>{};
292  result.reserve(num_trees(tl_model));
294  tl_model, std::back_inserter(result), [this](auto&& tree) { return get_offsets(tree); });
295  return result;
296  }
297 
298  auto get_tree_sizes(treelite::Model const& tl_model)
299  {
300  auto result = std::vector<index_type>{};
302  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
303  return result;
304  }
305 
306  auto get_num_class(treelite::Model const& tl_model)
307  {
308  auto result = index_type{};
309  tl_model.Dispatch(
310  [&result](auto&& concrete_tl_model) { result = concrete_tl_model.task_param.num_class; });
311  return result;
312  }
313 
314  auto get_num_feature(treelite::Model const& tl_model)
315  {
316  auto result = index_type{};
317  tl_model.Dispatch(
318  [&result](auto&& concrete_tl_model) { result = concrete_tl_model.num_feature; });
319  return result;
320  }
321 
322  auto get_max_num_categories(treelite::Model const& tl_model)
323  {
324  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
325  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
326  auto result = cur_accum;
327  for (auto&& cat : tl_node.categories()) {
328  result = (cat + 1 > result) ? cat + 1 : result;
329  }
330  return result;
331  });
332  });
333  }
334 
335  auto get_num_categorical_nodes(treelite::Model const& tl_model)
336  {
337  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
338  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
339  return cur_accum + tl_node.is_categorical();
340  });
341  });
342  }
343 
344  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
345  {
346  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
347  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
348  return cur_accum + (tl_node.is_leaf() && tl_node.get_output().size() > 1);
349  });
350  });
351  }
352 
353  auto get_average_factor(treelite::Model const& tl_model)
354  {
355  auto result = double{};
356  tl_model.Dispatch([&result](auto&& concrete_tl_model) {
357  if (concrete_tl_model.average_tree_output) {
358  if (concrete_tl_model.task_type == treelite::TaskType::kMultiClfGrovePerClass) {
359  result = concrete_tl_model.trees.size() / concrete_tl_model.task_param.num_class;
360  } else {
361  result = concrete_tl_model.trees.size();
362  }
363  } else {
364  result = 1.0;
365  }
366  });
367  return result;
368  }
369 
370  auto get_bias(treelite::Model const& tl_model)
371  {
372  auto result = double{};
373  tl_model.Dispatch(
374  [&result](auto&& concrete_tl_model) { result = concrete_tl_model.param.global_bias; });
375  return result;
376  }
377 
378  auto get_postproc_params(treelite::Model const& tl_model)
379  {
380  auto result = detail::postproc_params_t{};
381  tl_model.Dispatch([&result](auto&& concrete_tl_model) {
382  auto tl_pred_transform = std::string{concrete_tl_model.param.pred_transform};
383  if (tl_pred_transform == std::string{"identity"} ||
384  tl_pred_transform == std::string{"identity_multiclass"}) {
385  result.element = element_op::disable;
386  result.row = row_op::disable;
387  } else if (tl_pred_transform == std::string{"signed_square"}) {
388  result.element = element_op::signed_square;
389  } else if (tl_pred_transform == std::string{"hinge"}) {
390  result.element = element_op::hinge;
391  } else if (tl_pred_transform == std::string{"sigmoid"}) {
392  result.constant = concrete_tl_model.param.sigmoid_alpha;
393  result.element = element_op::sigmoid;
394  } else if (tl_pred_transform == std::string{"exponential"}) {
395  result.element = element_op::exponential;
396  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
397  result.constant = -concrete_tl_model.param.ratio_c / std::log(2);
398  result.element = element_op::exponential;
399  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
400  result.element = element_op::logarithm_one_plus_exp;
401  } else if (tl_pred_transform == std::string{"max_index"}) {
402  result.row = row_op::max_index;
403  } else if (tl_pred_transform == std::string{"softmax"}) {
404  result.row = row_op::softmax;
405  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
406  result.constant = concrete_tl_model.param.sigmoid_alpha;
407  result.element = element_op::sigmoid;
408  } else {
409  throw model_import_error{"Unrecognized Treelite pred_transform string"};
410  }
411  });
412  return result;
413  }
414 
415  auto uses_double_thresholds(treelite::Model const& tl_model)
416  {
417  auto result = false;
418  switch (tl_model.GetThresholdType()) {
419  case treelite::TypeInfo::kFloat64: result = true; break;
420  case treelite::TypeInfo::kFloat32: result = false; break;
421  default: throw model_import_error("Unrecognized Treelite threshold type");
422  }
423  return result;
424  }
425 
426  auto uses_double_outputs(treelite::Model const& tl_model)
427  {
428  auto result = false;
429  switch (tl_model.GetThresholdType()) {
430  case treelite::TypeInfo::kFloat64: result = true; break;
431  case treelite::TypeInfo::kFloat32: result = false; break;
432  case treelite::TypeInfo::kUInt32: result = false; break;
433  default: throw model_import_error("Unrecognized Treelite threshold type");
434  }
435  return result;
436  }
437 
438  auto uses_integer_outputs(treelite::Model const& tl_model)
439  {
440  auto result = false;
441  switch (tl_model.GetThresholdType()) {
442  case treelite::TypeInfo::kFloat64: result = false; break;
443  case treelite::TypeInfo::kFloat32: result = false; break;
444  case treelite::TypeInfo::kUInt32: result = true; break;
445  default: throw model_import_error("Unrecognized Treelite threshold type");
446  }
447  return result;
448  }
449 
454  template <index_type variant_index>
455  auto import_to_specific_variant(index_type target_variant_index,
456  treelite::Model const& tl_model,
457  index_type num_class,
458  index_type num_feature,
459  index_type max_num_categories,
460  std::vector<std::vector<index_type>> const& offsets,
461  index_type align_bytes = index_type{},
463  int device = 0,
465  {
466  auto result = decision_forest_variant{};
467  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
468  if (variant_index == target_variant_index) {
469  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
470  auto builder =
471  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
472  auto tree_count = num_trees(tl_model);
473  auto tree_index = index_type{};
474  tree_for_each(tl_model, [this, &builder, &tree_index, &offsets](auto&& tree) {
475  builder.start_new_tree();
476  auto node_index = index_type{};
477  node_for_each(
478  tree, [&builder, &tree_index, &node_index, &offsets](auto&& node, int tl_node_id) {
479  if (node.is_leaf()) {
480  auto output = node.get_output();
481  builder.set_output_size(output.size());
482  if (output.size() > index_type{1}) {
483  builder.add_leaf_vector_node(std::begin(output), std::end(output), tl_node_id);
484  } else {
485  builder.add_node(typename forest_model_t::io_type(output[0]), tl_node_id, true);
486  }
487  } else {
488  if (node.is_categorical()) {
489  auto categories = node.get_categories();
490  builder.add_categorical_node(std::begin(categories),
491  std::end(categories),
492  tl_node_id,
493  node.default_distant(),
494  node.get_feature(),
495  offsets[tree_index][node_index]);
496  } else {
497  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
498  tl_node_id,
499  false,
500  node.default_distant(),
501  false,
502  node.get_feature(),
503  offsets[tree_index][node_index],
504  node.is_inclusive());
505  }
506  }
507  ++node_index;
508  });
509  ++tree_index;
510  });
511 
512  builder.set_average_factor(get_average_factor(tl_model));
513  builder.set_bias(get_bias(tl_model));
514  auto postproc_params = get_postproc_params(tl_model);
515  builder.set_element_postproc(postproc_params.element);
516  builder.set_row_postproc(postproc_params.row);
517  builder.set_postproc_constant(postproc_params.constant);
518 
519  result.template emplace<variant_index>(
520  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
521  } else {
522  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
523  tl_model,
524  num_class,
525  num_feature,
526  max_num_categories,
527  offsets,
528  align_bytes,
529  mem_type,
530  device,
531  stream);
532  }
533  }
534  return result;
535  }
536 
559  auto import(treelite::Model const& tl_model,
560  index_type align_bytes = index_type{},
561  std::optional<bool> use_double_precision = std::nullopt,
563  int device = 0,
565  {
566  auto result = decision_forest_variant{};
567  auto num_feature = get_num_feature(tl_model);
568  auto max_num_categories = get_max_num_categories(tl_model);
569  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
570  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
571  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
572 
573  auto offsets = get_offsets(tl_model);
574  auto max_offset = std::accumulate(
575  std::begin(offsets),
576  std::end(offsets),
577  index_type{},
578  [&offsets](auto&& cur_max, auto&& tree_offsets) {
579  return std::max(cur_max,
580  *std::max_element(std::begin(tree_offsets), std::end(tree_offsets)));
581  });
582  auto tree_sizes = std::vector<index_type>{};
583  std::transform(std::begin(offsets),
584  std::end(offsets),
585  std::back_inserter(tree_sizes),
586  [](auto&& tree_offsets) { return tree_offsets.size(); });
587 
588  auto variant_index = get_forest_variant_index(use_double_thresholds,
589  max_offset,
590  num_feature,
591  num_categorical_nodes,
592  max_num_categories,
593  num_leaf_vector_nodes,
594  layout);
595  auto num_class = get_num_class(tl_model);
596  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
597  tl_model,
598  num_class,
599  num_feature,
600  max_num_categories,
601  offsets,
602  align_bytes,
603  dev_type,
604  device,
605  stream)};
606  }
607 };
608 
632 auto import_from_treelite_model(treelite::Model const& tl_model,
633  tree_layout layout = preferred_tree_layout,
634  index_type align_bytes = index_type{},
635  std::optional<bool> use_double_precision = std::nullopt,
637  int device = 0,
639 {
640  auto result = forest_model{};
641  switch (layout) {
642  case tree_layout::depth_first:
643  result = treelite_importer<tree_layout::depth_first>{}.import(
644  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
645  break;
646  case tree_layout::breadth_first:
647  result = treelite_importer<tree_layout::breadth_first>{}.import(
648  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
649  break;
650  }
651  return result;
652 }
653 
679  tree_layout layout = preferred_tree_layout,
680  index_type align_bytes = index_type{},
681  std::optional<bool> use_double_precision = std::nullopt,
683  int device = 0,
685 {
686  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
687  layout,
688  align_bytes,
689  use_double_precision,
690  dev_type,
691  device,
692  stream);
693 }
694 
695 } // namespace fil
696 } // namespace experimental
697 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
tree_layout
Definition: tree_layout.hpp:20
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:434
auto import_from_treelite_handle(ModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:678
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:414
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:632
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:27
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:82
element_op element
Definition: treelite_importer.hpp:83
double constant
Definition: treelite_importer.hpp:85
row_op row
Definition: treelite_importer.hpp:84
Definition: treelite_importer.hpp:41
auto empty()
Definition: treelite_importer.hpp:75
auto next()
Definition: treelite_importer.hpp:55
auto peek()
Definition: treelite_importer.hpp:67
void add(T const &val)
Definition: treelite_importer.hpp:44
auto size()
Definition: treelite_importer.hpp:76
void add(T const &hot, T const &distant)
Definition: treelite_importer.hpp:45
std::conditional_t< layout==tree_layout::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: treelite_importer.hpp:43
Definition: exceptions.hpp:36
Definition: node.hpp:91
auto get_feature()
Definition: treelite_importer.hpp:118
auto is_categorical()
Definition: treelite_importer.hpp:120
auto get_categories()
Definition: treelite_importer.hpp:116
auto threshold()
Definition: treelite_importer.hpp:146
index_type parent_index
Definition: treelite_importer.hpp:100
auto is_leaf()
Definition: treelite_importer.hpp:103
auto get_output()
Definition: treelite_importer.hpp:105
auto default_distant()
Definition: treelite_importer.hpp:125
index_type own_index
Definition: treelite_importer.hpp:101
auto is_inclusive()
Definition: treelite_importer.hpp:155
int node_id
Definition: treelite_importer.hpp:99
treelite::Tree< tl_threshold_t, tl_output_t > const & tree
Definition: treelite_importer.hpp:98
auto categories()
Definition: treelite_importer.hpp:148
Definition: treelite_importer.hpp:95
auto get_nodes(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:231
void tree_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:263
auto node_accumulate(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:219
void node_transform(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:208
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:426
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:314
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:298
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< std::vector< index_type >> const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:455
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:274
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:344
void node_for_each(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, lambda_t &&lambda)
Definition: treelite_importer.hpp:163
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:306
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:378
auto get_offsets(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:240
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:335
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:370
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:415
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:438
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:353
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:289
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite_importer.hpp:255
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:322
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:281
void * ModelHandle
Definition: treelite_defs.hpp:23