8 #include <raft/util/cudart_utils.hpp>
10 #include <rmm/aligned.hpp>
11 #include <rmm/mr/per_device_resource.hpp>
12 #include <rmm/resource_ref.hpp>
14 #include <cuda_runtime.h>
40 inline bool need_diff()
const {
return static_cast<bool>(
d +
D); }
49 template <
typename DataT>
70 rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource_ref();
71 if (order.
k && !tr)
mu = (DataT*)rmm_alloc.allocate(stream, batch_size *
sizeof(DataT));
73 beta = (DataT*)rmm_alloc.allocate(stream, order.
n_exog * batch_size *
sizeof(DataT));
74 if (order.
p)
ar = (DataT*)rmm_alloc.allocate(stream, order.
p * batch_size *
sizeof(DataT));
75 if (order.
q)
ma = (DataT*)rmm_alloc.allocate(stream, order.
q * batch_size *
sizeof(DataT));
76 if (order.
P)
sar = (DataT*)rmm_alloc.allocate(stream, order.
P * batch_size *
sizeof(DataT));
77 if (order.
Q)
sma = (DataT*)rmm_alloc.allocate(stream, order.
Q * batch_size *
sizeof(DataT));
78 sigma2 = (DataT*)rmm_alloc.allocate(stream, batch_size *
sizeof(DataT));
92 rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource_ref();
93 if (order.
k && !tr) rmm_alloc.deallocate(stream,
mu, batch_size *
sizeof(DataT));
95 rmm_alloc.deallocate(stream,
beta, order.
n_exog * batch_size *
sizeof(DataT));
96 if (order.
p) rmm_alloc.deallocate(stream,
ar, order.
p * batch_size *
sizeof(DataT));
97 if (order.
q) rmm_alloc.deallocate(stream,
ma, order.
q * batch_size *
sizeof(DataT));
98 if (order.
P) rmm_alloc.deallocate(stream,
sar, order.
P * batch_size *
sizeof(DataT));
99 if (order.
Q) rmm_alloc.deallocate(stream,
sma, order.
Q * batch_size *
sizeof(DataT));
100 rmm_alloc.deallocate(stream,
sigma2, batch_size *
sizeof(DataT));
112 void pack(
const ARIMAOrder& order,
int batch_size, DataT* param_vec, cudaStream_t stream)
const;
123 void unpack(
const ARIMAOrder& order,
int batch_size,
const DataT* param_vec, cudaStream_t stream);
132 template <
typename T,
int ALIGN = 256>
151 template <
bool assign,
typename ValType>
154 if (assign) { ptr =
reinterpret_cast<ValType*
>(
buf +
size); }
155 size += ((n_elem *
sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
158 template <
bool assign>
162 char* in_buf =
nullptr)
170 int n_diff = order.
n_diff();
172 append_buffer<assign>(
params_mu, order.
k * batch_size);
174 append_buffer<assign>(
params_ar, order.
p * batch_size);
175 append_buffer<assign>(
params_ma, order.
q * batch_size);
176 append_buffer<assign>(
params_sar, order.
P * batch_size);
177 append_buffer<assign>(
params_sma, order.
Q * batch_size);
180 append_buffer<assign>(
Tparams_ar, order.
p * batch_size);
181 append_buffer<assign>(
Tparams_ma, order.
q * batch_size);
182 append_buffer<assign>(
Tparams_sar, order.
P * batch_size);
183 append_buffer<assign>(
Tparams_sma, order.
Q * batch_size);
186 append_buffer<assign>(
d_params, N * batch_size);
187 append_buffer<assign>(
d_Tparams, N * batch_size);
188 append_buffer<assign>(
Z_dense, rd * batch_size);
189 append_buffer<assign>(
Z_batches, batch_size);
190 append_buffer<assign>(
R_dense, rd * batch_size);
191 append_buffer<assign>(
R_batches, batch_size);
192 append_buffer<assign>(
T_dense, rd * rd * batch_size);
193 append_buffer<assign>(
T_batches, batch_size);
194 append_buffer<assign>(
RQ_dense, rd * batch_size);
195 append_buffer<assign>(
RQ_batches, batch_size);
196 append_buffer<assign>(
RQR_dense, rd * rd * batch_size);
198 append_buffer<assign>(
P_dense, rd * rd * batch_size);
199 append_buffer<assign>(
P_batches, batch_size);
200 append_buffer<assign>(
alpha_dense, rd * batch_size);
202 append_buffer<assign>(
ImT_dense, r * r * batch_size);
206 append_buffer<assign>(
ImT_inv_P, r * batch_size);
208 append_buffer<assign>(
v_tmp_dense, rd * batch_size);
210 append_buffer<assign>(
m_tmp_dense, rd * rd * batch_size);
212 append_buffer<assign>(
K_dense, rd * batch_size);
213 append_buffer<assign>(
K_batches, batch_size);
214 append_buffer<assign>(
TP_dense, rd * rd * batch_size);
215 append_buffer<assign>(
TP_batches, batch_size);
217 append_buffer<assign>(
pred, n_obs * batch_size);
218 append_buffer<assign>(
y_diff, n_obs * batch_size);
220 append_buffer<assign>(
loglike, batch_size);
223 append_buffer<assign>(
x_pert, N * batch_size);
226 append_buffer<assign>(
Ts_dense, r * r * batch_size);
227 append_buffer<assign>(
Ts_batches, batch_size);
228 append_buffer<assign>(
RQRs_dense, r * r * batch_size);
230 append_buffer<assign>(
Ps_dense, r * r * batch_size);
231 append_buffer<assign>(
Ps_batches, batch_size);
237 append_buffer<assign>(
I_m_AxA_dense, r * r * r * r * batch_size);
241 append_buffer<assign>(
I_m_AxA_P, r * r * batch_size);
249 buf_offsets<false>(order, batch_size, n_obs);
262 buf_offsets<true>(order, batch_size, n_obs, in_buf);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
Definition: dbscan.hpp:18
Definition: arima_common.h:133
T * x_pert
Definition: arima_common.h:138
T * Tparams_sar
Definition: arima_common.h:135
T * K_dense
Definition: arima_common.h:137
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:159
T ** R_batches
Definition: arima_common.h:140
T ** RQ_batches
Definition: arima_common.h:140
T * params_mu
Definition: arima_common.h:134
T * T_dense
Definition: arima_common.h:136
T ** Ps_batches
Definition: arima_common.h:143
T * alpha_dense
Definition: arima_common.h:136
int * ImT_inv_P
Definition: arima_common.h:144
T * Z_dense
Definition: arima_common.h:136
T * d_params
Definition: arima_common.h:135
T * loglike_base
Definition: arima_common.h:138
T * ImT_inv_dense
Definition: arima_common.h:137
T * Tparams_sma
Definition: arima_common.h:135
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:271
T * y_diff
Definition: arima_common.h:137
T * TP_dense
Definition: arima_common.h:137
T * Tparams_ar
Definition: arima_common.h:135
T * RQRs_dense
Definition: arima_common.h:139
T ** ImT_batches
Definition: arima_common.h:141
T ** T_batches
Definition: arima_common.h:140
T * I_m_AxA_inv_dense
Definition: arima_common.h:138
T * pred
Definition: arima_common.h:137
T * ImT_dense
Definition: arima_common.h:136
int * ImT_inv_info
Definition: arima_common.h:144
T ** alpha_batches
Definition: arima_common.h:141
T * params_sma
Definition: arima_common.h:134
T * Ps_dense
Definition: arima_common.h:139
T ** P_batches
Definition: arima_common.h:140
T * loglike_pert
Definition: arima_common.h:138
T ** m_tmp_batches
Definition: arima_common.h:141
T ** Z_batches
Definition: arima_common.h:140
int * I_m_AxA_P
Definition: arima_common.h:144
T * m_tmp_dense
Definition: arima_common.h:137
T * params_sar
Definition: arima_common.h:134
T ** K_batches
Definition: arima_common.h:142
T * RQR_dense
Definition: arima_common.h:136
size_t size
Definition: arima_common.h:146
T ** I_m_AxA_inv_batches
Definition: arima_common.h:142
T * d_Tparams
Definition: arima_common.h:135
T ** ImT_inv_batches
Definition: arima_common.h:141
T * I_m_AxA_dense
Definition: arima_common.h:138
T * v_tmp_dense
Definition: arima_common.h:137
T * Tparams_ma
Definition: arima_common.h:135
T * params_beta
Definition: arima_common.h:134
int * I_m_AxA_info
Definition: arima_common.h:144
T * P_dense
Definition: arima_common.h:136
T ** v_tmp_batches
Definition: arima_common.h:141
T ** I_m_AxA_batches
Definition: arima_common.h:142
T * params_ar
Definition: arima_common.h:134
char * buf
Definition: arima_common.h:149
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:247
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:152
T * R_dense
Definition: arima_common.h:136
T ** RQRs_batches
Definition: arima_common.h:143
T * loglike
Definition: arima_common.h:138
T * exog_diff
Definition: arima_common.h:137
T * RQ_dense
Definition: arima_common.h:136
T * params_sigma2
Definition: arima_common.h:134
T ** RQR_batches
Definition: arima_common.h:140
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:260
T * Tparams_sigma2
Definition: arima_common.h:135
T ** TP_batches
Definition: arima_common.h:142
T * Ts_dense
Definition: arima_common.h:138
T * params_ma
Definition: arima_common.h:134
T ** Ts_batches
Definition: arima_common.h:142
Definition: arima_common.h:23
int p
Definition: arima_common.h:24
int s
Definition: arima_common.h:30
int n_phi() const
Definition: arima_common.h:35
int P
Definition: arima_common.h:27
int r() const
Definition: arima_common.h:37
int n_exog
Definition: arima_common.h:32
int rd() const
Definition: arima_common.h:38
int D
Definition: arima_common.h:28
int complexity() const
Definition: arima_common.h:39
int q
Definition: arima_common.h:26
bool need_diff() const
Definition: arima_common.h:40
int n_theta() const
Definition: arima_common.h:36
int Q
Definition: arima_common.h:29
int d
Definition: arima_common.h:25
int k
Definition: arima_common.h:31
int n_diff() const
Definition: arima_common.h:34
Definition: arima_common.h:50
DataT * mu
Definition: arima_common.h:51
DataT * sma
Definition: arima_common.h:56
DataT * beta
Definition: arima_common.h:52
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:90
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
DataT * ma
Definition: arima_common.h:54
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:68
DataT * ar
Definition: arima_common.h:53
DataT * sar
Definition: arima_common.h:55
DataT * sigma2
Definition: arima_common.h:57
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const