knn.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/distance/distance_types.hpp>
20 #include <raft/spatial/knn/ball_cover_types.hpp>
21 #include <raft/spatial/knn/detail/processing.hpp> // MetricProcessor
22 
23 #include <cuvs/neighbors/ivf_flat.hpp>
24 #include <cuvs/neighbors/ivf_pq.hpp>
25 
26 namespace raft {
27 class handle_t;
28 }
29 
30 namespace ML {
31 
55 void brute_force_knn(const raft::handle_t& handle,
56  std::vector<float*>& input,
57  std::vector<int>& sizes,
58  int D,
59  float* search_items,
60  int n,
61  int64_t* res_I,
62  float* res_D,
63  int k,
64  bool rowMajorIndex = false,
65  bool rowMajorQuery = false,
66  raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
67  float metric_arg = 2.0f,
68  std::vector<int64_t>* translations = nullptr);
69 
70 void rbc_build_index(const raft::handle_t& handle,
71  raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index);
72 
73 void rbc_knn_query(const raft::handle_t& handle,
74  raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index,
75  uint32_t k,
76  const float* search_items,
77  uint32_t n_search_items,
78  int64_t* out_inds,
79  float* out_dists);
80 
81 struct knnIndex {
82  raft::distance::DistanceType metric;
83  float metricArg;
84  int nprobe;
85  std::unique_ptr<raft::spatial::knn::MetricProcessor<float>> metric_processor;
86 
87  std::unique_ptr<cuvs::neighbors::ivf_flat::index<float, int64_t>> ivf_flat;
88  std::unique_ptr<cuvs::neighbors::ivf_pq::index<int64_t>> ivf_pq;
89 
90  int device;
91 };
92 
93 struct knnIndexParam {
94  virtual ~knnIndexParam() {}
95 };
96 
98  int nlist;
99  int nprobe;
100 };
101 
102 struct IVFFlatParam : IVFParam {};
103 
105  int M;
106  int n_bits;
108 };
109 
123 void approx_knn_build_index(raft::handle_t& handle,
124  knnIndex* index,
126  raft::distance::DistanceType metric,
127  float metricArg,
128  float* index_array,
129  int n,
130  int D);
131 
145 void approx_knn_search(raft::handle_t& handle,
146  float* distances,
147  int64_t* indices,
148  knnIndex* index,
149  int k,
150  float* query_array,
151  int n);
152 
167 void knn_classify(raft::handle_t& handle,
168  int* out,
169  int64_t* knn_indices,
170  std::vector<int*>& y,
171  size_t n_index_rows,
172  size_t n_query_rows,
173  int k);
174 
189 void knn_regress(raft::handle_t& handle,
190  float* out,
191  int64_t* knn_indices,
192  std::vector<float*>& y,
193  size_t n_index_rows,
194  size_t n_query_rows,
195  int k);
196 
211 void knn_class_proba(raft::handle_t& handle,
212  std::vector<float*>& out,
213  int64_t* knn_indices,
214  std::vector<int*>& y,
215  size_t n_index_rows,
216  size_t n_query_rows,
217  int k);
218 }; // namespace ML
Definition: params.hpp:34
Definition: dbscan.hpp:30
void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn classification using a given a vector of label arrays....
void brute_force_knn(const raft::handle_t &handle, std::vector< float * > &input, std::vector< int > &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex=false, bool rowMajorQuery=false, raft::distance::DistanceType metric=raft::distance::DistanceType::L2Expanded, float metric_arg=2.0f, std::vector< int64_t > *translations=nullptr)
Flat C++ API function to perform a brute force knn on a series of input arrays and combine the result...
void knn_class_proba(raft::handle_t &handle, std::vector< float * > &out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to compute knn class probabilities using a vector of device arrays containing d...
void knn_regress(raft::handle_t &handle, float *out, int64_t *knn_indices, std::vector< float * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn regression using a given a vector of label arrays....
void approx_knn_search(raft::handle_t &handle, float *distances, int64_t *indices, knnIndex *index, int k, float *query_array, int n)
Flat C++ API function to perform an approximate nearest neighbors search from previously built index ...
void rbc_build_index(const raft::handle_t &handle, raft::spatial::knn::BallCoverIndex< int64_t, float, uint32_t > &index)
void rbc_knn_query(const raft::handle_t &handle, raft::spatial::knn::BallCoverIndex< int64_t, float, uint32_t > &index, uint32_t k, const float *search_items, uint32_t n_search_items, int64_t *out_inds, float *out_dists)
void approx_knn_build_index(raft::handle_t &handle, knnIndex *index, knnIndexParam *params, raft::distance::DistanceType metric, float metricArg, float *index_array, int n, int D)
Flat C++ API function to build an approximate nearest neighbors index from an index array and a set o...
Definition: dbscan.hpp:26
Definition: knn.hpp:102
Definition: knn.hpp:104
int M
Definition: knn.hpp:105
int n_bits
Definition: knn.hpp:106
bool usePrecomputedTables
Definition: knn.hpp:107
Definition: knn.hpp:97
int nprobe
Definition: knn.hpp:99
int nlist
Definition: knn.hpp:98
Definition: knn.hpp:93
virtual ~knnIndexParam()
Definition: knn.hpp:94
Definition: knn.hpp:81
int nprobe
Definition: knn.hpp:84
int device
Definition: knn.hpp:90
std::unique_ptr< cuvs::neighbors::ivf_pq::index< int64_t > > ivf_pq
Definition: knn.hpp:88
float metricArg
Definition: knn.hpp:83
std::unique_ptr< cuvs::neighbors::ivf_flat::index< float, int64_t > > ivf_flat
Definition: knn.hpp:87
std::unique_ptr< raft::spatial::knn::MetricProcessor< float > > metric_processor
Definition: knn.hpp:85
raft::distance::DistanceType metric
Definition: knn.hpp:82