umapparams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/callback.hpp>
20 #include <cuml/common/logger.hpp>
21 
22 #include <cuvs/distance/distance.hpp>
23 #include <cuvs/neighbors/nn_descent.hpp>
24 
25 namespace ML {
26 
27 namespace graph_build_params {
28 
37  // not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
38  // Descent parameters
39  size_t graph_degree = 64;
41  size_t max_iterations = 20;
42  float termination_threshold = 0.0001;
43 };
44 
65  size_t overlap_factor = 2;
72  size_t n_clusters = 1;
74 };
75 } // namespace graph_build_params
76 
77 class UMAPParams {
78  public:
81 
88  int n_neighbors = 15;
89 
93  int n_components = 2;
94 
99  int n_epochs = 0;
100 
104  float learning_rate = 1.0;
105 
114  float min_dist = 0.1;
115 
120  float spread = 1.0;
121 
130  float set_op_mix_ratio = 1.0;
131 
139  float local_connectivity = 1.0;
140 
146  float repulsion_strength = 1.0;
147 
155 
162  float transform_queue_size = 4.0;
163 
167  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info;
168 
174  float a = -1.0;
175 
181  float b = -1.0;
182 
186  float initial_alpha = 1.0;
187 
193  int init = 1;
194 
198  graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;
199 
201 
207 
209 
210  float target_weight = 0.5;
211 
212  uint64_t random_state = 0;
213 
219  bool deterministic = true;
220 
221  cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded;
222 
223  float p = 2.0;
224 
226 };
227 
228 } // namespace ML
Definition: callback.hpp:29
Definition: umapparams.h:77
bool deterministic
Definition: umapparams.h:219
graph_build_algo
Definition: umapparams.h:80
@ BRUTE_FORCE_KNN
Definition: umapparams.h:80
@ NN_DESCENT
Definition: umapparams.h:80
float min_dist
Definition: umapparams.h:114
float repulsion_strength
Definition: umapparams.h:146
float local_connectivity
Definition: umapparams.h:139
rapids_logger::level_enum verbosity
Definition: umapparams.h:167
float set_op_mix_ratio
Definition: umapparams.h:130
float initial_alpha
Definition: umapparams.h:186
float target_weight
Definition: umapparams.h:210
float spread
Definition: umapparams.h:120
float a
Definition: umapparams.h:174
int n_components
Definition: umapparams.h:93
int n_neighbors
Definition: umapparams.h:88
float transform_queue_size
Definition: umapparams.h:162
MetricType target_metric
Definition: umapparams.h:208
float p
Definition: umapparams.h:223
graph_build_algo build_algo
Definition: umapparams.h:198
int negative_sample_rate
Definition: umapparams.h:154
graph_build_params::graph_build_params build_params
Definition: umapparams.h:200
cuvs::distance::DistanceType metric
Definition: umapparams.h:221
float b
Definition: umapparams.h:181
int n_epochs
Definition: umapparams.h:99
int init
Definition: umapparams.h:193
MetricType
Definition: umapparams.h:79
@ EUCLIDEAN
Definition: umapparams.h:79
@ CATEGORICAL
Definition: umapparams.h:79
Internals::GraphBasedDimRedCallback * callback
Definition: umapparams.h:225
uint64_t random_state
Definition: umapparams.h:212
float learning_rate
Definition: umapparams.h:104
int target_n_neighbors
Definition: umapparams.h:206
Definition: dbscan.hpp:30
size_t n_clusters
Definition: umapparams.h:72
size_t overlap_factor
Definition: umapparams.h:65
nn_descent_params_umap nn_descent_params
Definition: umapparams.h:73
size_t max_iterations
Definition: umapparams.h:41
float termination_threshold
Definition: umapparams.h:42
size_t graph_degree
Definition: umapparams.h:39
size_t intermediate_graph_degree
Definition: umapparams.h:40