metrics.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <cstdint>
11 
12 namespace raft {
13 class handle_t;
14 }
15 
16 namespace ML {
17 
18 namespace Metrics {
19 
36 float r2_score_py(const raft::handle_t& handle, float* y, float* y_hat, int n);
37 
54 double r2_score_py(const raft::handle_t& handle, double* y, double* y_hat, int n);
55 
68 double rand_index(const raft::handle_t& handle, double* y, double* y_hat, int n);
69 
90 double silhouette_score(const raft::handle_t& handle,
91  double* y,
92  int nRows,
93  int nCols,
94  int* labels,
95  int nLabels,
96  double* silScores,
98 
99 namespace Batched {
122 float silhouette_score(const raft::handle_t& handle,
123  float* X,
124  int n_rows,
125  int n_cols,
126  int* y,
127  int n_labels,
128  float* scores,
129  int chunk,
131 double silhouette_score(const raft::handle_t& handle,
132  double* X,
133  int n_rows,
134  int n_cols,
135  int* y,
136  int n_labels,
137  double* scores,
138  int chunk,
140 
141 } // namespace Batched
154 double adjusted_rand_index(const raft::handle_t& handle,
155  const int64_t* y,
156  const int64_t* y_hat,
157  const int64_t n);
158 double adjusted_rand_index(const raft::handle_t& handle,
159  const int* y,
160  const int* y_hat,
161  const int n);
178 double kl_divergence(const raft::handle_t& handle, const double* y, const double* y_hat, int n);
179 
194 float kl_divergence(const raft::handle_t& handle, const float* y, const float* y_hat, int n);
195 
208 double entropy(const raft::handle_t& handle,
209  const int* y,
210  const int n,
211  const int lower_class_range,
212  const int upper_class_range);
213 
228 double mutual_info_score(const raft::handle_t& handle,
229  const int* y,
230  const int* y_hat,
231  const int n,
232  const int lower_class_range,
233  const int upper_class_range);
234 
249 double homogeneity_score(const raft::handle_t& handle,
250  const int* y,
251  const int* y_hat,
252  const int n,
253  const int lower_class_range,
254  const int upper_class_range);
255 
270 double completeness_score(const raft::handle_t& handle,
271  const int* y,
272  const int* y_hat,
273  const int n,
274  const int lower_class_range,
275  const int upper_class_range);
276 
292 double v_measure(const raft::handle_t& handle,
293  const int* y,
294  const int* y_hat,
295  const int n,
296  const int lower_class_range,
297  const int upper_class_range,
298  double beta);
299 
312 float accuracy_score_py(const raft::handle_t& handle,
313  const int* predictions,
314  const int* ref_predictions,
315  int n);
316 
334 void pairwise_distance(const raft::handle_t& handle,
335  const double* x,
336  const double* y,
337  double* dist,
338  int m,
339  int n,
340  int k,
342  bool isRowMajor = true,
343  double metric_arg = 2.0);
344 
361 void pairwise_distance(const raft::handle_t& handle,
362  const float* x,
363  const float* y,
364  float* dist,
365  int m,
366  int n,
367  int k,
369  bool isRowMajor = true,
370  float metric_arg = 2.0f);
371 
372 void pairwiseDistance_sparse(const raft::handle_t& handle,
373  double* x,
374  double* y,
375  double* dist,
376  int x_nrows,
377  int y_nrows,
378  int n_cols,
379  int x_nnz,
380  int y_nnz,
381  int* x_indptr,
382  int* y_indptr,
383  int* x_indices,
384  int* y_indices,
386  float metric_arg);
387 void pairwiseDistance_sparse(const raft::handle_t& handle,
388  float* x,
389  float* y,
390  float* dist,
391  int x_nrows,
392  int y_nrows,
393  int n_cols,
394  int x_nnz,
395  int y_nnz,
396  int* x_indptr,
397  int* y_indptr,
398  int* x_indices,
399  int* y_indices,
401  float metric_arg);
402 
417 template <typename math_t, ML::distance::DistanceType distance_type>
418 double trustworthiness_score(const raft::handle_t& h,
419  const math_t* X,
420  math_t* X_embedded,
421  int n,
422  int m,
423  int d,
424  int n_neighbors,
425  int batchSize = 512);
426 
427 } // namespace Metrics
428 } // namespace ML
float silhouette_score(const raft::handle_t &handle, float *X, int n_rows, int n_cols, int *y, int n_labels, float *scores, int chunk, ML::distance::DistanceType metric)
void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y, double *dist, int x_nrows, int y_nrows, int n_cols, int x_nnz, int y_nnz, int *x_indptr, int *y_indptr, int *x_indices, int *y_indices, ML::distance::DistanceType metric, float metric_arg)
double homogeneity_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
double mutual_info_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
double completeness_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
double rand_index(const raft::handle_t &handle, double *y, double *y_hat, int n)
double v_measure(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range, double beta)
float accuracy_score_py(const raft::handle_t &handle, const int *predictions, const int *ref_predictions, int n)
double silhouette_score(const raft::handle_t &handle, double *y, int nRows, int nCols, int *labels, int nLabels, double *silScores, ML::distance::DistanceType metric)
void pairwise_distance(const raft::handle_t &handle, const double *x, const double *y, double *dist, int m, int n, int k, ML::distance::DistanceType metric, bool isRowMajor=true, double metric_arg=2.0)
Calculates the ij pairwise distances between two input arrays of double type.
double entropy(const raft::handle_t &handle, const int *y, const int n, const int lower_class_range, const int upper_class_range)
double kl_divergence(const raft::handle_t &handle, const double *y, const double *y_hat, int n)
double adjusted_rand_index(const raft::handle_t &handle, const int64_t *y, const int64_t *y_hat, const int64_t n)
double trustworthiness_score(const raft::handle_t &h, const math_t *X, math_t *X_embedded, int n, int m, int d, int n_neighbors, int batchSize=512)
Compute the trustworthiness score.
float r2_score_py(const raft::handle_t &handle, float *y, float *y_hat, int n)
DistanceType
Definition: distance_type.hpp:10
Definition: dbscan.hpp:18
Definition: dbscan.hpp:14