kmeans_params.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <raft/random/rng_state.hpp>
11 
12 #include <rapids_logger/logger.hpp>
13 
15 
16 struct params;
17 
18 } // end namespace cuvs::cluster::kmeans
19 
20 namespace ML::kmeans {
21 
22 struct KMeansParams {
25  int n_clusters = 8;
27  int max_iter = 300;
28  double tol = 1e-4;
29  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info;
30  raft::random::RngState rng_state{0};
31  int n_init = 1;
32  double oversampling_factor = 2.0;
33  int batch_samples = 1 << 15;
34  int batch_centroids = 0;
35  bool inertia_check = false;
36 
37  cuvs::cluster::kmeans::params to_cuvs() const;
38 };
39 
40 } // end namespace ML::kmeans
DistanceType
Definition: distance_type.hpp:10
Definition: kmeans.hpp:16
Definition: kmeans_params.hpp:14
Definition: kmeans_params.hpp:22
bool inertia_check
Definition: kmeans_params.hpp:35
int n_init
Definition: kmeans_params.hpp:31
int batch_samples
Definition: kmeans_params.hpp:33
double oversampling_factor
Definition: kmeans_params.hpp:32
int n_clusters
Definition: kmeans_params.hpp:25
ML::distance::DistanceType metric
Definition: kmeans_params.hpp:24
double tol
Definition: kmeans_params.hpp:28
InitMethod
Definition: kmeans_params.hpp:23
raft::random::RngState rng_state
Definition: kmeans_params.hpp:30
InitMethod init
Definition: kmeans_params.hpp:26
int batch_centroids
Definition: kmeans_params.hpp:34
cuvs::cluster::kmeans::params to_cuvs() const
Definition: kmeans_params.cpp:13
int max_iter
Definition: kmeans_params.hpp:27
rapids_logger::level_enum verbosity
Definition: kmeans_params.hpp:29