knn.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <cstdint>
11 #include <memory>
12 #include <vector>
13 
14 namespace raft {
15 class handle_t;
16 }
17 
18 namespace ML {
19 
43 void brute_force_knn(const raft::handle_t& handle,
44  std::vector<float*>& input,
45  std::vector<int>& sizes,
46  int D,
47  float* search_items,
48  int n,
49  int64_t* res_I,
50  float* res_D,
51  int k,
52  bool rowMajorIndex = false,
53  bool rowMajorQuery = false,
55  float metric_arg = 2.0f,
56  std::vector<int64_t>* translations = nullptr);
57 
58 void rbc_build_index(const raft::handle_t& handle,
59  std::uintptr_t& rbc_index,
60  float* X,
61  int64_t n_rows,
62  int64_t n_cols,
64 
65 void rbc_knn_query(const raft::handle_t& handle,
66  const std::uintptr_t& rbc_index,
67  uint32_t k,
68  const float* search_items,
69  uint32_t n_search_items,
70  int64_t dim,
71  int64_t* out_inds,
72  float* out_dists);
73 
79 void rbc_free_index(std::uintptr_t rbc_index);
80 
81 struct knnIndexImpl;
82 
83 struct knnIndex {
86 
88  float metricArg;
89  int nprobe;
90  int device;
91 
92  std::unique_ptr<knnIndexImpl> pimpl;
93 };
94 
95 struct knnIndexParam {
96  virtual ~knnIndexParam() {}
97 };
98 
100  int nlist;
101  int nprobe;
102 };
103 
104 struct IVFFlatParam : IVFParam {};
105 
107  int M;
108  int n_bits;
110 };
111 
125 void approx_knn_build_index(raft::handle_t& handle,
126  knnIndex* index,
129  float metricArg,
130  float* index_array,
131  int n,
132  int D);
133 
147 void approx_knn_search(raft::handle_t& handle,
148  float* distances,
149  int64_t* indices,
150  knnIndex* index,
151  int k,
152  float* query_array,
153  int n);
154 
171 void knn_classify(raft::handle_t& handle,
172  int* out,
173  int64_t* knn_indices,
174  std::vector<int*>& y,
175  size_t n_index_rows,
176  size_t n_query_rows,
177  int k,
178  float* sample_weight = nullptr);
179 
196 void knn_regress(raft::handle_t& handle,
197  float* out,
198  int64_t* knn_indices,
199  std::vector<float*>& y,
200  size_t n_index_rows,
201  size_t n_query_rows,
202  int k,
203  float* sample_weight = nullptr);
204 
221 void knn_class_proba(raft::handle_t& handle,
222  std::vector<float*>& out,
223  int64_t* knn_indices,
224  std::vector<int*>& y,
225  size_t n_index_rows,
226  size_t n_query_rows,
227  int k,
228  float* sample_weight = nullptr);
229 }; // namespace ML
Definition: params.hpp:23
DistanceType
Definition: distance_type.hpp:10
Definition: dbscan.hpp:18
void rbc_knn_query(const raft::handle_t &handle, const std::uintptr_t &rbc_index, uint32_t k, const float *search_items, uint32_t n_search_items, int64_t dim, int64_t *out_inds, float *out_dists)
void knn_class_proba(raft::handle_t &handle, std::vector< float * > &out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k, float *sample_weight=nullptr)
Flat C++ API function to compute knn class probabilities using a vector of device arrays containing d...
void rbc_free_index(std::uintptr_t rbc_index)
Free the RBC index.
void brute_force_knn(const raft::handle_t &handle, std::vector< float * > &input, std::vector< int > &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex=false, bool rowMajorQuery=false, ML::distance::DistanceType metric=ML::distance::DistanceType::L2Expanded, float metric_arg=2.0f, std::vector< int64_t > *translations=nullptr)
Flat C++ API function to perform a brute force knn on a series of input arrays and combine the result...
void rbc_build_index(const raft::handle_t &handle, std::uintptr_t &rbc_index, float *X, int64_t n_rows, int64_t n_cols, ML::distance::DistanceType metric)
void approx_knn_search(raft::handle_t &handle, float *distances, int64_t *indices, knnIndex *index, int k, float *query_array, int n)
Flat C++ API function to perform an approximate nearest neighbors search from previously built index ...
void knn_regress(raft::handle_t &handle, float *out, int64_t *knn_indices, std::vector< float * > &y, size_t n_index_rows, size_t n_query_rows, int k, float *sample_weight=nullptr)
Flat C++ API function to perform a knn regression using a given a vector of label arrays....
void approx_knn_build_index(raft::handle_t &handle, knnIndex *index, knnIndexParam *params, ML::distance::DistanceType metric, float metricArg, float *index_array, int n, int D)
Flat C++ API function to build an approximate nearest neighbors index from an index array and a set o...
void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k, float *sample_weight=nullptr)
Flat C++ API function to perform a knn classification using a given a vector of label arrays....
Definition: dbscan.hpp:14
Definition: knn.hpp:104
Definition: knn.hpp:106
int M
Definition: knn.hpp:107
int n_bits
Definition: knn.hpp:108
bool usePrecomputedTables
Definition: knn.hpp:109
Definition: knn.hpp:99
int nprobe
Definition: knn.hpp:101
int nlist
Definition: knn.hpp:100
Definition: knn.hpp:95
virtual ~knnIndexParam()
Definition: knn.hpp:96
Definition: knn.hpp:83
int nprobe
Definition: knn.hpp:89
std::unique_ptr< knnIndexImpl > pimpl
Definition: knn.hpp:92
int device
Definition: knn.hpp:90
float metricArg
Definition: knn.hpp:88
ML::distance::DistanceType metric
Definition: knn.hpp:87