Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
arima_common.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuda_runtime.h>
20 
21 #include <algorithm>
22 
23 #include <raft/util/cudart_utils.hpp>
24 #include <rmm/mr/device/per_device_resource.hpp>
25 #include <thrust/execution_policy.h>
26 #include <thrust/for_each.h>
27 #include <thrust/iterator/counting_iterator.h>
28 
29 namespace ML {
30 
34 struct ARIMAOrder {
35  int p; // Basic order
36  int d;
37  int q;
38  int P; // Seasonal order
39  int D;
40  int Q;
41  int s; // Seasonal period
42  int k; // Fit intercept?
43  int n_exog; // Number of exogenous regressors
44 
45  inline int n_diff() const { return d + s * D; }
46  inline int n_phi() const { return p + s * P; }
47  inline int n_theta() const { return q + s * Q; }
48  inline int r() const { return std::max(n_phi(), n_theta() + 1); }
49  inline int rd() const { return n_diff() + r(); }
50  inline int complexity() const { return p + P + q + Q + k + n_exog + 1; }
51  inline bool need_diff() const { return static_cast<bool>(d + D); }
52 };
53 
60 template <typename DataT>
61 struct ARIMAParams {
62  DataT* mu = nullptr;
63  DataT* beta = nullptr;
64  DataT* ar = nullptr;
65  DataT* ma = nullptr;
66  DataT* sar = nullptr;
67  DataT* sma = nullptr;
68  DataT* sigma2 = nullptr;
69 
79  void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
80  {
81  rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
82  if (order.k && !tr) mu = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream);
83  if (order.n_exog && !tr)
84  beta = (DataT*)rmm_alloc->allocate(order.n_exog * batch_size * sizeof(DataT), stream);
85  if (order.p) ar = (DataT*)rmm_alloc->allocate(order.p * batch_size * sizeof(DataT), stream);
86  if (order.q) ma = (DataT*)rmm_alloc->allocate(order.q * batch_size * sizeof(DataT), stream);
87  if (order.P) sar = (DataT*)rmm_alloc->allocate(order.P * batch_size * sizeof(DataT), stream);
88  if (order.Q) sma = (DataT*)rmm_alloc->allocate(order.Q * batch_size * sizeof(DataT), stream);
89  sigma2 = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream);
90  }
91 
101  void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
102  {
103  rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
104  if (order.k && !tr) rmm_alloc->deallocate(mu, batch_size * sizeof(DataT), stream);
105  if (order.n_exog && !tr)
106  rmm_alloc->deallocate(beta, order.n_exog * batch_size * sizeof(DataT), stream);
107  if (order.p) rmm_alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream);
108  if (order.q) rmm_alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream);
109  if (order.P) rmm_alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream);
110  if (order.Q) rmm_alloc->deallocate(sma, order.Q * batch_size * sizeof(DataT), stream);
111  rmm_alloc->deallocate(sigma2, batch_size * sizeof(DataT), stream);
112  }
113 
123  void pack(const ARIMAOrder& order, int batch_size, DataT* param_vec, cudaStream_t stream) const
124  {
125  int N = order.complexity();
126  auto counting = thrust::make_counting_iterator(0);
127  // The device lambda can't capture structure members...
128  const DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
129  *_sigma2 = sigma2;
130  thrust::for_each(
131  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
132  DataT* param = param_vec + bid * N;
133  if (order.k) {
134  *param = _mu[bid];
135  param++;
136  }
137  for (int i = 0; i < order.n_exog; i++) {
138  param[i] = _beta[order.n_exog * bid + i];
139  }
140  param += order.n_exog;
141  for (int ip = 0; ip < order.p; ip++) {
142  param[ip] = _ar[order.p * bid + ip];
143  }
144  param += order.p;
145  for (int iq = 0; iq < order.q; iq++) {
146  param[iq] = _ma[order.q * bid + iq];
147  }
148  param += order.q;
149  for (int iP = 0; iP < order.P; iP++) {
150  param[iP] = _sar[order.P * bid + iP];
151  }
152  param += order.P;
153  for (int iQ = 0; iQ < order.Q; iQ++) {
154  param[iQ] = _sma[order.Q * bid + iQ];
155  }
156  param += order.Q;
157  *param = _sigma2[bid];
158  });
159  }
160 
170  void unpack(const ARIMAOrder& order, int batch_size, const DataT* param_vec, cudaStream_t stream)
171  {
172  int N = order.complexity();
173  auto counting = thrust::make_counting_iterator(0);
174  // The device lambda can't capture structure members...
175  DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
176  *_sigma2 = sigma2;
177  thrust::for_each(
178  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
179  const DataT* param = param_vec + bid * N;
180  if (order.k) {
181  _mu[bid] = *param;
182  param++;
183  }
184  for (int i = 0; i < order.n_exog; i++) {
185  _beta[order.n_exog * bid + i] = param[i];
186  }
187  param += order.n_exog;
188  for (int ip = 0; ip < order.p; ip++) {
189  _ar[order.p * bid + ip] = param[ip];
190  }
191  param += order.p;
192  for (int iq = 0; iq < order.q; iq++) {
193  _ma[order.q * bid + iq] = param[iq];
194  }
195  param += order.q;
196  for (int iP = 0; iP < order.P; iP++) {
197  _sar[order.P * bid + iP] = param[iP];
198  }
199  param += order.P;
200  for (int iQ = 0; iQ < order.Q; iQ++) {
201  _sma[order.Q * bid + iQ] = param[iQ];
202  }
203  param += order.Q;
204  _sigma2[bid] = *param;
205  });
206  }
207 };
208 
215 template <typename T, int ALIGN = 256>
216 struct ARIMAMemory {
228 
229  size_t size;
230 
231  protected:
232  char* buf;
233 
234  template <bool assign, typename ValType>
235  inline void append_buffer(ValType*& ptr, size_t n_elem)
236  {
237  if (assign) { ptr = reinterpret_cast<ValType*>(buf + size); }
238  size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
239  }
240 
241  template <bool assign>
242  inline void buf_offsets(const ARIMAOrder& order,
243  int batch_size,
244  int n_obs,
245  char* in_buf = nullptr)
246  {
247  buf = in_buf;
248  size = 0;
249 
250  int r = order.r();
251  int rd = order.rd();
252  int N = order.complexity();
253  int n_diff = order.n_diff();
254 
255  append_buffer<assign>(params_mu, order.k * batch_size);
256  append_buffer<assign>(params_beta, order.n_exog * batch_size);
257  append_buffer<assign>(params_ar, order.p * batch_size);
258  append_buffer<assign>(params_ma, order.q * batch_size);
259  append_buffer<assign>(params_sar, order.P * batch_size);
260  append_buffer<assign>(params_sma, order.Q * batch_size);
261  append_buffer<assign>(params_sigma2, batch_size);
262 
263  append_buffer<assign>(Tparams_ar, order.p * batch_size);
264  append_buffer<assign>(Tparams_ma, order.q * batch_size);
265  append_buffer<assign>(Tparams_sar, order.P * batch_size);
266  append_buffer<assign>(Tparams_sma, order.Q * batch_size);
267  append_buffer<assign>(Tparams_sigma2, batch_size);
268 
269  append_buffer<assign>(d_params, N * batch_size);
270  append_buffer<assign>(d_Tparams, N * batch_size);
271  append_buffer<assign>(Z_dense, rd * batch_size);
272  append_buffer<assign>(Z_batches, batch_size);
273  append_buffer<assign>(R_dense, rd * batch_size);
274  append_buffer<assign>(R_batches, batch_size);
275  append_buffer<assign>(T_dense, rd * rd * batch_size);
276  append_buffer<assign>(T_batches, batch_size);
277  append_buffer<assign>(RQ_dense, rd * batch_size);
278  append_buffer<assign>(RQ_batches, batch_size);
279  append_buffer<assign>(RQR_dense, rd * rd * batch_size);
280  append_buffer<assign>(RQR_batches, batch_size);
281  append_buffer<assign>(P_dense, rd * rd * batch_size);
282  append_buffer<assign>(P_batches, batch_size);
283  append_buffer<assign>(alpha_dense, rd * batch_size);
284  append_buffer<assign>(alpha_batches, batch_size);
285  append_buffer<assign>(ImT_dense, r * r * batch_size);
286  append_buffer<assign>(ImT_batches, batch_size);
287  append_buffer<assign>(ImT_inv_dense, r * r * batch_size);
288  append_buffer<assign>(ImT_inv_batches, batch_size);
289  append_buffer<assign>(ImT_inv_P, r * batch_size);
290  append_buffer<assign>(ImT_inv_info, batch_size);
291  append_buffer<assign>(v_tmp_dense, rd * batch_size);
292  append_buffer<assign>(v_tmp_batches, batch_size);
293  append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
294  append_buffer<assign>(m_tmp_batches, batch_size);
295  append_buffer<assign>(K_dense, rd * batch_size);
296  append_buffer<assign>(K_batches, batch_size);
297  append_buffer<assign>(TP_dense, rd * rd * batch_size);
298  append_buffer<assign>(TP_batches, batch_size);
299 
300  append_buffer<assign>(pred, n_obs * batch_size);
301  append_buffer<assign>(y_diff, n_obs * batch_size);
302  append_buffer<assign>(exog_diff, n_obs * order.n_exog * batch_size);
303  append_buffer<assign>(loglike, batch_size);
304  append_buffer<assign>(loglike_base, batch_size);
305  append_buffer<assign>(loglike_pert, batch_size);
306  append_buffer<assign>(x_pert, N * batch_size);
307 
308  if (n_diff > 0) {
309  append_buffer<assign>(Ts_dense, r * r * batch_size);
310  append_buffer<assign>(Ts_batches, batch_size);
311  append_buffer<assign>(RQRs_dense, r * r * batch_size);
312  append_buffer<assign>(RQRs_batches, batch_size);
313  append_buffer<assign>(Ps_dense, r * r * batch_size);
314  append_buffer<assign>(Ps_batches, batch_size);
315  }
316 
317  if (r <= 5) {
318  // Note: temp mem for the direct Lyapunov solver grows very quickly!
319  // This solver is used iff the condition above is satisfied
320  append_buffer<assign>(I_m_AxA_dense, r * r * r * r * batch_size);
321  append_buffer<assign>(I_m_AxA_batches, batch_size);
322  append_buffer<assign>(I_m_AxA_inv_dense, r * r * r * r * batch_size);
323  append_buffer<assign>(I_m_AxA_inv_batches, batch_size);
324  append_buffer<assign>(I_m_AxA_P, r * r * batch_size);
325  append_buffer<assign>(I_m_AxA_info, batch_size);
326  }
327  }
328 
330  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs)
331  {
332  buf_offsets<false>(order, batch_size, n_obs);
333  }
334 
335  public:
343  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, char* in_buf)
344  {
345  buf_offsets<true>(order, batch_size, n_obs, in_buf);
346  }
347 
354  static size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs)
355  {
356  ARIMAMemory temp(order, batch_size, n_obs);
357  return temp.size;
358  }
359 };
360 
361 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
Definition: dbscan.hpp:27
Definition: arima_common.h:216
T * x_pert
Definition: arima_common.h:221
T * Tparams_sar
Definition: arima_common.h:218
T * K_dense
Definition: arima_common.h:220
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:242
T ** R_batches
Definition: arima_common.h:223
T ** RQ_batches
Definition: arima_common.h:223
T * params_mu
Definition: arima_common.h:217
T * T_dense
Definition: arima_common.h:219
T ** Ps_batches
Definition: arima_common.h:226
T * alpha_dense
Definition: arima_common.h:219
int * ImT_inv_P
Definition: arima_common.h:227
T * Z_dense
Definition: arima_common.h:219
T * d_params
Definition: arima_common.h:218
T * loglike_base
Definition: arima_common.h:221
T * ImT_inv_dense
Definition: arima_common.h:220
T * Tparams_sma
Definition: arima_common.h:218
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:354
T * y_diff
Definition: arima_common.h:220
T * TP_dense
Definition: arima_common.h:220
T * Tparams_ar
Definition: arima_common.h:218
T * RQRs_dense
Definition: arima_common.h:222
T ** ImT_batches
Definition: arima_common.h:224
T ** T_batches
Definition: arima_common.h:223
T * I_m_AxA_inv_dense
Definition: arima_common.h:221
T * pred
Definition: arima_common.h:220
T * ImT_dense
Definition: arima_common.h:219
int * ImT_inv_info
Definition: arima_common.h:227
T ** alpha_batches
Definition: arima_common.h:224
T * params_sma
Definition: arima_common.h:217
T * Ps_dense
Definition: arima_common.h:222
T ** P_batches
Definition: arima_common.h:223
T * loglike_pert
Definition: arima_common.h:221
T ** m_tmp_batches
Definition: arima_common.h:224
T ** Z_batches
Definition: arima_common.h:223
int * I_m_AxA_P
Definition: arima_common.h:227
T * m_tmp_dense
Definition: arima_common.h:220
T * params_sar
Definition: arima_common.h:217
T ** K_batches
Definition: arima_common.h:225
T * RQR_dense
Definition: arima_common.h:219
size_t size
Definition: arima_common.h:229
T ** I_m_AxA_inv_batches
Definition: arima_common.h:225
T * d_Tparams
Definition: arima_common.h:218
T ** ImT_inv_batches
Definition: arima_common.h:224
T * I_m_AxA_dense
Definition: arima_common.h:221
T * v_tmp_dense
Definition: arima_common.h:220
T * Tparams_ma
Definition: arima_common.h:218
T * params_beta
Definition: arima_common.h:217
int * I_m_AxA_info
Definition: arima_common.h:227
T * P_dense
Definition: arima_common.h:219
T ** v_tmp_batches
Definition: arima_common.h:224
T ** I_m_AxA_batches
Definition: arima_common.h:225
T * params_ar
Definition: arima_common.h:217
char * buf
Definition: arima_common.h:232
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:330
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:235
T * R_dense
Definition: arima_common.h:219
T ** RQRs_batches
Definition: arima_common.h:226
T * loglike
Definition: arima_common.h:221
T * exog_diff
Definition: arima_common.h:220
T * RQ_dense
Definition: arima_common.h:219
T * params_sigma2
Definition: arima_common.h:217
T ** RQR_batches
Definition: arima_common.h:223
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:343
T * Tparams_sigma2
Definition: arima_common.h:218
T ** TP_batches
Definition: arima_common.h:225
T * Ts_dense
Definition: arima_common.h:221
T * params_ma
Definition: arima_common.h:217
T ** Ts_batches
Definition: arima_common.h:225
Definition: arima_common.h:34
int p
Definition: arima_common.h:35
int s
Definition: arima_common.h:41
int n_phi() const
Definition: arima_common.h:46
int P
Definition: arima_common.h:38
int r() const
Definition: arima_common.h:48
int n_exog
Definition: arima_common.h:43
int rd() const
Definition: arima_common.h:49
int D
Definition: arima_common.h:39
int complexity() const
Definition: arima_common.h:50
int q
Definition: arima_common.h:37
bool need_diff() const
Definition: arima_common.h:51
int n_theta() const
Definition: arima_common.h:47
int Q
Definition: arima_common.h:40
int d
Definition: arima_common.h:36
int k
Definition: arima_common.h:42
int n_diff() const
Definition: arima_common.h:45
Definition: arima_common.h:61
DataT * mu
Definition: arima_common.h:62
DataT * sma
Definition: arima_common.h:67
DataT * beta
Definition: arima_common.h:63
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:101
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
Definition: arima_common.h:170
DataT * ma
Definition: arima_common.h:65
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:79
DataT * ar
Definition: arima_common.h:64
DataT * sar
Definition: arima_common.h:66
DataT * sigma2
Definition: arima_common.h:68
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const
Definition: arima_common.h:123