callback.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <type_traits>
20 
21 namespace ML {
22 namespace Internals {
23 
24 class Callback {
25  public:
26  virtual ~Callback() {}
27 };
28 
30  public:
31  template <typename T>
32  void setup(int n, int n_components)
33  {
34  this->n = n;
35  this->n_components = n_components;
36  this->isFloat = std::is_same<T, float>::value;
37  }
38 
39  virtual void on_preprocess_end(void* embeddings) = 0;
40  virtual void on_epoch_end(void* embeddings) = 0;
41  virtual void on_train_end(void* embeddings) = 0;
42 
43  protected:
44  int n;
46  bool isFloat;
47 };
48 
49 } // namespace Internals
50 } // namespace ML
Definition: callback.hpp:24
virtual ~Callback()
Definition: callback.hpp:26
Definition: callback.hpp:29
virtual void on_preprocess_end(void *embeddings)=0
virtual void on_epoch_end(void *embeddings)=0
virtual void on_train_end(void *embeddings)=0
void setup(int n, int n_components)
Definition: callback.hpp:32
bool isFloat
Definition: callback.hpp:46
int n
Definition: callback.hpp:44
int n_components
Definition: callback.hpp:45
Definition: dbscan.hpp:30