umapparams.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
10 #include <cuml/common/logger.hpp>
11 
12 namespace ML {
13 
14 namespace graph_build_params {
15 
24  // not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
25  // Descent parameters
26  size_t graph_degree = 64;
28  size_t max_iterations = 20;
29  float termination_threshold = 0.0001;
30 };
31 
52  size_t overlap_factor = 2;
59  size_t n_clusters = 1;
61 };
62 } // namespace graph_build_params
63 
64 class UMAPParams {
65  public:
68 
75  int n_neighbors = 15;
76 
80  int n_components = 2;
81 
86  int n_epochs = 0;
87 
91  float learning_rate = 1.0;
92 
101  float min_dist = 0.1;
102 
107  float spread = 1.0;
108 
117  float set_op_mix_ratio = 1.0;
118 
126  float local_connectivity = 1.0;
127 
133  float repulsion_strength = 1.0;
134 
142 
149  float transform_queue_size = 4.0;
150 
154  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info;
155 
161  float a = -1.0;
162 
168  float b = -1.0;
169 
173  float initial_alpha = 1.0;
174 
180  int init = 1;
181 
186 
188 
194 
196 
197  float target_weight = 0.5;
198 
199  uint64_t random_state = 0;
200 
206  bool deterministic = true;
207 
209 
210  float p = 2.0;
211 
213 };
214 
215 } // namespace ML
Definition: callback.hpp:18
Definition: umapparams.h:64
bool deterministic
Definition: umapparams.h:206
ML::distance::DistanceType metric
Definition: umapparams.h:208
graph_build_algo
Definition: umapparams.h:67
@ BRUTE_FORCE_KNN
Definition: umapparams.h:67
@ NN_DESCENT
Definition: umapparams.h:67
float min_dist
Definition: umapparams.h:101
float repulsion_strength
Definition: umapparams.h:133
float local_connectivity
Definition: umapparams.h:126
rapids_logger::level_enum verbosity
Definition: umapparams.h:154
float set_op_mix_ratio
Definition: umapparams.h:117
float initial_alpha
Definition: umapparams.h:173
float target_weight
Definition: umapparams.h:197
float spread
Definition: umapparams.h:107
float a
Definition: umapparams.h:161
int n_components
Definition: umapparams.h:80
int n_neighbors
Definition: umapparams.h:75
float transform_queue_size
Definition: umapparams.h:149
MetricType target_metric
Definition: umapparams.h:195
float p
Definition: umapparams.h:210
graph_build_algo build_algo
Definition: umapparams.h:185
int negative_sample_rate
Definition: umapparams.h:141
graph_build_params::graph_build_params build_params
Definition: umapparams.h:187
float b
Definition: umapparams.h:168
int n_epochs
Definition: umapparams.h:86
int init
Definition: umapparams.h:180
MetricType
Definition: umapparams.h:66
@ EUCLIDEAN
Definition: umapparams.h:66
@ CATEGORICAL
Definition: umapparams.h:66
Internals::GraphBasedDimRedCallback * callback
Definition: umapparams.h:212
uint64_t random_state
Definition: umapparams.h:199
float learning_rate
Definition: umapparams.h:91
int target_n_neighbors
Definition: umapparams.h:193
@ BRUTE_FORCE_KNN
Definition: hdbscan.hpp:127
DistanceType
Definition: distance_type.hpp:10
Definition: dbscan.hpp:18
size_t n_clusters
Definition: umapparams.h:59
size_t overlap_factor
Definition: umapparams.h:52
nn_descent_params_umap nn_descent_params
Definition: umapparams.h:60
size_t max_iterations
Definition: umapparams.h:28
float termination_threshold
Definition: umapparams.h:29
size_t graph_degree
Definition: umapparams.h:26
size_t intermediate_graph_degree
Definition: umapparams.h:27