learning_rate.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <math.h>
11 
12 namespace ML {
13 namespace Solver {
14 
15 template <typename math_t>
16 math_t max(math_t a, math_t b)
17 {
18  return (a < b) ? b : a;
19  ;
20 }
21 
22 template <typename math_t>
23 math_t invScaling(math_t eta, math_t power_t, int t)
24 {
25  return (eta / pow(t, power_t));
26 }
27 
28 template <typename math_t>
29 math_t regDLoss(math_t a, math_t b)
30 {
31  return a - b;
32 }
33 
34 template <typename math_t>
35 math_t calOptimalInit(math_t alpha)
36 {
37  math_t typw = sqrt(math_t(1.0) / sqrt(alpha));
38  math_t initial_eta0 = typw / max(math_t(1.0), regDLoss(-typw, math_t(1.0)));
39  return (math_t(1.0) / (initial_eta0 * alpha));
40 }
41 
42 template <typename math_t>
43 math_t optimal(math_t alpha, math_t optimal_init, int t)
44 {
45  return math_t(1.0) / (alpha * (optimal_init + t - 1));
46 }
47 
48 template <typename math_t>
49 math_t calLearningRate(ML::lr_type lr_type, math_t eta, math_t power_t, math_t alpha, math_t t)
50 {
52  return eta;
53  } else if (lr_type == ML::lr_type::INVSCALING) {
54  return invScaling(eta, power_t, t);
55  } else if (lr_type == ML::lr_type::OPTIMAL) {
56  return optimal(alpha, eta, t);
57  } else {
58  return math_t(0);
59  }
60 }
61 
62 }; // namespace Solver
63 }; // namespace ML
64 // end namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
math_t invScaling(math_t eta, math_t power_t, int t)
Definition: learning_rate.h:23
math_t optimal(math_t alpha, math_t optimal_init, int t)
Definition: learning_rate.h:43
math_t regDLoss(math_t a, math_t b)
Definition: learning_rate.h:29
math_t calLearningRate(ML::lr_type lr_type, math_t eta, math_t power_t, math_t alpha, math_t t)
Definition: learning_rate.h:49
math_t calOptimalInit(math_t alpha)
Definition: learning_rate.h:35
Definition: dbscan.hpp:18
lr_type
Definition: params.hpp:10
@ OPTIMAL
Definition: params.hpp:11
@ INVSCALING
Definition: params.hpp:13
@ CONSTANT
Definition: params.hpp:12