19 #include <cuda_runtime.h>
23 #include <raft/util/cudart_utils.hpp>
24 #include <rmm/mr/device/per_device_resource.hpp>
25 #include <thrust/execution_policy.h>
26 #include <thrust/for_each.h>
27 #include <thrust/iterator/counting_iterator.h>
51 inline bool need_diff()
const {
return static_cast<bool>(
d +
D); }
60 template <
typename DataT>
81 rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
82 if (order.
k && !tr)
mu = (DataT*)rmm_alloc->allocate(batch_size *
sizeof(DataT), stream);
84 beta = (DataT*)rmm_alloc->allocate(order.
n_exog * batch_size *
sizeof(DataT), stream);
85 if (order.
p)
ar = (DataT*)rmm_alloc->allocate(order.
p * batch_size *
sizeof(DataT), stream);
86 if (order.
q)
ma = (DataT*)rmm_alloc->allocate(order.
q * batch_size *
sizeof(DataT), stream);
87 if (order.
P)
sar = (DataT*)rmm_alloc->allocate(order.
P * batch_size *
sizeof(DataT), stream);
88 if (order.
Q)
sma = (DataT*)rmm_alloc->allocate(order.
Q * batch_size *
sizeof(DataT), stream);
89 sigma2 = (DataT*)rmm_alloc->allocate(batch_size *
sizeof(DataT), stream);
103 rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource();
104 if (order.
k && !tr) rmm_alloc->deallocate(
mu, batch_size *
sizeof(DataT), stream);
106 rmm_alloc->deallocate(
beta, order.
n_exog * batch_size *
sizeof(DataT), stream);
107 if (order.
p) rmm_alloc->deallocate(
ar, order.
p * batch_size *
sizeof(DataT), stream);
108 if (order.
q) rmm_alloc->deallocate(
ma, order.
q * batch_size *
sizeof(DataT), stream);
109 if (order.
P) rmm_alloc->deallocate(
sar, order.
P * batch_size *
sizeof(DataT), stream);
110 if (order.
Q) rmm_alloc->deallocate(
sma, order.
Q * batch_size *
sizeof(DataT), stream);
111 rmm_alloc->deallocate(
sigma2, batch_size *
sizeof(DataT), stream);
123 void pack(
const ARIMAOrder& order,
int batch_size, DataT* param_vec, cudaStream_t stream)
const
126 auto counting = thrust::make_counting_iterator(0);
128 const DataT *_mu =
mu, *_beta =
beta, *_ar =
ar, *_ma =
ma, *_sar =
sar, *_sma =
sma,
131 thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(
int bid) {
132 DataT* param = param_vec + bid * N;
137 for (
int i = 0; i < order.
n_exog; i++) {
138 param[i] = _beta[order.
n_exog * bid + i];
141 for (
int ip = 0; ip < order.
p; ip++) {
142 param[ip] = _ar[order.
p * bid + ip];
145 for (
int iq = 0; iq < order.
q; iq++) {
146 param[iq] = _ma[order.
q * bid + iq];
149 for (
int iP = 0; iP < order.
P; iP++) {
150 param[iP] = _sar[order.
P * bid + iP];
153 for (
int iQ = 0; iQ < order.
Q; iQ++) {
154 param[iQ] = _sma[order.
Q * bid + iQ];
157 *param = _sigma2[bid];
170 void unpack(
const ARIMAOrder& order,
int batch_size,
const DataT* param_vec, cudaStream_t stream)
173 auto counting = thrust::make_counting_iterator(0);
175 DataT *_mu =
mu, *_beta =
beta, *_ar =
ar, *_ma =
ma, *_sar =
sar, *_sma =
sma,
178 thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(
int bid) {
179 const DataT* param = param_vec + bid * N;
184 for (
int i = 0; i < order.
n_exog; i++) {
185 _beta[order.
n_exog * bid + i] = param[i];
188 for (
int ip = 0; ip < order.
p; ip++) {
189 _ar[order.
p * bid + ip] = param[ip];
192 for (
int iq = 0; iq < order.
q; iq++) {
193 _ma[order.
q * bid + iq] = param[iq];
196 for (
int iP = 0; iP < order.
P; iP++) {
197 _sar[order.
P * bid + iP] = param[iP];
200 for (
int iQ = 0; iQ < order.
Q; iQ++) {
201 _sma[order.
Q * bid + iQ] = param[iQ];
204 _sigma2[bid] = *param;
215 template <
typename T,
int ALIGN = 256>
234 template <
bool assign,
typename ValType>
237 if (assign) { ptr =
reinterpret_cast<ValType*
>(
buf +
size); }
238 size += ((n_elem *
sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
241 template <
bool assign>
245 char* in_buf =
nullptr)
253 int n_diff = order.
n_diff();
255 append_buffer<assign>(
params_mu, order.
k * batch_size);
257 append_buffer<assign>(
params_ar, order.
p * batch_size);
258 append_buffer<assign>(
params_ma, order.
q * batch_size);
259 append_buffer<assign>(
params_sar, order.
P * batch_size);
260 append_buffer<assign>(
params_sma, order.
Q * batch_size);
263 append_buffer<assign>(
Tparams_ar, order.
p * batch_size);
264 append_buffer<assign>(
Tparams_ma, order.
q * batch_size);
265 append_buffer<assign>(
Tparams_sar, order.
P * batch_size);
266 append_buffer<assign>(
Tparams_sma, order.
Q * batch_size);
269 append_buffer<assign>(
d_params, N * batch_size);
270 append_buffer<assign>(
d_Tparams, N * batch_size);
271 append_buffer<assign>(
Z_dense, rd * batch_size);
272 append_buffer<assign>(
Z_batches, batch_size);
273 append_buffer<assign>(
R_dense, rd * batch_size);
274 append_buffer<assign>(
R_batches, batch_size);
275 append_buffer<assign>(
T_dense, rd * rd * batch_size);
276 append_buffer<assign>(
T_batches, batch_size);
277 append_buffer<assign>(
RQ_dense, rd * batch_size);
278 append_buffer<assign>(
RQ_batches, batch_size);
279 append_buffer<assign>(
RQR_dense, rd * rd * batch_size);
281 append_buffer<assign>(
P_dense, rd * rd * batch_size);
282 append_buffer<assign>(
P_batches, batch_size);
283 append_buffer<assign>(
alpha_dense, rd * batch_size);
285 append_buffer<assign>(
ImT_dense, r * r * batch_size);
289 append_buffer<assign>(
ImT_inv_P, r * batch_size);
291 append_buffer<assign>(
v_tmp_dense, rd * batch_size);
293 append_buffer<assign>(
m_tmp_dense, rd * rd * batch_size);
295 append_buffer<assign>(
K_dense, rd * batch_size);
296 append_buffer<assign>(
K_batches, batch_size);
297 append_buffer<assign>(
TP_dense, rd * rd * batch_size);
298 append_buffer<assign>(
TP_batches, batch_size);
300 append_buffer<assign>(
pred, n_obs * batch_size);
301 append_buffer<assign>(
y_diff, n_obs * batch_size);
303 append_buffer<assign>(
loglike, batch_size);
306 append_buffer<assign>(
x_pert, N * batch_size);
309 append_buffer<assign>(
Ts_dense, r * r * batch_size);
310 append_buffer<assign>(
Ts_batches, batch_size);
311 append_buffer<assign>(
RQRs_dense, r * r * batch_size);
313 append_buffer<assign>(
Ps_dense, r * r * batch_size);
314 append_buffer<assign>(
Ps_batches, batch_size);
320 append_buffer<assign>(
I_m_AxA_dense, r * r * r * r * batch_size);
324 append_buffer<assign>(
I_m_AxA_P, r * r * batch_size);
332 buf_offsets<false>(order, batch_size, n_obs);
345 buf_offsets<true>(order, batch_size, n_obs, in_buf);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
Definition: dbscan.hpp:27
Definition: arima_common.h:216
T * x_pert
Definition: arima_common.h:221
T * Tparams_sar
Definition: arima_common.h:218
T * K_dense
Definition: arima_common.h:220
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:242
T ** R_batches
Definition: arima_common.h:223
T ** RQ_batches
Definition: arima_common.h:223
T * params_mu
Definition: arima_common.h:217
T * T_dense
Definition: arima_common.h:219
T ** Ps_batches
Definition: arima_common.h:226
T * alpha_dense
Definition: arima_common.h:219
int * ImT_inv_P
Definition: arima_common.h:227
T * Z_dense
Definition: arima_common.h:219
T * d_params
Definition: arima_common.h:218
T * loglike_base
Definition: arima_common.h:221
T * ImT_inv_dense
Definition: arima_common.h:220
T * Tparams_sma
Definition: arima_common.h:218
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:354
T * y_diff
Definition: arima_common.h:220
T * TP_dense
Definition: arima_common.h:220
T * Tparams_ar
Definition: arima_common.h:218
T * RQRs_dense
Definition: arima_common.h:222
T ** ImT_batches
Definition: arima_common.h:224
T ** T_batches
Definition: arima_common.h:223
T * I_m_AxA_inv_dense
Definition: arima_common.h:221
T * pred
Definition: arima_common.h:220
T * ImT_dense
Definition: arima_common.h:219
int * ImT_inv_info
Definition: arima_common.h:227
T ** alpha_batches
Definition: arima_common.h:224
T * params_sma
Definition: arima_common.h:217
T * Ps_dense
Definition: arima_common.h:222
T ** P_batches
Definition: arima_common.h:223
T * loglike_pert
Definition: arima_common.h:221
T ** m_tmp_batches
Definition: arima_common.h:224
T ** Z_batches
Definition: arima_common.h:223
int * I_m_AxA_P
Definition: arima_common.h:227
T * m_tmp_dense
Definition: arima_common.h:220
T * params_sar
Definition: arima_common.h:217
T ** K_batches
Definition: arima_common.h:225
T * RQR_dense
Definition: arima_common.h:219
size_t size
Definition: arima_common.h:229
T ** I_m_AxA_inv_batches
Definition: arima_common.h:225
T * d_Tparams
Definition: arima_common.h:218
T ** ImT_inv_batches
Definition: arima_common.h:224
T * I_m_AxA_dense
Definition: arima_common.h:221
T * v_tmp_dense
Definition: arima_common.h:220
T * Tparams_ma
Definition: arima_common.h:218
T * params_beta
Definition: arima_common.h:217
int * I_m_AxA_info
Definition: arima_common.h:227
T * P_dense
Definition: arima_common.h:219
T ** v_tmp_batches
Definition: arima_common.h:224
T ** I_m_AxA_batches
Definition: arima_common.h:225
T * params_ar
Definition: arima_common.h:217
char * buf
Definition: arima_common.h:232
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:330
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:235
T * R_dense
Definition: arima_common.h:219
T ** RQRs_batches
Definition: arima_common.h:226
T * loglike
Definition: arima_common.h:221
T * exog_diff
Definition: arima_common.h:220
T * RQ_dense
Definition: arima_common.h:219
T * params_sigma2
Definition: arima_common.h:217
T ** RQR_batches
Definition: arima_common.h:223
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:343
T * Tparams_sigma2
Definition: arima_common.h:218
T ** TP_batches
Definition: arima_common.h:225
T * Ts_dense
Definition: arima_common.h:221
T * params_ma
Definition: arima_common.h:217
T ** Ts_batches
Definition: arima_common.h:225
Definition: arima_common.h:34
int p
Definition: arima_common.h:35
int s
Definition: arima_common.h:41
int n_phi() const
Definition: arima_common.h:46
int P
Definition: arima_common.h:38
int r() const
Definition: arima_common.h:48
int n_exog
Definition: arima_common.h:43
int rd() const
Definition: arima_common.h:49
int D
Definition: arima_common.h:39
int complexity() const
Definition: arima_common.h:50
int q
Definition: arima_common.h:37
bool need_diff() const
Definition: arima_common.h:51
int n_theta() const
Definition: arima_common.h:47
int Q
Definition: arima_common.h:40
int d
Definition: arima_common.h:36
int k
Definition: arima_common.h:42
int n_diff() const
Definition: arima_common.h:45
Definition: arima_common.h:61
DataT * mu
Definition: arima_common.h:62
DataT * sma
Definition: arima_common.h:67
DataT * beta
Definition: arima_common.h:63
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:101
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
Definition: arima_common.h:170
DataT * ma
Definition: arima_common.h:65
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:79
DataT * ar
Definition: arima_common.h:64
DataT * sar
Definition: arima_common.h:66
DataT * sigma2
Definition: arima_common.h:68
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const
Definition: arima_common.h:123