kmeans.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <cuvs/cluster/kmeans.hpp>
22 
23 namespace raft {
24 class handle_t;
25 }
26 
27 namespace ML {
28 
29 namespace kmeans {
30 
31 using KMeansParams = cuvs::cluster::kmeans::params;
32 
57 void fit_predict(const raft::handle_t& handle,
58  const KMeansParams& params,
59  const float* X,
60  int n_samples,
61  int n_features,
62  const float* sample_weight,
63  float* centroids,
64  int* labels,
65  float& inertia,
66  int& n_iter);
67 
68 void fit_predict(const raft::handle_t& handle,
69  const KMeansParams& params,
70  const double* X,
71  int n_samples,
72  int n_features,
73  const double* sample_weight,
74  double* centroids,
75  int* labels,
76  double& inertia,
77  int& n_iter);
78 void fit_predict(const raft::handle_t& handle,
79  const KMeansParams& params,
80  const float* X,
81  int64_t n_samples,
82  int64_t n_features,
83  const float* sample_weight,
84  float* centroids,
85  int64_t* labels,
86  float& inertia,
87  int64_t& n_iter);
88 
89 void fit_predict(const raft::handle_t& handle,
90  const KMeansParams& params,
91  const double* X,
92  int64_t n_samples,
93  int64_t n_features,
94  const double* sample_weight,
95  double* centroids,
96  int64_t* labels,
97  double& inertia,
98  int64_t& n_iter);
99 
122 void predict(const raft::handle_t& handle,
123  const KMeansParams& params,
124  const float* centroids,
125  const float* X,
126  int n_samples,
127  int n_features,
128  const float* sample_weight,
129  bool normalize_weights,
130  int* labels,
131  float& inertia);
132 
133 void predict(const raft::handle_t& handle,
134  const KMeansParams& params,
135  const double* centroids,
136  const double* X,
137  int n_samples,
138  int n_features,
139  const double* sample_weight,
140  bool normalize_weights,
141  int* labels,
142  double& inertia);
143 void predict(const raft::handle_t& handle,
144  const KMeansParams& params,
145  const float* centroids,
146  const float* X,
147  int64_t n_samples,
148  int64_t n_features,
149  const float* sample_weight,
150  bool normalize_weights,
151  int64_t* labels,
152  float& inertia);
153 
154 void predict(const raft::handle_t& handle,
155  const KMeansParams& params,
156  const double* centroids,
157  const double* X,
158  int64_t n_samples,
159  int64_t n_features,
160  const double* sample_weight,
161  bool normalize_weights,
162  int64_t* labels,
163  double& inertia);
181 void transform(const raft::handle_t& handle,
182  const KMeansParams& params,
183  const float* centroids,
184  const float* X,
185  int n_samples,
186  int n_features,
187  float* X_new);
188 
189 void transform(const raft::handle_t& handle,
190  const KMeansParams& params,
191  const double* centroids,
192  const double* X,
193  int n_samples,
194  int n_features,
195  double* X_new);
196 void transform(const raft::handle_t& handle,
197  const KMeansParams& params,
198  const float* centroids,
199  const float* X,
200  int64_t n_samples,
201  int64_t n_features,
202  float* X_new);
203 
204 void transform(const raft::handle_t& handle,
205  const KMeansParams& params,
206  const double* centroids,
207  const double* X,
208  int64_t n_samples,
209  int64_t n_features,
210  double* X_new);
211 }; // end namespace kmeans
212 }; // end namespace ML
Definition: params.hpp:34
void fit_predict(const raft::handle_t &handle, const KMeansParams &params, const float *X, int n_samples, int n_features, const float *sample_weight, float *centroids, int *labels, float &inertia, int &n_iter)
Compute k-means clustering and predicts cluster index for each sample in the input.
cuvs::cluster::kmeans::params KMeansParams
Definition: kmeans.hpp:31
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
void predict(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, const float *sample_weight, bool normalize_weights, int *labels, float &inertia)
Predict the closest cluster each sample in X belongs to.
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26