metrics.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/distance/distance_types.hpp>
20 
21 #include <cstdint>
22 
23 namespace raft {
24 class handle_t;
25 }
26 
27 namespace ML {
28 
29 namespace Metrics {
30 
47 float r2_score_py(const raft::handle_t& handle, float* y, float* y_hat, int n);
48 
65 double r2_score_py(const raft::handle_t& handle, double* y, double* y_hat, int n);
66 
79 double rand_index(const raft::handle_t& handle, double* y, double* y_hat, int n);
80 
101 double silhouette_score(const raft::handle_t& handle,
102  double* y,
103  int nRows,
104  int nCols,
105  int* labels,
106  int nLabels,
107  double* silScores,
108  raft::distance::DistanceType metric);
109 
110 namespace Batched {
133 float silhouette_score(const raft::handle_t& handle,
134  float* X,
135  int n_rows,
136  int n_cols,
137  int* y,
138  int n_labels,
139  float* scores,
140  int chunk,
141  raft::distance::DistanceType metric);
142 double silhouette_score(const raft::handle_t& handle,
143  double* X,
144  int n_rows,
145  int n_cols,
146  int* y,
147  int n_labels,
148  double* scores,
149  int chunk,
150  raft::distance::DistanceType metric);
151 
152 } // namespace Batched
165 double adjusted_rand_index(const raft::handle_t& handle,
166  const int64_t* y,
167  const int64_t* y_hat,
168  const int64_t n);
169 double adjusted_rand_index(const raft::handle_t& handle,
170  const int* y,
171  const int* y_hat,
172  const int n);
189 double kl_divergence(const raft::handle_t& handle, const double* y, const double* y_hat, int n);
190 
205 float kl_divergence(const raft::handle_t& handle, const float* y, const float* y_hat, int n);
206 
219 double entropy(const raft::handle_t& handle,
220  const int* y,
221  const int n,
222  const int lower_class_range,
223  const int upper_class_range);
224 
239 double mutual_info_score(const raft::handle_t& handle,
240  const int* y,
241  const int* y_hat,
242  const int n,
243  const int lower_class_range,
244  const int upper_class_range);
245 
260 double homogeneity_score(const raft::handle_t& handle,
261  const int* y,
262  const int* y_hat,
263  const int n,
264  const int lower_class_range,
265  const int upper_class_range);
266 
281 double completeness_score(const raft::handle_t& handle,
282  const int* y,
283  const int* y_hat,
284  const int n,
285  const int lower_class_range,
286  const int upper_class_range);
287 
303 double v_measure(const raft::handle_t& handle,
304  const int* y,
305  const int* y_hat,
306  const int n,
307  const int lower_class_range,
308  const int upper_class_range,
309  double beta);
310 
323 float accuracy_score_py(const raft::handle_t& handle,
324  const int* predictions,
325  const int* ref_predictions,
326  int n);
327 
345 void pairwise_distance(const raft::handle_t& handle,
346  const double* x,
347  const double* y,
348  double* dist,
349  int m,
350  int n,
351  int k,
352  raft::distance::DistanceType metric,
353  bool isRowMajor = true,
354  double metric_arg = 2.0);
355 
372 void pairwise_distance(const raft::handle_t& handle,
373  const float* x,
374  const float* y,
375  float* dist,
376  int m,
377  int n,
378  int k,
379  raft::distance::DistanceType metric,
380  bool isRowMajor = true,
381  float metric_arg = 2.0f);
382 
383 void pairwiseDistance_sparse(const raft::handle_t& handle,
384  double* x,
385  double* y,
386  double* dist,
387  int x_nrows,
388  int y_nrows,
389  int n_cols,
390  int x_nnz,
391  int y_nnz,
392  int* x_indptr,
393  int* y_indptr,
394  int* x_indices,
395  int* y_indices,
396  raft::distance::DistanceType metric,
397  float metric_arg);
398 void pairwiseDistance_sparse(const raft::handle_t& handle,
399  float* x,
400  float* y,
401  float* dist,
402  int x_nrows,
403  int y_nrows,
404  int n_cols,
405  int x_nnz,
406  int y_nnz,
407  int* x_indptr,
408  int* y_indptr,
409  int* x_indices,
410  int* y_indices,
411  raft::distance::DistanceType metric,
412  float metric_arg);
413 
428 template <typename math_t, raft::distance::DistanceType distance_type>
429 double trustworthiness_score(const raft::handle_t& h,
430  const math_t* X,
431  math_t* X_embedded,
432  int n,
433  int m,
434  int d,
435  int n_neighbors,
436  int batchSize = 512);
437 
438 } // namespace Metrics
439 } // namespace ML
float silhouette_score(const raft::handle_t &handle, float *X, int n_rows, int n_cols, int *y, int n_labels, float *scores, int chunk, raft::distance::DistanceType metric)
double homogeneity_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
double mutual_info_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
void pairwise_distance(const raft::handle_t &handle, const double *x, const double *y, double *dist, int m, int n, int k, raft::distance::DistanceType metric, bool isRowMajor=true, double metric_arg=2.0)
Calculates the ij pairwise distances between two input arrays of double type.
double completeness_score(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range)
double rand_index(const raft::handle_t &handle, double *y, double *y_hat, int n)
double v_measure(const raft::handle_t &handle, const int *y, const int *y_hat, const int n, const int lower_class_range, const int upper_class_range, double beta)
float accuracy_score_py(const raft::handle_t &handle, const int *predictions, const int *ref_predictions, int n)
void pairwiseDistance_sparse(const raft::handle_t &handle, double *x, double *y, double *dist, int x_nrows, int y_nrows, int n_cols, int x_nnz, int y_nnz, int *x_indptr, int *y_indptr, int *x_indices, int *y_indices, raft::distance::DistanceType metric, float metric_arg)
double entropy(const raft::handle_t &handle, const int *y, const int n, const int lower_class_range, const int upper_class_range)
double kl_divergence(const raft::handle_t &handle, const double *y, const double *y_hat, int n)
double adjusted_rand_index(const raft::handle_t &handle, const int64_t *y, const int64_t *y_hat, const int64_t n)
double silhouette_score(const raft::handle_t &handle, double *y, int nRows, int nCols, int *labels, int nLabels, double *silScores, raft::distance::DistanceType metric)
double trustworthiness_score(const raft::handle_t &h, const math_t *X, math_t *X_embedded, int n, int m, int d, int n_neighbors, int batchSize=512)
Compute the trustworthiness score.
float r2_score_py(const raft::handle_t &handle, float *y, float *y_hat, int n)
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26