batched_arima.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/tsa/arima_common.h>
20 
21 namespace raft {
22 class handle_t;
23 }
24 
25 namespace ML {
26 
27 enum LoglikeMethod { CSS, MLE };
28 
38 void pack(raft::handle_t& handle,
40  const ARIMAOrder& order,
41  int batch_size,
42  double* param_vec);
43 
53 void unpack(raft::handle_t& handle,
55  const ARIMAOrder& order,
56  int batch_size,
57  const double* param_vec);
58 
66 bool detect_missing(raft::handle_t& handle, const double* d_y, int n_elem);
67 
78 void batched_diff(raft::handle_t& handle,
79  double* d_y_diff,
80  const double* d_y,
81  int batch_size,
82  int n_obs,
83  const ARIMAOrder& order);
84 
107 void batched_loglike(raft::handle_t& handle,
108  const ARIMAMemory<double>& arima_mem,
109  const double* d_y,
110  const double* d_exog,
111  int batch_size,
112  int n_obs,
113  const ARIMAOrder& order,
114  const double* d_params,
115  double* loglike,
116  bool trans = true,
117  bool host_loglike = true,
118  LoglikeMethod method = MLE,
119  int truncate = 0);
120 
153 void batched_loglike(raft::handle_t& handle,
154  const ARIMAMemory<double>& arima_mem,
155  const double* d_y,
156  const double* d_exog,
157  int batch_size,
158  int n_obs,
159  const ARIMAOrder& order,
161  double* loglike,
162  bool trans = true,
163  bool host_loglike = true,
164  LoglikeMethod method = MLE,
165  int truncate = 0,
166  int fc_steps = 0,
167  double* d_fc = nullptr,
168  const double* d_exog_fut = nullptr,
169  double level = 0,
170  double* d_lower = nullptr,
171  double* d_upper = nullptr);
172 
193 void batched_loglike_grad(raft::handle_t& handle,
194  const ARIMAMemory<double>& arima_mem,
195  const double* d_y,
196  const double* d_exog,
197  int batch_size,
198  int n_obs,
199  const ARIMAOrder& order,
200  const double* d_x,
201  double* d_grad,
202  double h,
203  bool trans = true,
204  LoglikeMethod method = MLE,
205  int truncate = 0);
206 
233 void predict(raft::handle_t& handle,
234  const ARIMAMemory<double>& arima_mem,
235  const double* d_y,
236  const double* d_exog,
237  const double* d_exog_fut,
238  int batch_size,
239  int n_obs,
240  int start,
241  int end,
242  const ARIMAOrder& order,
244  double* d_y_p,
245  bool pre_diff = true,
246  double level = 0,
247  double* d_lower = nullptr,
248  double* d_upper = nullptr);
249 
269 void information_criterion(raft::handle_t& handle,
270  const ARIMAMemory<double>& arima_mem,
271  const double* d_y,
272  const double* d_exog,
273  int batch_size,
274  int n_obs,
275  const ARIMAOrder& order,
277  double* ic,
278  int ic_type);
279 
295 void estimate_x0(raft::handle_t& handle,
297  const double* d_y,
298  const double* d_exog,
299  int batch_size,
300  int n_obs,
301  const ARIMAOrder& order,
302  bool missing);
303 
304 } // namespace ML
Definition: params.hpp:34
Definition: dbscan.hpp:30
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, int verbosity=CUML_LEVEL_INFO)
LoglikeMethod
Definition: batched_arima.hpp:27
@ CSS
Definition: batched_arima.hpp:27
@ MLE
Definition: batched_arima.hpp:27
void estimate_x0(raft::handle_t &handle, ARIMAParams< double > &params, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, bool missing)
void batched_loglike(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const double *d_params, double *loglike, bool trans=true, bool host_loglike=true, LoglikeMethod method=MLE, int truncate=0)
bool detect_missing(raft::handle_t &handle, const double *d_y, int n_elem)
void unpack(raft::handle_t &handle, ARIMAParams< double > &params, const ARIMAOrder &order, int batch_size, const double *param_vec)
void information_criterion(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const ARIMAParams< double > &params, double *ic, int ic_type)
void batched_diff(raft::handle_t &handle, double *d_y_diff, const double *d_y, int batch_size, int n_obs, const ARIMAOrder &order)
void batched_loglike_grad(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_y, const double *d_exog, int batch_size, int n_obs, const ARIMAOrder &order, const double *d_x, double *d_grad, double h, bool trans=true, LoglikeMethod method=MLE, int truncate=0)
void pack(raft::handle_t &handle, const ARIMAParams< double > &params, const ARIMAOrder &order, int batch_size, double *param_vec)
Definition: dbscan.hpp:26
Definition: arima_common.h:243
Definition: arima_common.h:37
Definition: arima_common.h:64