Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
postprocessor.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #ifndef __CUDACC__
18 #include <math.h>
19 #endif
23 #include <limits>
24 #include <stddef.h>
25 #include <type_traits>
26 
27 namespace ML {
28 namespace experimental {
29 namespace fil {
30 
31 /* Convert the postprocessing operations into a single value
32  * representing what must be done in the inference kernel
33  */
34 HOST DEVICE inline auto constexpr ops_to_val(row_op row_wise, element_op elem_wise)
35 {
36  return (static_cast<std::underlying_type_t<row_op>>(row_wise) |
37  static_cast<std::underlying_type_t<element_op>>(elem_wise));
38 }
39 
40 /*
41  * Perform postprocessing on raw forest output
42  *
43  * @param val Pointer to the raw forest output
44  * @param output_count The number of output values per row
45  * @param out Pointer to the output buffer
46  * @param stride Number of elements between the first element that must be
47  * summed for a particular output element and the next. This is typically
48  * equal to the number of "groves" of trees over which the computation
49  * was divided.
50  * @param average_factor The factor by which to divide during the
51  * normalization step of postprocessing
52  * @param bias The bias factor to subtract off during the
53  * normalization step of postprocessing
54  * @param constant If the postprocessing operation requires a constant,
55  * it can be passed here.
56  */
57 template <row_op row_wise_v, element_op elem_wise_v, typename io_t>
58 HOST DEVICE void postprocess(io_t* val,
59  index_type output_count,
60  io_t* out,
61  index_type stride = index_type{1},
62  io_t average_factor = io_t{1},
63  io_t bias = io_t{0},
64  io_t constant = io_t{1})
65 {
66 #pragma GCC diagnostic push
67 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
68  auto max_index = index_type{};
69  auto max_value = std::numeric_limits<io_t>::lowest();
70 #pragma GCC diagnostic pop
71  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
72  auto workspace_index = output_index * stride;
73  val[workspace_index] = val[workspace_index] / average_factor + bias;
74  if constexpr (elem_wise_v == element_op::signed_square) {
75  val[workspace_index] =
76  copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
77  } else if constexpr (elem_wise_v == element_op::hinge) {
78  val[workspace_index] = io_t(val[workspace_index] > io_t{});
79  } else if constexpr (elem_wise_v == element_op::sigmoid) {
80  val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
81  } else if constexpr (elem_wise_v == element_op::exponential) {
82  val[workspace_index] = exp(val[workspace_index] / constant);
83  } else if constexpr (elem_wise_v == element_op::logarithm_one_plus_exp) {
84  val[workspace_index] = log1p(exp(val[workspace_index] / constant));
85  }
86  if constexpr (row_wise_v == row_op::softmax || row_wise_v == row_op::max_index) {
87  auto is_new_max = val[workspace_index] > max_value;
88  max_index = is_new_max * output_index + (!is_new_max) * max_index;
89  max_value = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
90  }
91  }
92 
93  if constexpr (row_wise_v == row_op::max_index) {
94  *out = max_index;
95  } else {
96 #pragma GCC diagnostic push
97 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
98  auto softmax_normalization = io_t{};
99 #pragma GCC diagnostic pop
100  if constexpr (row_wise_v == row_op::softmax) {
101  for (auto workspace_index = index_type{}; workspace_index < output_count * stride;
102  workspace_index += stride) {
103  val[workspace_index] = exp(val[workspace_index] - max_value);
104  softmax_normalization += val[workspace_index];
105  }
106  }
107 
108  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
109  auto workspace_index = output_index * stride;
110  if constexpr (row_wise_v == row_op::softmax) {
111  out[output_index] = val[workspace_index] / softmax_normalization;
112  } else {
113  out[output_index] = val[workspace_index];
114  }
115  }
116  }
117 }
118 
119 /*
120  * Struct which holds all data necessary to perform postprocessing on raw
121  * output of a forest model
122  *
123  * @tparam io_t The type used for input and output to/from the model
124  * (typically float/double)
125  * @param row_wise Enum value representing the row-wise post-processing
126  * operation to perform on the output
127  * @param elem_wise Enum value representing the element-wise post-processing
128  * operation to perform on the output
129  * @param average_factor The factor by which to divide during the
130  * normalization step of postprocessing
131  * @param bias The bias factor to subtract off during the
132  * normalization step of postprocessing
133  * @param constant If the postprocessing operation requires a constant,
134  * it can be passed here.
135  */
136 template <typename io_t>
139  element_op elem_wise = element_op::disable,
140  io_t average_factor = io_t{1},
141  io_t bias = io_t{0},
142  io_t constant = io_t{1})
143  : average_factor_{average_factor},
144  bias_{bias},
145  constant_{constant},
146  row_wise_{row_wise},
147  elem_wise_{elem_wise}
148  {
149  }
150 
151  HOST DEVICE void operator()(io_t* val,
152  index_type output_count,
153  io_t* out,
154  index_type stride = index_type{1}) const
155  {
156  switch (ops_to_val(row_wise_, elem_wise_)) {
159  val, output_count, out, stride, average_factor_, bias_, constant_);
160  break;
163  val, output_count, out, stride, average_factor_, bias_, constant_);
164  break;
167  val, output_count, out, stride, average_factor_, bias_, constant_);
168  break;
171  val, output_count, out, stride, average_factor_, bias_, constant_);
172  break;
175  val, output_count, out, stride, average_factor_, bias_, constant_);
176  break;
178  postprocess<row_op::softmax, element_op::disable>(
179  val, output_count, out, stride, average_factor_, bias_, constant_);
180  break;
183  val, output_count, out, stride, average_factor_, bias_, constant_);
184  break;
186  postprocess<row_op::softmax, element_op::hinge>(
187  val, output_count, out, stride, average_factor_, bias_, constant_);
188  break;
190  postprocess<row_op::softmax, element_op::sigmoid>(
191  val, output_count, out, stride, average_factor_, bias_, constant_);
192  break;
195  val, output_count, out, stride, average_factor_, bias_, constant_);
196  break;
199  val, output_count, out, stride, average_factor_, bias_, constant_);
200  break;
203  val, output_count, out, stride, average_factor_, bias_, constant_);
204  break;
207  val, output_count, out, stride, average_factor_, bias_, constant_);
208  break;
211  val, output_count, out, stride, average_factor_, bias_, constant_);
212  break;
215  val, output_count, out, stride, average_factor_, bias_, constant_);
216  break;
219  val, output_count, out, stride, average_factor_, bias_, constant_);
220  break;
223  val, output_count, out, stride, average_factor_, bias_, constant_);
224  break;
225  default:
226  postprocess<row_op::disable, element_op::disable>(
227  val, output_count, out, stride, average_factor_, bias_, constant_);
228  }
229  }
230 
231  private:
232  io_t average_factor_;
233  io_t bias_;
234  io_t constant_;
235  row_op row_wise_;
236  element_op elem_wise_;
237 };
238 } // namespace fil
239 } // namespace experimental
240 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:34
#define HOST
Definition: gpu_support.hpp:33
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:34
row_op
Definition: postproc_ops.hpp:22
HOST DEVICE void postprocess(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:58
Definition: dbscan.hpp:27
Definition: postprocessor.hpp:137
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:138
HOST DEVICE void operator()(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:151