arima_common.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <raft/util/cudart_utils.hpp>
9 
10 #include <rmm/aligned.hpp>
11 #include <rmm/mr/per_device_resource.hpp>
12 #include <rmm/resource_ref.hpp>
13 
14 #include <cuda_runtime.h>
15 
16 #include <algorithm>
17 
18 namespace ML {
19 
23 struct ARIMAOrder {
24  int p; // Basic order
25  int d;
26  int q;
27  int P; // Seasonal order
28  int D;
29  int Q;
30  int s; // Seasonal period
31  int k; // Fit intercept?
32  int n_exog; // Number of exogenous regressors
33 
34  inline int n_diff() const { return d + s * D; }
35  inline int n_phi() const { return p + s * P; }
36  inline int n_theta() const { return q + s * Q; }
37  inline int r() const { return std::max(n_phi(), n_theta() + 1); }
38  inline int rd() const { return n_diff() + r(); }
39  inline int complexity() const { return p + P + q + Q + k + n_exog + 1; }
40  inline bool need_diff() const { return static_cast<bool>(d + D); }
41 };
42 
49 template <typename DataT>
50 struct ARIMAParams {
51  DataT* mu = nullptr;
52  DataT* beta = nullptr;
53  DataT* ar = nullptr;
54  DataT* ma = nullptr;
55  DataT* sar = nullptr;
56  DataT* sma = nullptr;
57  DataT* sigma2 = nullptr;
58 
68  void allocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
69  {
70  rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource_ref();
71  if (order.k && !tr) mu = (DataT*)rmm_alloc.allocate(stream, batch_size * sizeof(DataT));
72  if (order.n_exog && !tr)
73  beta = (DataT*)rmm_alloc.allocate(stream, order.n_exog * batch_size * sizeof(DataT));
74  if (order.p) ar = (DataT*)rmm_alloc.allocate(stream, order.p * batch_size * sizeof(DataT));
75  if (order.q) ma = (DataT*)rmm_alloc.allocate(stream, order.q * batch_size * sizeof(DataT));
76  if (order.P) sar = (DataT*)rmm_alloc.allocate(stream, order.P * batch_size * sizeof(DataT));
77  if (order.Q) sma = (DataT*)rmm_alloc.allocate(stream, order.Q * batch_size * sizeof(DataT));
78  sigma2 = (DataT*)rmm_alloc.allocate(stream, batch_size * sizeof(DataT));
79  }
80 
90  void deallocate(const ARIMAOrder& order, int batch_size, cudaStream_t stream, bool tr = false)
91  {
92  rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource_ref();
93  if (order.k && !tr) rmm_alloc.deallocate(stream, mu, batch_size * sizeof(DataT));
94  if (order.n_exog && !tr)
95  rmm_alloc.deallocate(stream, beta, order.n_exog * batch_size * sizeof(DataT));
96  if (order.p) rmm_alloc.deallocate(stream, ar, order.p * batch_size * sizeof(DataT));
97  if (order.q) rmm_alloc.deallocate(stream, ma, order.q * batch_size * sizeof(DataT));
98  if (order.P) rmm_alloc.deallocate(stream, sar, order.P * batch_size * sizeof(DataT));
99  if (order.Q) rmm_alloc.deallocate(stream, sma, order.Q * batch_size * sizeof(DataT));
100  rmm_alloc.deallocate(stream, sigma2, batch_size * sizeof(DataT));
101  }
102 
112  void pack(const ARIMAOrder& order, int batch_size, DataT* param_vec, cudaStream_t stream) const;
113 
123  void unpack(const ARIMAOrder& order, int batch_size, const DataT* param_vec, cudaStream_t stream);
124 };
125 
132 template <typename T, int ALIGN = 256>
133 struct ARIMAMemory {
145 
146  size_t size;
147 
148  protected:
149  char* buf;
150 
151  template <bool assign, typename ValType>
152  inline void append_buffer(ValType*& ptr, size_t n_elem)
153  {
154  if (assign) { ptr = reinterpret_cast<ValType*>(buf + size); }
155  size += ((n_elem * sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
156  }
157 
158  template <bool assign>
159  inline void buf_offsets(const ARIMAOrder& order,
160  int batch_size,
161  int n_obs,
162  char* in_buf = nullptr)
163  {
164  buf = in_buf;
165  size = 0;
166 
167  int r = order.r();
168  int rd = order.rd();
169  int N = order.complexity();
170  int n_diff = order.n_diff();
171 
172  append_buffer<assign>(params_mu, order.k * batch_size);
173  append_buffer<assign>(params_beta, order.n_exog * batch_size);
174  append_buffer<assign>(params_ar, order.p * batch_size);
175  append_buffer<assign>(params_ma, order.q * batch_size);
176  append_buffer<assign>(params_sar, order.P * batch_size);
177  append_buffer<assign>(params_sma, order.Q * batch_size);
178  append_buffer<assign>(params_sigma2, batch_size);
179 
180  append_buffer<assign>(Tparams_ar, order.p * batch_size);
181  append_buffer<assign>(Tparams_ma, order.q * batch_size);
182  append_buffer<assign>(Tparams_sar, order.P * batch_size);
183  append_buffer<assign>(Tparams_sma, order.Q * batch_size);
184  append_buffer<assign>(Tparams_sigma2, batch_size);
185 
186  append_buffer<assign>(d_params, N * batch_size);
187  append_buffer<assign>(d_Tparams, N * batch_size);
188  append_buffer<assign>(Z_dense, rd * batch_size);
189  append_buffer<assign>(Z_batches, batch_size);
190  append_buffer<assign>(R_dense, rd * batch_size);
191  append_buffer<assign>(R_batches, batch_size);
192  append_buffer<assign>(T_dense, rd * rd * batch_size);
193  append_buffer<assign>(T_batches, batch_size);
194  append_buffer<assign>(RQ_dense, rd * batch_size);
195  append_buffer<assign>(RQ_batches, batch_size);
196  append_buffer<assign>(RQR_dense, rd * rd * batch_size);
197  append_buffer<assign>(RQR_batches, batch_size);
198  append_buffer<assign>(P_dense, rd * rd * batch_size);
199  append_buffer<assign>(P_batches, batch_size);
200  append_buffer<assign>(alpha_dense, rd * batch_size);
201  append_buffer<assign>(alpha_batches, batch_size);
202  append_buffer<assign>(ImT_dense, r * r * batch_size);
203  append_buffer<assign>(ImT_batches, batch_size);
204  append_buffer<assign>(ImT_inv_dense, r * r * batch_size);
205  append_buffer<assign>(ImT_inv_batches, batch_size);
206  append_buffer<assign>(ImT_inv_P, r * batch_size);
207  append_buffer<assign>(ImT_inv_info, batch_size);
208  append_buffer<assign>(v_tmp_dense, rd * batch_size);
209  append_buffer<assign>(v_tmp_batches, batch_size);
210  append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
211  append_buffer<assign>(m_tmp_batches, batch_size);
212  append_buffer<assign>(K_dense, rd * batch_size);
213  append_buffer<assign>(K_batches, batch_size);
214  append_buffer<assign>(TP_dense, rd * rd * batch_size);
215  append_buffer<assign>(TP_batches, batch_size);
216 
217  append_buffer<assign>(pred, n_obs * batch_size);
218  append_buffer<assign>(y_diff, n_obs * batch_size);
219  append_buffer<assign>(exog_diff, n_obs * order.n_exog * batch_size);
220  append_buffer<assign>(loglike, batch_size);
221  append_buffer<assign>(loglike_base, batch_size);
222  append_buffer<assign>(loglike_pert, batch_size);
223  append_buffer<assign>(x_pert, N * batch_size);
224 
225  if (n_diff > 0) {
226  append_buffer<assign>(Ts_dense, r * r * batch_size);
227  append_buffer<assign>(Ts_batches, batch_size);
228  append_buffer<assign>(RQRs_dense, r * r * batch_size);
229  append_buffer<assign>(RQRs_batches, batch_size);
230  append_buffer<assign>(Ps_dense, r * r * batch_size);
231  append_buffer<assign>(Ps_batches, batch_size);
232  }
233 
234  if (r <= 5) {
235  // Note: temp mem for the direct Lyapunov solver grows very quickly!
236  // This solver is used iff the condition above is satisfied
237  append_buffer<assign>(I_m_AxA_dense, r * r * r * r * batch_size);
238  append_buffer<assign>(I_m_AxA_batches, batch_size);
239  append_buffer<assign>(I_m_AxA_inv_dense, r * r * r * r * batch_size);
240  append_buffer<assign>(I_m_AxA_inv_batches, batch_size);
241  append_buffer<assign>(I_m_AxA_P, r * r * batch_size);
242  append_buffer<assign>(I_m_AxA_info, batch_size);
243  }
244  }
245 
247  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs)
248  {
249  buf_offsets<false>(order, batch_size, n_obs);
250  }
251 
252  public:
260  ARIMAMemory(const ARIMAOrder& order, int batch_size, int n_obs, char* in_buf)
261  {
262  buf_offsets<true>(order, batch_size, n_obs, in_buf);
263  }
264 
271  static size_t compute_size(const ARIMAOrder& order, int batch_size, int n_obs)
272  {
273  ARIMAMemory temp(order, batch_size, n_obs);
274  return temp.size;
275  }
276 };
277 
278 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
Definition: dbscan.hpp:18
Definition: arima_common.h:133
T * x_pert
Definition: arima_common.h:138
T * Tparams_sar
Definition: arima_common.h:135
T * K_dense
Definition: arima_common.h:137
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:159
T ** R_batches
Definition: arima_common.h:140
T ** RQ_batches
Definition: arima_common.h:140
T * params_mu
Definition: arima_common.h:134
T * T_dense
Definition: arima_common.h:136
T ** Ps_batches
Definition: arima_common.h:143
T * alpha_dense
Definition: arima_common.h:136
int * ImT_inv_P
Definition: arima_common.h:144
T * Z_dense
Definition: arima_common.h:136
T * d_params
Definition: arima_common.h:135
T * loglike_base
Definition: arima_common.h:138
T * ImT_inv_dense
Definition: arima_common.h:137
T * Tparams_sma
Definition: arima_common.h:135
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:271
T * y_diff
Definition: arima_common.h:137
T * TP_dense
Definition: arima_common.h:137
T * Tparams_ar
Definition: arima_common.h:135
T * RQRs_dense
Definition: arima_common.h:139
T ** ImT_batches
Definition: arima_common.h:141
T ** T_batches
Definition: arima_common.h:140
T * I_m_AxA_inv_dense
Definition: arima_common.h:138
T * pred
Definition: arima_common.h:137
T * ImT_dense
Definition: arima_common.h:136
int * ImT_inv_info
Definition: arima_common.h:144
T ** alpha_batches
Definition: arima_common.h:141
T * params_sma
Definition: arima_common.h:134
T * Ps_dense
Definition: arima_common.h:139
T ** P_batches
Definition: arima_common.h:140
T * loglike_pert
Definition: arima_common.h:138
T ** m_tmp_batches
Definition: arima_common.h:141
T ** Z_batches
Definition: arima_common.h:140
int * I_m_AxA_P
Definition: arima_common.h:144
T * m_tmp_dense
Definition: arima_common.h:137
T * params_sar
Definition: arima_common.h:134
T ** K_batches
Definition: arima_common.h:142
T * RQR_dense
Definition: arima_common.h:136
size_t size
Definition: arima_common.h:146
T ** I_m_AxA_inv_batches
Definition: arima_common.h:142
T * d_Tparams
Definition: arima_common.h:135
T ** ImT_inv_batches
Definition: arima_common.h:141
T * I_m_AxA_dense
Definition: arima_common.h:138
T * v_tmp_dense
Definition: arima_common.h:137
T * Tparams_ma
Definition: arima_common.h:135
T * params_beta
Definition: arima_common.h:134
int * I_m_AxA_info
Definition: arima_common.h:144
T * P_dense
Definition: arima_common.h:136
T ** v_tmp_batches
Definition: arima_common.h:141
T ** I_m_AxA_batches
Definition: arima_common.h:142
T * params_ar
Definition: arima_common.h:134
char * buf
Definition: arima_common.h:149
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:247
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:152
T * R_dense
Definition: arima_common.h:136
T ** RQRs_batches
Definition: arima_common.h:143
T * loglike
Definition: arima_common.h:138
T * exog_diff
Definition: arima_common.h:137
T * RQ_dense
Definition: arima_common.h:136
T * params_sigma2
Definition: arima_common.h:134
T ** RQR_batches
Definition: arima_common.h:140
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:260
T * Tparams_sigma2
Definition: arima_common.h:135
T ** TP_batches
Definition: arima_common.h:142
T * Ts_dense
Definition: arima_common.h:138
T * params_ma
Definition: arima_common.h:134
T ** Ts_batches
Definition: arima_common.h:142
Definition: arima_common.h:23
int p
Definition: arima_common.h:24
int s
Definition: arima_common.h:30
int n_phi() const
Definition: arima_common.h:35
int P
Definition: arima_common.h:27
int r() const
Definition: arima_common.h:37
int n_exog
Definition: arima_common.h:32
int rd() const
Definition: arima_common.h:38
int D
Definition: arima_common.h:28
int complexity() const
Definition: arima_common.h:39
int q
Definition: arima_common.h:26
bool need_diff() const
Definition: arima_common.h:40
int n_theta() const
Definition: arima_common.h:36
int Q
Definition: arima_common.h:29
int d
Definition: arima_common.h:25
int k
Definition: arima_common.h:31
int n_diff() const
Definition: arima_common.h:34
Definition: arima_common.h:50
DataT * mu
Definition: arima_common.h:51
DataT * sma
Definition: arima_common.h:56
DataT * beta
Definition: arima_common.h:52
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:90
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
DataT * ma
Definition: arima_common.h:54
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:68
DataT * ar
Definition: arima_common.h:53
DataT * sar
Definition: arima_common.h:55
DataT * sigma2
Definition: arima_common.h:57
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const