umapparams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/callback.hpp>
21 #include <cuml/common/logger.hpp>
22 
23 namespace ML {
24 
25 namespace graph_build_params {
26 
35  // not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
36  // Descent parameters
37  size_t graph_degree = 64;
39  size_t max_iterations = 20;
40  float termination_threshold = 0.0001;
41 };
42 
63  size_t overlap_factor = 2;
70  size_t n_clusters = 1;
72 };
73 } // namespace graph_build_params
74 
75 class UMAPParams {
76  public:
79 
86  int n_neighbors = 15;
87 
91  int n_components = 2;
92 
97  int n_epochs = 0;
98 
102  float learning_rate = 1.0;
103 
112  float min_dist = 0.1;
113 
118  float spread = 1.0;
119 
128  float set_op_mix_ratio = 1.0;
129 
137  float local_connectivity = 1.0;
138 
144  float repulsion_strength = 1.0;
145 
153 
160  float transform_queue_size = 4.0;
161 
165  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info;
166 
172  float a = -1.0;
173 
179  float b = -1.0;
180 
184  float initial_alpha = 1.0;
185 
191  int init = 1;
192 
196  graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;
197 
199 
205 
207 
208  float target_weight = 0.5;
209 
210  uint64_t random_state = 0;
211 
217  bool deterministic = true;
218 
220 
221  float p = 2.0;
222 
224 };
225 
226 } // namespace ML
Definition: callback.hpp:29
Definition: umapparams.h:75
bool deterministic
Definition: umapparams.h:217
ML::distance::DistanceType metric
Definition: umapparams.h:219
graph_build_algo
Definition: umapparams.h:78
@ BRUTE_FORCE_KNN
Definition: umapparams.h:78
@ NN_DESCENT
Definition: umapparams.h:78
float min_dist
Definition: umapparams.h:112
float repulsion_strength
Definition: umapparams.h:144
float local_connectivity
Definition: umapparams.h:137
rapids_logger::level_enum verbosity
Definition: umapparams.h:165
float set_op_mix_ratio
Definition: umapparams.h:128
float initial_alpha
Definition: umapparams.h:184
float target_weight
Definition: umapparams.h:208
float spread
Definition: umapparams.h:118
float a
Definition: umapparams.h:172
int n_components
Definition: umapparams.h:91
int n_neighbors
Definition: umapparams.h:86
float transform_queue_size
Definition: umapparams.h:160
MetricType target_metric
Definition: umapparams.h:206
float p
Definition: umapparams.h:221
graph_build_algo build_algo
Definition: umapparams.h:196
int negative_sample_rate
Definition: umapparams.h:152
graph_build_params::graph_build_params build_params
Definition: umapparams.h:198
float b
Definition: umapparams.h:179
int n_epochs
Definition: umapparams.h:97
int init
Definition: umapparams.h:191
MetricType
Definition: umapparams.h:77
@ EUCLIDEAN
Definition: umapparams.h:77
@ CATEGORICAL
Definition: umapparams.h:77
Internals::GraphBasedDimRedCallback * callback
Definition: umapparams.h:223
uint64_t random_state
Definition: umapparams.h:210
float learning_rate
Definition: umapparams.h:102
int target_n_neighbors
Definition: umapparams.h:204
DistanceType
Definition: distance_type.hpp:21
Definition: dbscan.hpp:29
size_t n_clusters
Definition: umapparams.h:70
size_t overlap_factor
Definition: umapparams.h:63
nn_descent_params_umap nn_descent_params
Definition: umapparams.h:71
size_t max_iterations
Definition: umapparams.h:39
float termination_threshold
Definition: umapparams.h:40
size_t graph_degree
Definition: umapparams.h:37
size_t intermediate_graph_degree
Definition: umapparams.h:38