knn.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/spatial/knn/detail/processing.hpp> // MetricProcessor
20 
21 #include <cuvs/distance/distance.hpp>
22 #include <cuvs/neighbors/ivf_flat.hpp>
23 #include <cuvs/neighbors/ivf_pq.hpp>
24 
25 namespace raft {
26 class handle_t;
27 }
28 
29 namespace ML {
30 
54 void brute_force_knn(const raft::handle_t& handle,
55  std::vector<float*>& input,
56  std::vector<int>& sizes,
57  int D,
58  float* search_items,
59  int n,
60  int64_t* res_I,
61  float* res_D,
62  int k,
63  bool rowMajorIndex = false,
64  bool rowMajorQuery = false,
65  cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded,
66  float metric_arg = 2.0f,
67  std::vector<int64_t>* translations = nullptr);
68 
69 void rbc_build_index(const raft::handle_t& handle,
70  std::uintptr_t& rbc_index,
71  float* X,
72  int64_t n_rows,
73  int64_t n_cols,
74  cuvs::distance::DistanceType metric);
75 
76 void rbc_knn_query(const raft::handle_t& handle,
77  const std::uintptr_t& rbc_index,
78  uint32_t k,
79  const float* search_items,
80  uint32_t n_search_items,
81  int64_t dim,
82  int64_t* out_inds,
83  float* out_dists);
84 
90 void rbc_free_index(std::uintptr_t rbc_index);
91 
92 struct knnIndex {
93  cuvs::distance::DistanceType metric;
94  float metricArg;
95  int nprobe;
96  std::unique_ptr<raft::spatial::knn::MetricProcessor<float>> metric_processor;
97 
98  std::unique_ptr<cuvs::neighbors::ivf_flat::index<float, int64_t>> ivf_flat;
99  std::unique_ptr<cuvs::neighbors::ivf_pq::index<int64_t>> ivf_pq;
100 
101  int device;
102 };
103 
105  virtual ~knnIndexParam() {}
106 };
107 
109  int nlist;
110  int nprobe;
111 };
112 
113 struct IVFFlatParam : IVFParam {};
114 
116  int M;
117  int n_bits;
119 };
120 
134 void approx_knn_build_index(raft::handle_t& handle,
135  knnIndex* index,
137  cuvs::distance::DistanceType metric,
138  float metricArg,
139  float* index_array,
140  int n,
141  int D);
142 
156 void approx_knn_search(raft::handle_t& handle,
157  float* distances,
158  int64_t* indices,
159  knnIndex* index,
160  int k,
161  float* query_array,
162  int n);
163 
178 void knn_classify(raft::handle_t& handle,
179  int* out,
180  int64_t* knn_indices,
181  std::vector<int*>& y,
182  size_t n_index_rows,
183  size_t n_query_rows,
184  int k);
185 
200 void knn_regress(raft::handle_t& handle,
201  float* out,
202  int64_t* knn_indices,
203  std::vector<float*>& y,
204  size_t n_index_rows,
205  size_t n_query_rows,
206  int k);
207 
222 void knn_class_proba(raft::handle_t& handle,
223  std::vector<float*>& out,
224  int64_t* knn_indices,
225  std::vector<int*>& y,
226  size_t n_index_rows,
227  size_t n_query_rows,
228  int k);
229 }; // namespace ML
Definition: params.hpp:34
Definition: dbscan.hpp:30
void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn classification using a given a vector of label arrays....
void rbc_knn_query(const raft::handle_t &handle, const std::uintptr_t &rbc_index, uint32_t k, const float *search_items, uint32_t n_search_items, int64_t dim, int64_t *out_inds, float *out_dists)
void brute_force_knn(const raft::handle_t &handle, std::vector< float * > &input, std::vector< int > &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex=false, bool rowMajorQuery=false, cuvs::distance::DistanceType metric=cuvs::distance::DistanceType::L2Expanded, float metric_arg=2.0f, std::vector< int64_t > *translations=nullptr)
Flat C++ API function to perform a brute force knn on a series of input arrays and combine the result...
void rbc_free_index(std::uintptr_t rbc_index)
Free the RBC index.
void knn_class_proba(raft::handle_t &handle, std::vector< float * > &out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to compute knn class probabilities using a vector of device arrays containing d...
void knn_regress(raft::handle_t &handle, float *out, int64_t *knn_indices, std::vector< float * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn regression using a given a vector of label arrays....
void rbc_build_index(const raft::handle_t &handle, std::uintptr_t &rbc_index, float *X, int64_t n_rows, int64_t n_cols, cuvs::distance::DistanceType metric)
void approx_knn_search(raft::handle_t &handle, float *distances, int64_t *indices, knnIndex *index, int k, float *query_array, int n)
Flat C++ API function to perform an approximate nearest neighbors search from previously built index ...
void approx_knn_build_index(raft::handle_t &handle, knnIndex *index, knnIndexParam *params, cuvs::distance::DistanceType metric, float metricArg, float *index_array, int n, int D)
Flat C++ API function to build an approximate nearest neighbors index from an index array and a set o...
Definition: dbscan.hpp:26
Definition: knn.hpp:113
Definition: knn.hpp:115
int M
Definition: knn.hpp:116
int n_bits
Definition: knn.hpp:117
bool usePrecomputedTables
Definition: knn.hpp:118
Definition: knn.hpp:108
int nprobe
Definition: knn.hpp:110
int nlist
Definition: knn.hpp:109
Definition: knn.hpp:104
virtual ~knnIndexParam()
Definition: knn.hpp:105
Definition: knn.hpp:92
int nprobe
Definition: knn.hpp:95
cuvs::distance::DistanceType metric
Definition: knn.hpp:93
int device
Definition: knn.hpp:101
std::unique_ptr< cuvs::neighbors::ivf_pq::index< int64_t > > ivf_pq
Definition: knn.hpp:99
float metricArg
Definition: knn.hpp:94
std::unique_ptr< cuvs::neighbors::ivf_flat::index< float, int64_t > > ivf_flat
Definition: knn.hpp:98
std::unique_ptr< raft::spatial::knn::MetricProcessor< float > > metric_processor
Definition: knn.hpp:96