#include <arima_common.h>
Public Member Functions | |
void | allocate (const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false) |
void | deallocate (const ARIMAOrder &order, int batch_size, cudaStream_t stream, bool tr=false) |
void | pack (const ARIMAOrder &order, int batch_size, DataT *param_vec, cudaStream_t stream) const |
void | unpack (const ARIMAOrder &order, int batch_size, const DataT *param_vec, cudaStream_t stream) |
Public Attributes | |
DataT * | mu = nullptr |
DataT * | beta = nullptr |
DataT * | ar = nullptr |
DataT * | ma = nullptr |
DataT * | sar = nullptr |
DataT * | sma = nullptr |
DataT * | sigma2 = nullptr |
Structure to hold the parameters (makes it easier to pass as an argument)
|
inline |
Allocate all the parameter device arrays
AllocatorT | Type of allocator used |
[in] | order | ARIMA order |
[in] | batch_size | Batch size |
[in] | stream | CUDA stream |
[in] | tr | Whether these are the transformed parameters |
|
inline |
Deallocate all the parameter device arrays
AllocatorT | Type of allocator used |
[in] | order | ARIMA order |
[in] | batch_size | Batch size |
[in] | stream | CUDA stream |
[in] | tr | Whether these are the transformed parameters |
|
inline |
Pack the separate parameter arrays into a unique parameter vector
[in] | order | ARIMA order |
[in] | batch_size | Batch size |
[out] | param_vec | Linear array of all parameters grouped by batch [mu, ar, ma, sar, sma, sigma2] (device) |
[in] | stream | CUDA stream |
|
inline |
Unpack a parameter vector into separate arrays of parameters.
[in] | order | ARIMA order |
[in] | batch_size | Batch size |
[in] | param_vec | Linear array of all parameters grouped by batch [mu, ar, ma, sar, sma, sigma2] (device) |
[in] | stream | CUDA stream |
DataT* ML::ARIMAParams< DataT >::ar = nullptr |
DataT* ML::ARIMAParams< DataT >::beta = nullptr |
DataT* ML::ARIMAParams< DataT >::ma = nullptr |
DataT* ML::ARIMAParams< DataT >::mu = nullptr |
DataT* ML::ARIMAParams< DataT >::sar = nullptr |
DataT* ML::ARIMAParams< DataT >::sigma2 = nullptr |
DataT* ML::ARIMAParams< DataT >::sma = nullptr |