8 #include <raft/util/cudart_utils.hpp>
10 #include <rmm/aligned.hpp>
11 #include <rmm/mr/per_device_resource.hpp>
12 #include <rmm/resource_ref.hpp>
14 #include <cuda_runtime.h>
15 #include <thrust/execution_policy.h>
16 #include <thrust/for_each.h>
17 #include <thrust/iterator/counting_iterator.h>
43 inline bool need_diff()
const {
return static_cast<bool>(
d +
D); }
52 template <
typename DataT>
73 rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();
74 if (order.
k && !tr)
mu = (DataT*)rmm_alloc.allocate(stream, batch_size *
sizeof(DataT));
76 beta = (DataT*)rmm_alloc.allocate(stream, order.
n_exog * batch_size *
sizeof(DataT));
77 if (order.
p)
ar = (DataT*)rmm_alloc.allocate(stream, order.
p * batch_size *
sizeof(DataT));
78 if (order.
q)
ma = (DataT*)rmm_alloc.allocate(stream, order.
q * batch_size *
sizeof(DataT));
79 if (order.
P)
sar = (DataT*)rmm_alloc.allocate(stream, order.
P * batch_size *
sizeof(DataT));
80 if (order.
Q)
sma = (DataT*)rmm_alloc.allocate(stream, order.
Q * batch_size *
sizeof(DataT));
81 sigma2 = (DataT*)rmm_alloc.allocate(stream, batch_size *
sizeof(DataT));
95 rmm::device_async_resource_ref rmm_alloc = rmm::mr::get_current_device_resource();
96 if (order.
k && !tr) rmm_alloc.deallocate(stream,
mu, batch_size *
sizeof(DataT));
98 rmm_alloc.deallocate(stream,
beta, order.
n_exog * batch_size *
sizeof(DataT));
99 if (order.
p) rmm_alloc.deallocate(stream,
ar, order.
p * batch_size *
sizeof(DataT));
100 if (order.
q) rmm_alloc.deallocate(stream,
ma, order.
q * batch_size *
sizeof(DataT));
101 if (order.
P) rmm_alloc.deallocate(stream,
sar, order.
P * batch_size *
sizeof(DataT));
102 if (order.
Q) rmm_alloc.deallocate(stream,
sma, order.
Q * batch_size *
sizeof(DataT));
103 rmm_alloc.deallocate(stream,
sigma2, batch_size *
sizeof(DataT));
115 void pack(
const ARIMAOrder& order,
int batch_size, DataT* param_vec, cudaStream_t stream)
const
118 auto counting = thrust::make_counting_iterator(0);
120 const DataT *_mu =
mu, *_beta =
beta, *_ar =
ar, *_ma =
ma, *_sar =
sar, *_sma =
sma,
123 thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(
int bid) {
124 DataT* param = param_vec + bid * N;
129 for (
int i = 0; i < order.
n_exog; i++) {
130 param[i] = _beta[order.
n_exog * bid + i];
133 for (
int ip = 0; ip < order.
p; ip++) {
134 param[ip] = _ar[order.
p * bid + ip];
137 for (
int iq = 0; iq < order.
q; iq++) {
138 param[iq] = _ma[order.
q * bid + iq];
141 for (
int iP = 0; iP < order.
P; iP++) {
142 param[iP] = _sar[order.
P * bid + iP];
145 for (
int iQ = 0; iQ < order.
Q; iQ++) {
146 param[iQ] = _sma[order.
Q * bid + iQ];
149 *param = _sigma2[bid];
162 void unpack(
const ARIMAOrder& order,
int batch_size,
const DataT* param_vec, cudaStream_t stream)
165 auto counting = thrust::make_counting_iterator(0);
167 DataT *_mu =
mu, *_beta =
beta, *_ar =
ar, *_ma =
ma, *_sar =
sar, *_sma =
sma,
170 thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(
int bid) {
171 const DataT* param = param_vec + bid * N;
176 for (
int i = 0; i < order.
n_exog; i++) {
177 _beta[order.
n_exog * bid + i] = param[i];
180 for (
int ip = 0; ip < order.
p; ip++) {
181 _ar[order.
p * bid + ip] = param[ip];
184 for (
int iq = 0; iq < order.
q; iq++) {
185 _ma[order.
q * bid + iq] = param[iq];
188 for (
int iP = 0; iP < order.
P; iP++) {
189 _sar[order.
P * bid + iP] = param[iP];
192 for (
int iQ = 0; iQ < order.
Q; iQ++) {
193 _sma[order.
Q * bid + iQ] = param[iQ];
196 _sigma2[bid] = *param;
207 template <
typename T,
int ALIGN = 256>
226 template <
bool assign,
typename ValType>
229 if (assign) { ptr =
reinterpret_cast<ValType*
>(
buf +
size); }
230 size += ((n_elem *
sizeof(ValType) + ALIGN - 1) / ALIGN) * ALIGN;
233 template <
bool assign>
237 char* in_buf =
nullptr)
245 int n_diff = order.
n_diff();
247 append_buffer<assign>(
params_mu, order.
k * batch_size);
249 append_buffer<assign>(
params_ar, order.
p * batch_size);
250 append_buffer<assign>(
params_ma, order.
q * batch_size);
251 append_buffer<assign>(
params_sar, order.
P * batch_size);
252 append_buffer<assign>(
params_sma, order.
Q * batch_size);
255 append_buffer<assign>(
Tparams_ar, order.
p * batch_size);
256 append_buffer<assign>(
Tparams_ma, order.
q * batch_size);
257 append_buffer<assign>(
Tparams_sar, order.
P * batch_size);
258 append_buffer<assign>(
Tparams_sma, order.
Q * batch_size);
261 append_buffer<assign>(
d_params, N * batch_size);
262 append_buffer<assign>(
d_Tparams, N * batch_size);
263 append_buffer<assign>(
Z_dense, rd * batch_size);
264 append_buffer<assign>(
Z_batches, batch_size);
265 append_buffer<assign>(
R_dense, rd * batch_size);
266 append_buffer<assign>(
R_batches, batch_size);
267 append_buffer<assign>(
T_dense, rd * rd * batch_size);
268 append_buffer<assign>(
T_batches, batch_size);
269 append_buffer<assign>(
RQ_dense, rd * batch_size);
270 append_buffer<assign>(
RQ_batches, batch_size);
271 append_buffer<assign>(
RQR_dense, rd * rd * batch_size);
273 append_buffer<assign>(
P_dense, rd * rd * batch_size);
274 append_buffer<assign>(
P_batches, batch_size);
275 append_buffer<assign>(
alpha_dense, rd * batch_size);
277 append_buffer<assign>(
ImT_dense, r * r * batch_size);
281 append_buffer<assign>(
ImT_inv_P, r * batch_size);
283 append_buffer<assign>(
v_tmp_dense, rd * batch_size);
285 append_buffer<assign>(
m_tmp_dense, rd * rd * batch_size);
287 append_buffer<assign>(
K_dense, rd * batch_size);
288 append_buffer<assign>(
K_batches, batch_size);
289 append_buffer<assign>(
TP_dense, rd * rd * batch_size);
290 append_buffer<assign>(
TP_batches, batch_size);
292 append_buffer<assign>(
pred, n_obs * batch_size);
293 append_buffer<assign>(
y_diff, n_obs * batch_size);
295 append_buffer<assign>(
loglike, batch_size);
298 append_buffer<assign>(
x_pert, N * batch_size);
301 append_buffer<assign>(
Ts_dense, r * r * batch_size);
302 append_buffer<assign>(
Ts_batches, batch_size);
303 append_buffer<assign>(
RQRs_dense, r * r * batch_size);
305 append_buffer<assign>(
Ps_dense, r * r * batch_size);
306 append_buffer<assign>(
Ps_batches, batch_size);
312 append_buffer<assign>(
I_m_AxA_dense, r * r * r * r * batch_size);
316 append_buffer<assign>(
I_m_AxA_P, r * r * batch_size);
324 buf_offsets<false>(order, batch_size, n_obs);
337 buf_offsets<true>(order, batch_size, n_obs, in_buf);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
Definition: dbscan.hpp:18
Definition: arima_common.h:208
T * x_pert
Definition: arima_common.h:213
T * Tparams_sar
Definition: arima_common.h:210
T * K_dense
Definition: arima_common.h:212
void buf_offsets(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf=nullptr)
Definition: arima_common.h:234
T ** R_batches
Definition: arima_common.h:215
T ** RQ_batches
Definition: arima_common.h:215
T * params_mu
Definition: arima_common.h:209
T * T_dense
Definition: arima_common.h:211
T ** Ps_batches
Definition: arima_common.h:218
T * alpha_dense
Definition: arima_common.h:211
int * ImT_inv_P
Definition: arima_common.h:219
T * Z_dense
Definition: arima_common.h:211
T * d_params
Definition: arima_common.h:210
T * loglike_base
Definition: arima_common.h:213
T * ImT_inv_dense
Definition: arima_common.h:212
T * Tparams_sma
Definition: arima_common.h:210
static size_t compute_size(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:346
T * y_diff
Definition: arima_common.h:212
T * TP_dense
Definition: arima_common.h:212
T * Tparams_ar
Definition: arima_common.h:210
T * RQRs_dense
Definition: arima_common.h:214
T ** ImT_batches
Definition: arima_common.h:216
T ** T_batches
Definition: arima_common.h:215
T * I_m_AxA_inv_dense
Definition: arima_common.h:213
T * pred
Definition: arima_common.h:212
T * ImT_dense
Definition: arima_common.h:211
int * ImT_inv_info
Definition: arima_common.h:219
T ** alpha_batches
Definition: arima_common.h:216
T * params_sma
Definition: arima_common.h:209
T * Ps_dense
Definition: arima_common.h:214
T ** P_batches
Definition: arima_common.h:215
T * loglike_pert
Definition: arima_common.h:213
T ** m_tmp_batches
Definition: arima_common.h:216
T ** Z_batches
Definition: arima_common.h:215
int * I_m_AxA_P
Definition: arima_common.h:219
T * m_tmp_dense
Definition: arima_common.h:212
T * params_sar
Definition: arima_common.h:209
T ** K_batches
Definition: arima_common.h:217
T * RQR_dense
Definition: arima_common.h:211
size_t size
Definition: arima_common.h:221
T ** I_m_AxA_inv_batches
Definition: arima_common.h:217
T * d_Tparams
Definition: arima_common.h:210
T ** ImT_inv_batches
Definition: arima_common.h:216
T * I_m_AxA_dense
Definition: arima_common.h:213
T * v_tmp_dense
Definition: arima_common.h:212
T * Tparams_ma
Definition: arima_common.h:210
T * params_beta
Definition: arima_common.h:209
int * I_m_AxA_info
Definition: arima_common.h:219
T * P_dense
Definition: arima_common.h:211
T ** v_tmp_batches
Definition: arima_common.h:216
T ** I_m_AxA_batches
Definition: arima_common.h:217
T * params_ar
Definition: arima_common.h:209
char * buf
Definition: arima_common.h:224
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs)
Definition: arima_common.h:322
void append_buffer(ValType *&ptr, size_t n_elem)
Definition: arima_common.h:227
T * R_dense
Definition: arima_common.h:211
T ** RQRs_batches
Definition: arima_common.h:218
T * loglike
Definition: arima_common.h:213
T * exog_diff
Definition: arima_common.h:212
T * RQ_dense
Definition: arima_common.h:211
T * params_sigma2
Definition: arima_common.h:209
T ** RQR_batches
Definition: arima_common.h:215
ARIMAMemory(const ARIMAOrder &order, int batch_size, int n_obs, char *in_buf)
Definition: arima_common.h:335
T * Tparams_sigma2
Definition: arima_common.h:210
T ** TP_batches
Definition: arima_common.h:217
T * Ts_dense
Definition: arima_common.h:213
T * params_ma
Definition: arima_common.h:209
T ** Ts_batches
Definition: arima_common.h:217
Definition: arima_common.h:26
int p
Definition: arima_common.h:27
int s
Definition: arima_common.h:33
int n_phi() const
Definition: arima_common.h:38
int P
Definition: arima_common.h:30
int r() const
Definition: arima_common.h:40
int n_exog
Definition: arima_common.h:35
int rd() const
Definition: arima_common.h:41
int D
Definition: arima_common.h:31
int complexity() const
Definition: arima_common.h:42
int q
Definition: arima_common.h:29
bool need_diff() const
Definition: arima_common.h:43
int n_theta() const
Definition: arima_common.h:39
int Q
Definition: arima_common.h:32
int d
Definition: arima_common.h:28
int k
Definition: arima_common.h:34
int n_diff() const
Definition: arima_common.h:37
Definition: arima_common.h:53
DataT * mu
Definition: arima_common.h:54
DataT * sma
Definition: arima_common.h:59
DataT * beta
Definition: arima_common.h:55
void deallocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:93
void unpack(const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream)
Definition: arima_common.h:162
DataT * ma
Definition: arima_common.h:57
void allocate(const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false)
Definition: arima_common.h:71
DataT * ar
Definition: arima_common.h:56
DataT * sar
Definition: arima_common.h:58
DataT * sigma2
Definition: arima_common.h:60
void pack(const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const
Definition: arima_common.h:115