kmeans.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 namespace raft {
22 class handle_t;
23 }
24 
25 namespace ML {
26 
27 namespace kmeans {
28 
53 void fit_predict(const raft::handle_t& handle,
54  const KMeansParams& params,
55  const float* X,
56  int n_samples,
57  int n_features,
58  const float* sample_weight,
59  float* centroids,
60  int* labels,
61  float& inertia,
62  int& n_iter);
63 
64 void fit_predict(const raft::handle_t& handle,
65  const KMeansParams& params,
66  const double* X,
67  int n_samples,
68  int n_features,
69  const double* sample_weight,
70  double* centroids,
71  int* labels,
72  double& inertia,
73  int& n_iter);
74 void fit_predict(const raft::handle_t& handle,
75  const KMeansParams& params,
76  const float* X,
77  int64_t n_samples,
78  int64_t n_features,
79  const float* sample_weight,
80  float* centroids,
81  int64_t* labels,
82  float& inertia,
83  int64_t& n_iter);
84 
85 void fit_predict(const raft::handle_t& handle,
86  const KMeansParams& params,
87  const double* X,
88  int64_t n_samples,
89  int64_t n_features,
90  const double* sample_weight,
91  double* centroids,
92  int64_t* labels,
93  double& inertia,
94  int64_t& n_iter);
95 
118 void predict(const raft::handle_t& handle,
119  const KMeansParams& params,
120  const float* centroids,
121  const float* X,
122  int n_samples,
123  int n_features,
124  const float* sample_weight,
125  bool normalize_weights,
126  int* labels,
127  float& inertia);
128 
129 void predict(const raft::handle_t& handle,
130  const KMeansParams& params,
131  const double* centroids,
132  const double* X,
133  int n_samples,
134  int n_features,
135  const double* sample_weight,
136  bool normalize_weights,
137  int* labels,
138  double& inertia);
139 void predict(const raft::handle_t& handle,
140  const KMeansParams& params,
141  const float* centroids,
142  const float* X,
143  int64_t n_samples,
144  int64_t n_features,
145  const float* sample_weight,
146  bool normalize_weights,
147  int64_t* labels,
148  float& inertia);
149 
150 void predict(const raft::handle_t& handle,
151  const KMeansParams& params,
152  const double* centroids,
153  const double* X,
154  int64_t n_samples,
155  int64_t n_features,
156  const double* sample_weight,
157  bool normalize_weights,
158  int64_t* labels,
159  double& inertia);
177 void transform(const raft::handle_t& handle,
178  const KMeansParams& params,
179  const float* centroids,
180  const float* X,
181  int n_samples,
182  int n_features,
183  float* X_new);
184 
185 void transform(const raft::handle_t& handle,
186  const KMeansParams& params,
187  const double* centroids,
188  const double* X,
189  int n_samples,
190  int n_features,
191  double* X_new);
192 void transform(const raft::handle_t& handle,
193  const KMeansParams& params,
194  const float* centroids,
195  const float* X,
196  int64_t n_samples,
197  int64_t n_features,
198  float* X_new);
199 
200 void transform(const raft::handle_t& handle,
201  const KMeansParams& params,
202  const double* centroids,
203  const double* X,
204  int64_t n_samples,
205  int64_t n_features,
206  double* X_new);
207 }; // end namespace kmeans
208 }; // end namespace ML
Definition: params.hpp:34
void fit_predict(const raft::handle_t &handle, const KMeansParams ¶ms, const float *X, int n_samples, int n_features, const float *sample_weight, float *centroids, int *labels, float &inertia, int &n_iter)
Compute k-means clustering and predicts cluster index for each sample in the input.
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
void predict(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, const float *sample_weight, bool normalize_weights, int *labels, float &inertia)
Predict the closest cluster each sample in X belongs to.
Definition: dbscan.hpp:29
Definition: dbscan.hpp:25
Definition: kmeans_params.hpp:33