callback.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <type_traits>
9 
10 namespace ML {
11 namespace Internals {
12 
13 class Callback {
14  public:
15  virtual ~Callback() {}
16 };
17 
19  public:
20  template <typename T>
21  void setup(int n, int n_components)
22  {
23  this->n = n;
24  this->n_components = n_components;
25  this->isFloat = std::is_same<T, float>::value;
26  }
27 
28  virtual void on_preprocess_end(void* embeddings) = 0;
29  virtual void on_epoch_end(void* embeddings) = 0;
30  virtual void on_train_end(void* embeddings) = 0;
31 
32  protected:
33  int n;
35  bool isFloat;
36 };
37 
38 } // namespace Internals
39 } // namespace ML
Definition: callback.hpp:13
virtual ~Callback()
Definition: callback.hpp:15
Definition: callback.hpp:18
virtual void on_preprocess_end(void *embeddings)=0
virtual void on_epoch_end(void *embeddings)=0
virtual void on_train_end(void *embeddings)=0
void setup(int n, int n_components)
Definition: callback.hpp:21
bool isFloat
Definition: callback.hpp:35
int n
Definition: callback.hpp:33
int n_components
Definition: callback.hpp:34
Definition: dbscan.hpp:18