arima_common.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <raft/util/cudart_utils.hpp>
9 
10 #include <rmm/aligned.hpp>
11 #include <rmm/mr/per_device_resource.hpp>
12 #include <rmm/resource_ref.hpp>
13 
14 #include <cuda_runtime.h>
15 #include <thrust/execution_policy.h>
16 #include <thrust/for_each.h>
17 #include <thrust/iterator/counting_iterator.h>
18 
19 #include <algorithm>
20 
21 namespace ML {
22 
26 struct ARIMAOrder {
27  int p; // Basic order
28  int d;
29  int q;
30  int P; // Seasonal order
31  int D;
32  int Q;
33  int s; // Seasonal period
34  int k; // Fit intercept?
35  int n_exog; // Number of exogenous regressors
36 
37  inline int n_diff() const { return d + s * D; }
38  inline int n_phi() const { return p + s * P; }
39  inline int n_theta() const { return q + s * Q; }
40  inline int r() const { return std::max(n_phi(), n_theta() + 1); }
41  inline int rd() const { return n_diff() + r(); }
42  inline int complexity() const { return p + P + q + Q + k + n_exog + 1; }
43  inline bool need_diff() const { return static_cast<bool>(d + D); }
44 };
45 
52 template <typename DataT>
53 struct ARIMAParams {
54  DataT* mu = nullptr;
55  DataT* beta = nullptr;
56  DataT* ar = nullptr;
57  DataT* ma = nullptr;
58  DataT* sar = nullptr;
59  DataT* sma = nullptr;
60  DataT* sigma2 = nullptr;
61 
71  void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
72  {
73  rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();
74  if (order.k && !tr) mu = (DataT*)rmm_alloc.allocate(stream, batch_size * sizeof(DataT));
75  if (order.n_exog && !tr)
76  beta = (DataT*)rmm_alloc.allocate(stream, order.n_exog * batch_size * sizeof(DataT));
77  if (order.p) ar = (DataT*)rmm_alloc.allocate(stream, order.p * batch_size * sizeof(DataT));
78  if (order.q) ma = (DataT*)rmm_alloc.allocate(stream, order.q * batch_size * sizeof(DataT));
79  if (order.P) sar = (DataT*)rmm_alloc.allocate(stream, order.P * batch_size * sizeof(DataT));
80  if (order.Q) sma = (DataT*)rmm_alloc.allocate(stream, order.Q * batch_size * sizeof(DataT));
81  sigma2 = (DataT*)rmm_alloc.allocate(stream, batch_size * sizeof(DataT));
82  }
83 
93  void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
94  {
95  rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();
96  if (order.k && !tr) rmm_alloc.deallocate(stream, mu, batch_size * sizeof(DataT));
97  if (order.n_exog && !tr)
98  rmm_alloc.deallocate(stream, beta, order.n_exog * batch_size * sizeof(DataT));
99  if (order.p) rmm_alloc.deallocate(stream, ar, order.p * batch_size * sizeof(DataT));
100  if (order.q) rmm_alloc.deallocate(stream, ma, order.q * batch_size * sizeof(DataT));
101  if (order.P) rmm_alloc.deallocate(stream, sar, order.P * batch_size * sizeof(DataT));
102  if (order.Q) rmm_alloc.deallocate(stream, sma, order.Q * batch_size * sizeof(DataT));
103  rmm_alloc.deallocate(stream, sigma2, batch_size * sizeof(DataT));
104  }
105 
115  void pack(const ARIMAOrder& order, int batch_size, DataT* param_vec, cudaStream_t stream) const
116  {
117  int N = order.complexity();
118  auto counting = thrust::make_counting_iterator(0);
119  // The device lambda can't capture structure members...
120  const DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
121  *_sigma2 = sigma2;
122  thrust::for_each(
123  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
124  DataT* param = param_vec + bid * N;
125  if (order.k) {
126  *param = _mu[bid];
127  param++;
128  }
129  for (int i = 0; i < order.n_exog; i++) {
130  param[i] = _beta[order.n_exog * bid + i];
131  }
132  param += order.n_exog;
133  for (int ip = 0; ip < order.p; ip++) {
134  param[ip] = _ar[order.p * bid + ip];
135  }
136  param += order.p;
137  for (int iq = 0; iq < order.q; iq++) {
138  param[iq] = _ma[order.q * bid + iq];
139  }
140  param += order.q;
141  for (int iP = 0; iP < order.P; iP++) {
142  param[iP] = _sar[order.P * bid + iP];
143  }
144  param += order.P;
145  for (int iQ = 0; iQ < order.Q; iQ++) {
146  param[iQ] = _sma[order.Q * bid + iQ];
147  }
148  param += order.Q;
149  *param = _sigma2[bid];
150  });
151  }
152 
162  void unpack(const ARIMAOrder& order, int batch_size, const DataT* param_vec, cudaStream_t stream)
163  {
164  int N = order.complexity();
165  auto counting = thrust::make_counting_iterator(0);
166  // The device lambda can't capture structure members...
167  DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma,
168  *_sigma2 = sigma2;
169  thrust::for_each(
170  thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) {
171  const DataT* param = param_vec + bid * N;
172  if (order.k) {
173  _mu[bid] = *param;
174  param++;
175  }
176  for (int i = 0; i < order.n_exog; i++) {
177  _beta[order.n_exog * bid + i] = param[i];
178  }
179  param += order.n_exog;
180  for (int ip = 0; ip < order.p; ip++) {
181  _ar[order.p * bid + ip] = param[ip];
182  }
183  param += order.p;
184  for (int iq = 0; iq < order.q; iq++) {
185  _ma[order.q * bid + iq] = param[iq];
186  }
187  param += order.q;
188  for (int iP = 0; iP < order.P; iP++) {
189  _sar[order.P * bid + iP] = param[iP];
190  }
191  param += order.P;
192  for (int iQ = 0; iQ < order.Q; iQ++) {
193  _sma[order.Q * bid + iQ] = param[iQ];
194  }
195  param += order.Q;
196  _sigma2[bid] = *param;
197  });
198  }
199 };
200 
207 template <typename T, int ALIGN = 256>
208 struct ARIMAMemory {
220 
221  size_t size;
222 
223  protected:
224  char* buf;
225 
226  template <bool assign, typename ValType>
227  inline void append_buffer(ValType*& ptr, size_t n_elem)
228  {
229  if (assign) { ptr = reinterpret_cast<ValType*>(buf + size); }
230  size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
231  }
232 
233  template <bool assign>
234  inline void buf_offsets(const ARIMAOrder& order,
235  int batch_size,
236  int n_obs,
237  char* in_buf = nullptr)
238  {
239  buf = in_buf;
240  size = 0;
241 
242  int r = order.r();
243  int rd = order.rd();
244  int N = order.complexity();
245  int n_diff = order.n_diff();
246 
247  append_buffer<assign>(params_mu, order.k * batch_size);
248  append_buffer<assign>(params_beta, order.n_exog * batch_size);
249  append_buffer<assign>(params_ar, order.p * batch_size);
250  append_buffer<assign>(params_ma, order.q * batch_size);
251  append_buffer<assign>(params_sar, order.P * batch_size);
252  append_buffer<assign>(params_sma, order.Q * batch_size);
253  append_buffer<assign>(params_sigma2, batch_size);
254 
255  append_buffer<assign>(Tparams_ar, order.p * batch_size);
256  append_buffer<assign>(Tparams_ma, order.q * batch_size);
257  append_buffer<assign>(Tparams_sar, order.P * batch_size);
258  append_buffer<assign>(Tparams_sma, order.Q * batch_size);
259  append_buffer<assign>(Tparams_sigma2, batch_size);
260 
261  append_buffer<assign>(d_params, N * batch_size);
262  append_buffer<assign>(d_Tparams, N * batch_size);
263  append_buffer<assign>(Z_dense, rd * batch_size);
264  append_buffer<assign>(Z_batches, batch_size);
265  append_buffer<assign>(R_dense, rd * batch_size);
266  append_buffer<assign>(R_batches, batch_size);
267  append_buffer<assign>(T_dense, rd * rd * batch_size);
268  append_buffer<assign>(T_batches, batch_size);
269  append_buffer<assign>(RQ_dense, rd * batch_size);
270  append_buffer<assign>(RQ_batches, batch_size);
271  append_buffer<assign>(RQR_dense, rd * rd * batch_size);
272  append_buffer<assign>(RQR_batches, batch_size);
273  append_buffer<assign>(P_dense, rd * rd * batch_size);
274  append_buffer<assign>(P_batches, batch_size);
275  append_buffer<assign>(alpha_dense, rd * batch_size);
276  append_buffer<assign>(alpha_batches, batch_size);
277  append_buffer<assign>(ImT_dense, r * r * batch_size);
278  append_buffer<assign>(ImT_batches, batch_size);
279  append_buffer<assign>(ImT_inv_dense, r * r * batch_size);
280  append_buffer<assign>(ImT_inv_batches, batch_size);
281  append_buffer<assign>(ImT_inv_P, r * batch_size);
282  append_buffer<assign>(ImT_inv_info, batch_size);
283  append_buffer<assign>(v_tmp_dense, rd * batch_size);
284  append_buffer<assign>(v_tmp_batches, batch_size);
285  append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
286  append_buffer<assign>(m_tmp_batches, batch_size);
287  append_buffer<assign>(K_dense, rd * batch_size);
288  append_buffer<assign>(K_batches, batch_size);
289  append_buffer<assign>(TP_dense, rd * rd * batch_size);
290  append_buffer<assign>(TP_batches, batch_size);
291 
292  append_buffer<assign>(pred, n_obs * batch_size);
293  append_buffer<assign>(y_diff, n_obs * batch_size);
294  append_buffer<assign>(exog_diff, n_obs * order.n_exog * batch_size);
295  append_buffer<assign>(loglike, batch_size);
296  append_buffer<assign>(loglike_base, batch_size);
297  append_buffer<assign>(loglike_pert, batch_size);
298  append_buffer<assign>(x_pert, N * batch_size);
299 
300  if (n_diff > 0) {
301  append_buffer<assign>(Ts_dense, r * r * batch_size);
302  append_buffer<assign>(Ts_batches, batch_size);
303  append_buffer<assign>(RQRs_dense, r * r * batch_size);
304  append_buffer<assign>(RQRs_batches, batch_size);
305  append_buffer<assign>(Ps_dense, r * r * batch_size);
306  append_buffer<assign>(Ps_batches, batch_size);
307  }
308 
309  if (r <= 5) {
310  // Note: temp mem for the direct Lyapunov solver grows very quickly!
311  // This solver is used iff the condition above is satisfied
312  append_buffer<assign>(I_m_AxA_dense, r * r * r * r * batch_size);
313  append_buffer<assign>(I_m_AxA_batches, batch_size);
314  append_buffer<assign>(I_m_AxA_inv_dense, r * r * r * r * batch_size);
315  append_buffer<assign>(I_m_AxA_inv_batches, batch_size);
316  append_buffer<assign>(I_m_AxA_P, r * r * batch_size);
317  append_buffer<assign>(I_m_AxA_info, batch_size);
318  }
319  }
320 
322  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs)
323  {
324  buf_offsets<false>(order, batch_size, n_obs);
325  }
326 
327  public:
335  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, char* in_buf)
336  {
337  buf_offsets<true>(order, batch_size, n_obs, in_buf);
338  }
339 
346  static size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs)
347  {
348  ARIMAMemory temp(order, batch_size, n_obs);
349  return temp.size;
350  }
351 };
352 
353 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
Definition: dbscan.hpp:18
Definition: arima_common.h:208
T * x_pert
Definition: arima_common.h:213
T * Tparams_sar
Definition: arima_common.h:210
T * K_dense
Definition: arima_common.h:212
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:234
T ** R_batches
Definition: arima_common.h:215
T ** RQ_batches
Definition: arima_common.h:215
T * params_mu
Definition: arima_common.h:209
T * T_dense
Definition: arima_common.h:211
T ** Ps_batches
Definition: arima_common.h:218
T * alpha_dense
Definition: arima_common.h:211
int * ImT_inv_P
Definition: arima_common.h:219
T * Z_dense
Definition: arima_common.h:211
T * d_params
Definition: arima_common.h:210
T * loglike_base
Definition: arima_common.h:213
T * ImT_inv_dense
Definition: arima_common.h:212
T * Tparams_sma
Definition: arima_common.h:210
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:346
T * y_diff
Definition: arima_common.h:212
T * TP_dense
Definition: arima_common.h:212
T * Tparams_ar
Definition: arima_common.h:210
T * RQRs_dense
Definition: arima_common.h:214
T ** ImT_batches
Definition: arima_common.h:216
T ** T_batches
Definition: arima_common.h:215
T * I_m_AxA_inv_dense
Definition: arima_common.h:213
T * pred
Definition: arima_common.h:212
T * ImT_dense
Definition: arima_common.h:211
int * ImT_inv_info
Definition: arima_common.h:219
T ** alpha_batches
Definition: arima_common.h:216
T * params_sma
Definition: arima_common.h:209
T * Ps_dense
Definition: arima_common.h:214
T ** P_batches
Definition: arima_common.h:215
T * loglike_pert
Definition: arima_common.h:213
T ** m_tmp_batches
Definition: arima_common.h:216
T ** Z_batches
Definition: arima_common.h:215
int * I_m_AxA_P
Definition: arima_common.h:219
T * m_tmp_dense
Definition: arima_common.h:212
T * params_sar
Definition: arima_common.h:209
T ** K_batches
Definition: arima_common.h:217
T * RQR_dense
Definition: arima_common.h:211
size_t size
Definition: arima_common.h:221
T ** I_m_AxA_inv_batches
Definition: arima_common.h:217
T * d_Tparams
Definition: arima_common.h:210
T ** ImT_inv_batches
Definition: arima_common.h:216
T * I_m_AxA_dense
Definition: arima_common.h:213
T * v_tmp_dense
Definition: arima_common.h:212
T * Tparams_ma
Definition: arima_common.h:210
T * params_beta
Definition: arima_common.h:209
int * I_m_AxA_info
Definition: arima_common.h:219
T * P_dense
Definition: arima_common.h:211
T ** v_tmp_batches
Definition: arima_common.h:216
T ** I_m_AxA_batches
Definition: arima_common.h:217
T * params_ar
Definition: arima_common.h:209
char * buf
Definition: arima_common.h:224
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:322
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:227
T * R_dense
Definition: arima_common.h:211
T ** RQRs_batches
Definition: arima_common.h:218
T * loglike
Definition: arima_common.h:213
T * exog_diff
Definition: arima_common.h:212
T * RQ_dense
Definition: arima_common.h:211
T * params_sigma2
Definition: arima_common.h:209
T ** RQR_batches
Definition: arima_common.h:215
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:335
T * Tparams_sigma2
Definition: arima_common.h:210
T ** TP_batches
Definition: arima_common.h:217
T * Ts_dense
Definition: arima_common.h:213
T * params_ma
Definition: arima_common.h:209
T ** Ts_batches
Definition: arima_common.h:217
Definition: arima_common.h:26
int p
Definition: arima_common.h:27
int s
Definition: arima_common.h:33
int n_phi() const
Definition: arima_common.h:38
int P
Definition: arima_common.h:30
int r() const
Definition: arima_common.h:40
int n_exog
Definition: arima_common.h:35
int rd() const
Definition: arima_common.h:41
int D
Definition: arima_common.h:31
int complexity() const
Definition: arima_common.h:42
int q
Definition: arima_common.h:29
bool need_diff() const
Definition: arima_common.h:43
int n_theta() const
Definition: arima_common.h:39
int Q
Definition: arima_common.h:32
int d
Definition: arima_common.h:28
int k
Definition: arima_common.h:34
int n_diff() const
Definition: arima_common.h:37
Definition: arima_common.h:53
DataT * mu
Definition: arima_common.h:54
DataT * sma
Definition: arima_common.h:59
DataT * beta
Definition: arima_common.h:55
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:93
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
Definition: arima_common.h:162
DataT * ma
Definition: arima_common.h:57
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:71
DataT * ar
Definition: arima_common.h:56
DataT * sar
Definition: arima_common.h:58
DataT * sigma2
Definition: arima_common.h:60
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const
Definition: arima_common.h:115