batched_kalman.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/tsa/arima_common.h>
20 
21 namespace raft {
22 class handle_t;
23 }
24 
25 namespace ML {
26 
53 void batched_kalman_filter(raft::handle_t& handle,
54  const ARIMAMemory<double>& arima_mem,
55  const double* d_ys,
56  const double* d_exog,
57  int nobs,
59  const ARIMAOrder& order,
60  int batch_size,
61  double* d_loglike,
62  double* d_pred,
63  int fc_steps = 0,
64  double* d_fc = nullptr,
65  const double* d_exog_fut = nullptr,
66  double level = 0,
67  double* d_lower = nullptr,
68  double* d_upper = nullptr);
69 
85 void batched_jones_transform(raft::handle_t& handle,
86  const ARIMAMemory<double>& arima_mem,
87  const ARIMAOrder& order,
88  int batch_size,
89  bool isInv,
90  const double* h_params,
91  double* h_Tparams);
92 } // namespace ML
Definition: params.hpp:34
Definition: dbscan.hpp:30
void batched_jones_transform(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const ARIMAOrder &order, int batch_size, bool isInv, const double *h_params, double *h_Tparams)
void batched_kalman_filter(raft::handle_t &handle, const ARIMAMemory< double > &arima_mem, const double *d_ys, const double *d_exog, int nobs, const ARIMAParams< double > &params, const ARIMAOrder &order, int batch_size, double *d_loglike, double *d_pred, int fc_steps=0, double *d_fc=nullptr, const double *d_exog_fut=nullptr, double level=0, double *d_lower=nullptr, double *d_upper=nullptr)
Definition: dbscan.hpp:26
Definition: arima_common.h:243
Definition: arima_common.h:37
Definition: arima_common.h:64