Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  24.04
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
arima_common.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/util/cudart_utils.hpp>
20 
21 #include <rmm/mr/device/per_device_resource.hpp>
22 
23 #include <cuda_runtime.h>
24 #include <thrust/execution_policy.h>
25 #include <thrust/for_each.h>
26 #include <thrust/iterator/counting_iterator.h>
27 
28 #include <algorithm>
29 
30 namespace ML {
31 
35 struct ARIMAOrder {
36  int p; // Basic order
37  int d;
38  int q;
39  int P; // Seasonal order
40  int D;
41  int Q;
42  int s; // Seasonal period
43  int k; // Fit intercept?
44  int n_exog; // Number of exogenous regressors
45 
46  inline int n_diff() const { return d + s * D; }
47  inline int n_phi() const { return p + s * P; }
48  inline int n_theta() const { return q + s * Q; }
49  inline int r() const { return std::max(n_phi(), n_theta() + 1); }
50  inline int rd() const { return n_diff() + r(); }
51  inline int complexity() const { return p + P + q + Q + k + n_exog + 1; }
52  inline bool need_diff() const { return static_cast<bool>(d + D); }
53 };
54 
61 template <typename DataT>
62 struct ARIMAParams {
63  DataT* mu = nullptr;
64  DataT* beta = nullptr;
65  DataT* ar = nullptr;
66  DataT* ma = nullptr;
67  DataT* sar = nullptr;
68  DataT* sma = nullptr;
69  DataT* sigma2 = nullptr;
70 
80  void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
81  {
82  rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
83  if (order.k && !tr) mu = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream);
84  if (order.n_exog && !tr)
85  beta = (DataT*)rmm_alloc->allocate(order.n_exog * batch_size * sizeof(DataT), stream);
86  if (order.p) ar = (DataT*)rmm_alloc->allocate(order.p * batch_size * sizeof(DataT), stream);
87  if (order.q) ma = (DataT*)rmm_alloc->allocate(order.q * batch_size * sizeof(DataT), stream);
88  if (order.P) sar = (DataT*)rmm_alloc->allocate(order.P * batch_size * sizeof(DataT), stream);
89  if (order.Q) sma = (DataT*)rmm_alloc->allocate(order.Q * batch_size * sizeof(DataT), stream);
90  sigma2 = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream);
91  }
92 
102  void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
103  {
104  rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
105  if (order.k && !tr) rmm_alloc->deallocate(mu, batch_size * sizeof(DataT), stream);
106  if (order.n_exog && !tr)
107  rmm_alloc->deallocate(beta, order.n_exog * batch_size * sizeof(DataT), stream);
108  if (order.p) rmm_alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream);
109  if (order.q) rmm_alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream);
110  if (order.P) rmm_alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream);
111  if (order.Q) rmm_alloc->deallocate(sma, order.Q * batch_size * sizeof(DataT), stream);
112  rmm_alloc->deallocate(sigma2, batch_size * sizeof(DataT), stream);
113  }
114 
124  void pack(const ARIMAOrder& order, int batch_size, DataT* param_vec, cudaStream_t stream) const
125  {
126  int N = order.complexity();
127  auto counting = thrust::make_counting_iterator(0);
128  // The device lambda can't capture structure members...
129  const DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
130  *_sigma2 = sigma2;
131  thrust::for_each(
132  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
133  DataT* param = param_vec + bid * N;
134  if (order.k) {
135  *param = _mu[bid];
136  param++;
137  }
138  for (int i = 0; i < order.n_exog; i++) {
139  param[i] = _beta[order.n_exog * bid + i];
140  }
141  param += order.n_exog;
142  for (int ip = 0; ip < order.p; ip++) {
143  param[ip] = _ar[order.p * bid + ip];
144  }
145  param += order.p;
146  for (int iq = 0; iq < order.q; iq++) {
147  param[iq] = _ma[order.q * bid + iq];
148  }
149  param += order.q;
150  for (int iP = 0; iP < order.P; iP++) {
151  param[iP] = _sar[order.P * bid + iP];
152  }
153  param += order.P;
154  for (int iQ = 0; iQ < order.Q; iQ++) {
155  param[iQ] = _sma[order.Q * bid + iQ];
156  }
157  param += order.Q;
158  *param = _sigma2[bid];
159  });
160  }
161 
171  void unpack(const ARIMAOrder& order, int batch_size, const DataT* param_vec, cudaStream_t stream)
172  {
173  int N = order.complexity();
174  auto counting = thrust::make_counting_iterator(0);
175  // The device lambda can't capture structure members...
176  DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
177  *_sigma2 = sigma2;
178  thrust::for_each(
179  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
180  const DataT* param = param_vec + bid * N;
181  if (order.k) {
182  _mu[bid] = *param;
183  param++;
184  }
185  for (int i = 0; i < order.n_exog; i++) {
186  _beta[order.n_exog * bid + i] = param[i];
187  }
188  param += order.n_exog;
189  for (int ip = 0; ip < order.p; ip++) {
190  _ar[order.p * bid + ip] = param[ip];
191  }
192  param += order.p;
193  for (int iq = 0; iq < order.q; iq++) {
194  _ma[order.q * bid + iq] = param[iq];
195  }
196  param += order.q;
197  for (int iP = 0; iP < order.P; iP++) {
198  _sar[order.P * bid + iP] = param[iP];
199  }
200  param += order.P;
201  for (int iQ = 0; iQ < order.Q; iQ++) {
202  _sma[order.Q * bid + iQ] = param[iQ];
203  }
204  param += order.Q;
205  _sigma2[bid] = *param;
206  });
207  }
208 };
209 
216 template <typename T, int ALIGN = 256>
217 struct ARIMAMemory {
229 
230  size_t size;
231 
232  protected:
233  char* buf;
234 
235  template <bool assign, typename ValType>
236  inline void append_buffer(ValType*& ptr, size_t n_elem)
237  {
238  if (assign) { ptr = reinterpret_cast<ValType*>(buf + size); }
239  size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
240  }
241 
242  template <bool assign>
243  inline void buf_offsets(const ARIMAOrder& order,
244  int batch_size,
245  int n_obs,
246  char* in_buf = nullptr)
247  {
248  buf = in_buf;
249  size = 0;
250 
251  int r = order.r();
252  int rd = order.rd();
253  int N = order.complexity();
254  int n_diff = order.n_diff();
255 
256  append_buffer<assign>(params_mu, order.k * batch_size);
257  append_buffer<assign>(params_beta, order.n_exog * batch_size);
258  append_buffer<assign>(params_ar, order.p * batch_size);
259  append_buffer<assign>(params_ma, order.q * batch_size);
260  append_buffer<assign>(params_sar, order.P * batch_size);
261  append_buffer<assign>(params_sma, order.Q * batch_size);
262  append_buffer<assign>(params_sigma2, batch_size);
263 
264  append_buffer<assign>(Tparams_ar, order.p * batch_size);
265  append_buffer<assign>(Tparams_ma, order.q * batch_size);
266  append_buffer<assign>(Tparams_sar, order.P * batch_size);
267  append_buffer<assign>(Tparams_sma, order.Q * batch_size);
268  append_buffer<assign>(Tparams_sigma2, batch_size);
269 
270  append_buffer<assign>(d_params, N * batch_size);
271  append_buffer<assign>(d_Tparams, N * batch_size);
272  append_buffer<assign>(Z_dense, rd * batch_size);
273  append_buffer<assign>(Z_batches, batch_size);
274  append_buffer<assign>(R_dense, rd * batch_size);
275  append_buffer<assign>(R_batches, batch_size);
276  append_buffer<assign>(T_dense, rd * rd * batch_size);
277  append_buffer<assign>(T_batches, batch_size);
278  append_buffer<assign>(RQ_dense, rd * batch_size);
279  append_buffer<assign>(RQ_batches, batch_size);
280  append_buffer<assign>(RQR_dense, rd * rd * batch_size);
281  append_buffer<assign>(RQR_batches, batch_size);
282  append_buffer<assign>(P_dense, rd * rd * batch_size);
283  append_buffer<assign>(P_batches, batch_size);
284  append_buffer<assign>(alpha_dense, rd * batch_size);
285  append_buffer<assign>(alpha_batches, batch_size);
286  append_buffer<assign>(ImT_dense, r * r * batch_size);
287  append_buffer<assign>(ImT_batches, batch_size);
288  append_buffer<assign>(ImT_inv_dense, r * r * batch_size);
289  append_buffer<assign>(ImT_inv_batches, batch_size);
290  append_buffer<assign>(ImT_inv_P, r * batch_size);
291  append_buffer<assign>(ImT_inv_info, batch_size);
292  append_buffer<assign>(v_tmp_dense, rd * batch_size);
293  append_buffer<assign>(v_tmp_batches, batch_size);
294  append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
295  append_buffer<assign>(m_tmp_batches, batch_size);
296  append_buffer<assign>(K_dense, rd * batch_size);
297  append_buffer<assign>(K_batches, batch_size);
298  append_buffer<assign>(TP_dense, rd * rd * batch_size);
299  append_buffer<assign>(TP_batches, batch_size);
300 
301  append_buffer<assign>(pred, n_obs * batch_size);
302  append_buffer<assign>(y_diff, n_obs * batch_size);
303  append_buffer<assign>(exog_diff, n_obs * order.n_exog * batch_size);
304  append_buffer<assign>(loglike, batch_size);
305  append_buffer<assign>(loglike_base, batch_size);
306  append_buffer<assign>(loglike_pert, batch_size);
307  append_buffer<assign>(x_pert, N * batch_size);
308 
309  if (n_diff > 0) {
310  append_buffer<assign>(Ts_dense, r * r * batch_size);
311  append_buffer<assign>(Ts_batches, batch_size);
312  append_buffer<assign>(RQRs_dense, r * r * batch_size);
313  append_buffer<assign>(RQRs_batches, batch_size);
314  append_buffer<assign>(Ps_dense, r * r * batch_size);
315  append_buffer<assign>(Ps_batches, batch_size);
316  }
317 
318  if (r <= 5) {
319  // Note: temp mem for the direct Lyapunov solver grows very quickly!
320  // This solver is used iff the condition above is satisfied
321  append_buffer<assign>(I_m_AxA_dense, r * r * r * r * batch_size);
322  append_buffer<assign>(I_m_AxA_batches, batch_size);
323  append_buffer<assign>(I_m_AxA_inv_dense, r * r * r * r * batch_size);
324  append_buffer<assign>(I_m_AxA_inv_batches, batch_size);
325  append_buffer<assign>(I_m_AxA_P, r * r * batch_size);
326  append_buffer<assign>(I_m_AxA_info, batch_size);
327  }
328  }
329 
331  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs)
332  {
333  buf_offsets<false>(order, batch_size, n_obs);
334  }
335 
336  public:
344  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, char* in_buf)
345  {
346  buf_offsets<true>(order, batch_size, n_obs, in_buf);
347  }
348 
355  static size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs)
356  {
357  ARIMAMemory temp(order, batch_size, n_obs);
358  return temp.size;
359  }
360 };
361 
362 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
Definition: dbscan.hpp:30
Definition: arima_common.h:217
T * x_pert
Definition: arima_common.h:222
T * Tparams_sar
Definition: arima_common.h:219
T * K_dense
Definition: arima_common.h:221
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:243
T ** R_batches
Definition: arima_common.h:224
T ** RQ_batches
Definition: arima_common.h:224
T * params_mu
Definition: arima_common.h:218
T * T_dense
Definition: arima_common.h:220
T ** Ps_batches
Definition: arima_common.h:227
T * alpha_dense
Definition: arima_common.h:220
int * ImT_inv_P
Definition: arima_common.h:228
T * Z_dense
Definition: arima_common.h:220
T * d_params
Definition: arima_common.h:219
T * loglike_base
Definition: arima_common.h:222
T * ImT_inv_dense
Definition: arima_common.h:221
T * Tparams_sma
Definition: arima_common.h:219
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:355
T * y_diff
Definition: arima_common.h:221
T * TP_dense
Definition: arima_common.h:221
T * Tparams_ar
Definition: arima_common.h:219
T * RQRs_dense
Definition: arima_common.h:223
T ** ImT_batches
Definition: arima_common.h:225
T ** T_batches
Definition: arima_common.h:224
T * I_m_AxA_inv_dense
Definition: arima_common.h:222
T * pred
Definition: arima_common.h:221
T * ImT_dense
Definition: arima_common.h:220
int * ImT_inv_info
Definition: arima_common.h:228
T ** alpha_batches
Definition: arima_common.h:225
T * params_sma
Definition: arima_common.h:218
T * Ps_dense
Definition: arima_common.h:223
T ** P_batches
Definition: arima_common.h:224
T * loglike_pert
Definition: arima_common.h:222
T ** m_tmp_batches
Definition: arima_common.h:225
T ** Z_batches
Definition: arima_common.h:224
int * I_m_AxA_P
Definition: arima_common.h:228
T * m_tmp_dense
Definition: arima_common.h:221
T * params_sar
Definition: arima_common.h:218
T ** K_batches
Definition: arima_common.h:226
T * RQR_dense
Definition: arima_common.h:220
size_t size
Definition: arima_common.h:230
T ** I_m_AxA_inv_batches
Definition: arima_common.h:226
T * d_Tparams
Definition: arima_common.h:219
T ** ImT_inv_batches
Definition: arima_common.h:225
T * I_m_AxA_dense
Definition: arima_common.h:222
T * v_tmp_dense
Definition: arima_common.h:221
T * Tparams_ma
Definition: arima_common.h:219
T * params_beta
Definition: arima_common.h:218
int * I_m_AxA_info
Definition: arima_common.h:228
T * P_dense
Definition: arima_common.h:220
T ** v_tmp_batches
Definition: arima_common.h:225
T ** I_m_AxA_batches
Definition: arima_common.h:226
T * params_ar
Definition: arima_common.h:218
char * buf
Definition: arima_common.h:233
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:331
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:236
T * R_dense
Definition: arima_common.h:220
T ** RQRs_batches
Definition: arima_common.h:227
T * loglike
Definition: arima_common.h:222
T * exog_diff
Definition: arima_common.h:221
T * RQ_dense
Definition: arima_common.h:220
T * params_sigma2
Definition: arima_common.h:218
T ** RQR_batches
Definition: arima_common.h:224
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:344
T * Tparams_sigma2
Definition: arima_common.h:219
T ** TP_batches
Definition: arima_common.h:226
T * Ts_dense
Definition: arima_common.h:222
T * params_ma
Definition: arima_common.h:218
T ** Ts_batches
Definition: arima_common.h:226
Definition: arima_common.h:35
int p
Definition: arima_common.h:36
int s
Definition: arima_common.h:42
int n_phi() const
Definition: arima_common.h:47
int P
Definition: arima_common.h:39
int r() const
Definition: arima_common.h:49
int n_exog
Definition: arima_common.h:44
int rd() const
Definition: arima_common.h:50
int D
Definition: arima_common.h:40
int complexity() const
Definition: arima_common.h:51
int q
Definition: arima_common.h:38
bool need_diff() const
Definition: arima_common.h:52
int n_theta() const
Definition: arima_common.h:48
int Q
Definition: arima_common.h:41
int d
Definition: arima_common.h:37
int k
Definition: arima_common.h:43
int n_diff() const
Definition: arima_common.h:46
Definition: arima_common.h:62
DataT * mu
Definition: arima_common.h:63
DataT * sma
Definition: arima_common.h:68
DataT * beta
Definition: arima_common.h:64
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:102
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
Definition: arima_common.h:171
DataT * ma
Definition: arima_common.h:66
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:80
DataT * ar
Definition: arima_common.h:65
DataT * sar
Definition: arima_common.h:67
DataT * sigma2
Definition: arima_common.h:69
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const
Definition: arima_common.h:124