Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
kmeans.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuvs/cluster/kmeans.hpp>
20 
21 namespace raft {
22 class handle_t;
23 }
24 
25 namespace ML {
26 
27 namespace kmeans {
28 
29 using KMeansParams = cuvs::cluster::kmeans::params;
30 
55 void fit_predict(const raft::handle_t& handle,
56  const KMeansParams& params,
57  const float* X,
58  int n_samples,
59  int n_features,
60  const float* sample_weight,
61  float* centroids,
62  int* labels,
63  float& inertia,
64  int& n_iter);
65 
66 void fit_predict(const raft::handle_t& handle,
67  const KMeansParams& params,
68  const double* X,
69  int n_samples,
70  int n_features,
71  const double* sample_weight,
72  double* centroids,
73  int* labels,
74  double& inertia,
75  int& n_iter);
76 void fit_predict(const raft::handle_t& handle,
77  const KMeansParams& params,
78  const float* X,
79  int64_t n_samples,
80  int64_t n_features,
81  const float* sample_weight,
82  float* centroids,
83  int64_t* labels,
84  float& inertia,
85  int64_t& n_iter);
86 
87 void fit_predict(const raft::handle_t& handle,
88  const KMeansParams& params,
89  const double* X,
90  int64_t n_samples,
91  int64_t n_features,
92  const double* sample_weight,
93  double* centroids,
94  int64_t* labels,
95  double& inertia,
96  int64_t& n_iter);
97 
120 void predict(const raft::handle_t& handle,
121  const KMeansParams& params,
122  const float* centroids,
123  const float* X,
124  int n_samples,
125  int n_features,
126  const float* sample_weight,
127  bool normalize_weights,
128  int* labels,
129  float& inertia);
130 
131 void predict(const raft::handle_t& handle,
132  const KMeansParams& params,
133  const double* centroids,
134  const double* X,
135  int n_samples,
136  int n_features,
137  const double* sample_weight,
138  bool normalize_weights,
139  int* labels,
140  double& inertia);
141 void predict(const raft::handle_t& handle,
142  const KMeansParams& params,
143  const float* centroids,
144  const float* X,
145  int64_t n_samples,
146  int64_t n_features,
147  const float* sample_weight,
148  bool normalize_weights,
149  int64_t* labels,
150  float& inertia);
151 
152 void predict(const raft::handle_t& handle,
153  const KMeansParams& params,
154  const double* centroids,
155  const double* X,
156  int64_t n_samples,
157  int64_t n_features,
158  const double* sample_weight,
159  bool normalize_weights,
160  int64_t* labels,
161  double& inertia);
179 void transform(const raft::handle_t& handle,
180  const KMeansParams& params,
181  const float* centroids,
182  const float* X,
183  int n_samples,
184  int n_features,
185  float* X_new);
186 
187 void transform(const raft::handle_t& handle,
188  const KMeansParams& params,
189  const double* centroids,
190  const double* X,
191  int n_samples,
192  int n_features,
193  double* X_new);
194 void transform(const raft::handle_t& handle,
195  const KMeansParams& params,
196  const float* centroids,
197  const float* X,
198  int64_t n_samples,
199  int64_t n_features,
200  float* X_new);
201 
202 void transform(const raft::handle_t& handle,
203  const KMeansParams& params,
204  const double* centroids,
205  const double* X,
206  int64_t n_samples,
207  int64_t n_features,
208  double* X_new);
209 }; // end namespace kmeans
210 }; // end namespace ML
Definition: params.hpp:34
void fit_predict(const raft::handle_t &handle, const KMeansParams ¶ms, const float *X, int n_samples, int n_features, const float *sample_weight, float *centroids, int *labels, float &inertia, int &n_iter)
Compute k-means clustering and predicts cluster index for each sample in the input.
cuvs::cluster::kmeans::params KMeansParams
Definition: kmeans.hpp:29
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
void predict(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, const float *sample_weight, bool normalize_weights, int *labels, float &inertia)
Predict the closest cluster each sample in X belongs to.
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26