umapparams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/callback.hpp>
20 #include <cuml/common/logger.hpp>
21 
22 #include <raft/neighbors/nn_descent_types.hpp>
23 
24 #include <cuvs/distance/distance.hpp>
25 
26 namespace ML {
27 
28 using nn_index_params = raft::neighbors::experimental::nn_descent::index_params;
29 
30 class UMAPParams {
31  public:
34 
41  int n_neighbors = 15;
42 
46  int n_components = 2;
47 
52  int n_epochs = 0;
53 
57  float learning_rate = 1.0;
58 
67  float min_dist = 0.1;
68 
73  float spread = 1.0;
74 
83  float set_op_mix_ratio = 1.0;
84 
92  float local_connectivity = 1.0;
93 
99  float repulsion_strength = 1.0;
100 
108 
115  float transform_queue_size = 4.0;
116 
121 
127  float a = -1.0;
128 
134  float b = -1.0;
135 
139  float initial_alpha = 1.0;
140 
146  int init = 1;
147 
151  graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;
152 
154 
160 
162 
163  float target_weight = 0.5;
164 
165  uint64_t random_state = 0;
166 
172  bool deterministic = true;
173 
174  cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded;
175 
176  float p = 2.0;
177 
179 };
180 
181 } // namespace ML
Definition: callback.hpp:29
Definition: umapparams.h:30
bool deterministic
Definition: umapparams.h:172
graph_build_algo
Definition: umapparams.h:33
@ BRUTE_FORCE_KNN
Definition: umapparams.h:33
@ NN_DESCENT
Definition: umapparams.h:33
float min_dist
Definition: umapparams.h:67
float repulsion_strength
Definition: umapparams.h:99
float local_connectivity
Definition: umapparams.h:92
float set_op_mix_ratio
Definition: umapparams.h:83
float initial_alpha
Definition: umapparams.h:139
float target_weight
Definition: umapparams.h:163
float spread
Definition: umapparams.h:73
float a
Definition: umapparams.h:127
int n_components
Definition: umapparams.h:46
int n_neighbors
Definition: umapparams.h:41
float transform_queue_size
Definition: umapparams.h:115
MetricType target_metric
Definition: umapparams.h:161
float p
Definition: umapparams.h:176
graph_build_algo build_algo
Definition: umapparams.h:151
int negative_sample_rate
Definition: umapparams.h:107
cuvs::distance::DistanceType metric
Definition: umapparams.h:174
nn_index_params nn_descent_params
Definition: umapparams.h:153
float b
Definition: umapparams.h:134
int n_epochs
Definition: umapparams.h:52
int verbosity
Definition: umapparams.h:120
int init
Definition: umapparams.h:146
MetricType
Definition: umapparams.h:32
@ EUCLIDEAN
Definition: umapparams.h:32
@ CATEGORICAL
Definition: umapparams.h:32
Internals::GraphBasedDimRedCallback * callback
Definition: umapparams.h:178
uint64_t random_state
Definition: umapparams.h:165
float learning_rate
Definition: umapparams.h:57
int target_n_neighbors
Definition: umapparams.h:159
#define CUML_LEVEL_INFO
Definition: log_levels.hpp:28
Definition: dbscan.hpp:30
raft::neighbors::experimental::nn_descent::index_params nn_index_params
Definition: umapparams.h:28