batched_arima.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 namespace raft {
11 class handle_t;
12 }
13 
14 namespace ML {
15 
16 enum LoglikeMethod { CSS, MLE };
17 
27 void pack(raft::handle_t& handle,
29  const ARIMAOrder& order,
30  int batch_size,
31  double* param_vec);
32 
42 void unpack(raft::handle_t& handle,
44  const ARIMAOrder& order,
45  int batch_size,
46  const double* param_vec);
47 
55 bool detect_missing(raft::handle_t& handle, const double* d_y, int n_elem);
56 
67 void batched_diff(raft::handle_t& handle,
68  double* d_y_diff,
69  const double* d_y,
70  int batch_size,
71  int n_obs,
72  const ARIMAOrder& order);
73 
96 void batched_loglike(raft::handle_t& handle,
97  const ARIMAMemory<double>& arima_mem,
98  const double* d_y,
99  const double* d_exog,
100  int batch_size,
101  int n_obs,
102  const ARIMAOrder& order,
103  const double* d_params,
104  double* loglike,
105  bool trans = true,
106  bool host_loglike = true,
107  LoglikeMethod method = MLE,
108  int truncate = 0);
109 
142 void batched_loglike(raft::handle_t& handle,
143  const ARIMAMemory<double>& arima_mem,
144  const double* d_y,
145  const double* d_exog,
146  int batch_size,
147  int n_obs,
148  const ARIMAOrder& order,
150  double* loglike,
151  bool trans = true,
152  bool host_loglike = true,
153  LoglikeMethod method = MLE,
154  int truncate = 0,
155  int fc_steps = 0,
156  double* d_fc = nullptr,
157  const double* d_exog_fut = nullptr,
158  double level = 0,
159  double* d_lower = nullptr,
160  double* d_upper = nullptr);
161 
182 void batched_loglike_grad(raft::handle_t& handle,
183  const ARIMAMemory<double>& arima_mem,
184  const double* d_y,
185  const double* d_exog,
186  int batch_size,
187  int n_obs,
188  const ARIMAOrder& order,
189  const double* d_x,
190  double* d_grad,
191  double h,
192  bool trans = true,
193  LoglikeMethod method = MLE,
194  int truncate = 0);
195 
222 void predict(raft::handle_t& handle,
223  const ARIMAMemory<double>& arima_mem,
224  const double* d_y,
225  const double* d_exog,
226  const double* d_exog_fut,
227  int batch_size,
228  int n_obs,
229  int start,
230  int end,
231  const ARIMAOrder& order,
233  double* d_y_p,
234  bool pre_diff = true,
235  double level = 0,
236  double* d_lower = nullptr,
237  double* d_upper = nullptr);
238 
258 void information_criterion(raft::handle_t& handle,
259  const ARIMAMemory<double>& arima_mem,
260  const double* d_y,
261  const double* d_exog,
262  int batch_size,
263  int n_obs,
264  const ARIMAOrder& order,
266  double* ic,
267  int ic_type);
268 
284 void estimate_x0(raft::handle_t& handle,
286  const double* d_y,
287  const double* d_exog,
288  int batch_size,
289  int n_obs,
290  const ARIMAOrder& order,
291  bool missing);
292 
293 } // namespace ML
Definition: params.hpp:23
Definition: dbscan.hpp:18
LoglikeMethod
Definition: batched_arima.hpp:16
@ CSS
Definition: batched_arima.hpp:16
@ MLE
Definition: batched_arima.hpp:16
void estimate_x0(raft::handle_t &handle, ARIMAParams< double > ¶ms, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, bool missing)
void batched_loglike(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const double *d_params, double *loglike, bool trans=true, bool host_loglike=true, LoglikeMethod method=MLE, int truncate=0)
bool detect_missing(raft::handle_t &handle, const double *d_y, int n_elem)
void unpack(raft::handle_t &handle, ARIMAParams< double > ¶ms, const ARIMAOrder &order, int batch_size, const double *param_vec)
void information_criterion(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const ARIMAParams< double > ¶ms, double *ic, int ic_type)
void batched_diff(raft::handle_t &handle, double *d_y_diff, const double *d_y, int batch_size, int n_obs, const ARIMAOrder &order)
void batched_loglike_grad(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const double *d_x, double *d_grad, double h, bool trans=true, LoglikeMethod method=MLE, int truncate=0)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void pack(raft::handle_t &handle, const ARIMAParams< double > ¶ms, const ARIMAOrder &order, int batch_size, double *param_vec)
Definition: dbscan.hpp:14
Definition: arima_common.h:208
Definition: arima_common.h:26
Definition: arima_common.h:53