Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
infer_macros.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cstddef>
28 #include <variant>
29 
30 /* Macro which expands to the valid arguments to an inference call for a forest
31  * model without vector leaves or non-local categorical data.*/
32 #define CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index) \
33  (CUML_FIL_FOREST(variant_index) const&, \
34  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
35  CUML_FIL_SPEC(variant_index)::threshold_type*, \
36  CUML_FIL_SPEC(variant_index)::threshold_type*, \
37  index_type, \
38  index_type, \
39  index_type, \
40  std::nullptr_t, \
41  std::nullptr_t, \
42  infer_kind, \
43  std::optional<index_type>, \
44  raft_proto::device_id<dev>, \
45  raft_proto::cuda_stream stream)
46 
47 /* Macro which expands to the valid arguments to an inference call for a forest
48  * model with vector leaves but without non-local categorical data.*/
49 #define CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index) \
50  (CUML_FIL_FOREST(variant_index) const&, \
51  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
52  CUML_FIL_SPEC(variant_index)::threshold_type*, \
53  CUML_FIL_SPEC(variant_index)::threshold_type*, \
54  index_type, \
55  index_type, \
56  index_type, \
57  CUML_FIL_SPEC(variant_index)::threshold_type*, \
58  std::nullptr_t, \
59  infer_kind, \
60  std::optional<index_type>, \
61  raft_proto::device_id<dev>, \
62  raft_proto::cuda_stream stream)
63 
64 /* Macro which expands to the valid arguments to an inference call for a forest
65  * model without vector leaves but with non-local categorical data.*/
66 #define CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index) \
67  (CUML_FIL_FOREST(variant_index) const&, \
68  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
69  CUML_FIL_SPEC(variant_index)::threshold_type*, \
70  CUML_FIL_SPEC(variant_index)::threshold_type*, \
71  index_type, \
72  index_type, \
73  index_type, \
74  std::nullptr_t, \
75  CUML_FIL_SPEC(variant_index)::index_type*, \
76  infer_kind, \
77  std::optional<index_type>, \
78  raft_proto::device_id<dev>, \
79  raft_proto::cuda_stream stream)
80 
81 /* Macro which expands to the valid arguments to an inference call for a forest
82  * model with vector leaves and with non-local categorical data.*/
83 #define CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index) \
84  (CUML_FIL_FOREST(variant_index) const&, \
85  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
86  CUML_FIL_SPEC(variant_index)::threshold_type*, \
87  CUML_FIL_SPEC(variant_index)::threshold_type*, \
88  index_type, \
89  index_type, \
90  index_type, \
91  CUML_FIL_SPEC(variant_index)::threshold_type*, \
92  CUML_FIL_SPEC(variant_index)::index_type*, \
93  infer_kind, \
94  std::optional<index_type>, \
95  raft_proto::device_id<dev>, \
96  raft_proto::cuda_stream stream)
97 
98 /* Macro which expands to the declaration of an inference template for a forest
99  * of the type indicated by the variant index */
100 #define CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, categorical) \
101  template_type void infer<dev, categorical, CUML_FIL_FOREST(variant_index)>
102 
103 /* Macro which expands to the declaration of an inference template for a forest
104  * of the type indicated by the variant index on the given device type without
105  * vector leaves or categorical nodes*/
106 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
107  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
108  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
109 
110 /* Macro which expands to the declaration of an inference template for a forest
111  * of the type indicated by the variant index on the given device type without
112  * vector leaves and with only local categorical nodes*/
113 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
114  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
115  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
116 
117 /* Macro which expands to the declaration of an inference template for a forest
118  * of the type indicated by the variant index on the given device type without
119  * vector leaves and with non-local categorical nodes*/
120 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
121  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
122  CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index);
123 
124 /* Macro which expands to the declaration of an inference template for a forest
125  * of the type indicated by the variant index on the given device type with
126  * vector leaves and without categorical nodes*/
127 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
128  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
129  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
130 
131 /* Macro which expands to the declaration of an inference template for a forest
132  * of the type indicated by the variant index on the given device type with
133  * vector leaves and with only local categorical nodes*/
134 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
135  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
136  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
137 
138 /* Macro which expands to the declaration of an inference template for a forest
139  * of the type indicated by the variant index on the given device type with
140  * vector leaves and with non-local categorical nodes*/
141 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
142  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
143  CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index);
144 
145 /* Macro which expands to the declaration of all valid inference templates for
146  * the given device on the forest type specified by the given variant index */
147 #define CUML_FIL_INFER_ALL(template_type, dev, variant_index) \
148  CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
149  CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
150  CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
151  CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
152  CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
153  CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index)