kmeans.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 namespace raft {
22 class handle_t;
23 }
24 
25 namespace ML {
26 
27 namespace kmeans {
28 
50 void fit(const raft::handle_t& handle,
51  const KMeansParams& params,
52  const float* X,
53  int n_samples,
54  int n_features,
55  const float* sample_weight,
56  float* centroids,
57  float& inertia,
58  int& n_iter);
59 
60 void fit(const raft::handle_t& handle,
61  const KMeansParams& params,
62  const double* X,
63  int n_samples,
64  int n_features,
65  const double* sample_weight,
66  double* centroids,
67  double& inertia,
68  int& n_iter);
69 
70 void fit(const raft::handle_t& handle,
71  const KMeansParams& params,
72  const float* X,
73  int64_t n_samples,
74  int64_t n_features,
75  const float* sample_weight,
76  float* centroids,
77  float& inertia,
78  int64_t& n_iter);
79 
80 void fit(const raft::handle_t& handle,
81  const KMeansParams& params,
82  const double* X,
83  int64_t n_samples,
84  int64_t n_features,
85  const double* sample_weight,
86  double* centroids,
87  double& inertia,
88  int64_t& n_iter);
89 
112 void predict(const raft::handle_t& handle,
113  const KMeansParams& params,
114  const float* centroids,
115  const float* X,
116  int n_samples,
117  int n_features,
118  const float* sample_weight,
119  bool normalize_weights,
120  int* labels,
121  float& inertia);
122 
123 void predict(const raft::handle_t& handle,
124  const KMeansParams& params,
125  const double* centroids,
126  const double* X,
127  int n_samples,
128  int n_features,
129  const double* sample_weight,
130  bool normalize_weights,
131  int* labels,
132  double& inertia);
133 void predict(const raft::handle_t& handle,
134  const KMeansParams& params,
135  const float* centroids,
136  const float* X,
137  int64_t n_samples,
138  int64_t n_features,
139  const float* sample_weight,
140  bool normalize_weights,
141  int64_t* labels,
142  float& inertia);
143 
144 void predict(const raft::handle_t& handle,
145  const KMeansParams& params,
146  const double* centroids,
147  const double* X,
148  int64_t n_samples,
149  int64_t n_features,
150  const double* sample_weight,
151  bool normalize_weights,
152  int64_t* labels,
153  double& inertia);
171 void transform(const raft::handle_t& handle,
172  const KMeansParams& params,
173  const float* centroids,
174  const float* X,
175  int n_samples,
176  int n_features,
177  float* X_new);
178 
179 void transform(const raft::handle_t& handle,
180  const KMeansParams& params,
181  const double* centroids,
182  const double* X,
183  int n_samples,
184  int n_features,
185  double* X_new);
186 void transform(const raft::handle_t& handle,
187  const KMeansParams& params,
188  const float* centroids,
189  const float* X,
190  int64_t n_samples,
191  int64_t n_features,
192  float* X_new);
193 
194 void transform(const raft::handle_t& handle,
195  const KMeansParams& params,
196  const double* centroids,
197  const double* X,
198  int64_t n_samples,
199  int64_t n_features,
200  double* X_new);
201 }; // end namespace kmeans
202 }; // end namespace ML
Definition: params.hpp:34
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
void fit(const raft::handle_t &handle, const KMeansParams ¶ms, const float *X, int n_samples, int n_features, const float *sample_weight, float *centroids, float &inertia, int &n_iter)
Compute k-means clustering for each sample in the input.
void predict(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, const float *sample_weight, bool normalize_weights, int *labels, float &inertia)
Predict the closest cluster each sample in X belongs to.
Definition: dbscan.hpp:29
Definition: dbscan.hpp:25
Definition: kmeans_params.hpp:33