1use std::io::{stderr, Write};
8
9use crate::distance_type::DistanceType;
10use crate::dlpack::ManagedTensor;
11use crate::error::{check_cuvs, Result};
12use crate::resources::Resources;
13
14#[derive(Debug)]
16pub struct Index(ffi::cuvsBruteForceIndex_t);
17
18impl Index {
19 pub fn build<T: Into<ManagedTensor>>(
28 res: &Resources,
29 metric: DistanceType,
30 metric_arg: Option<f32>,
31 dataset: T,
32 ) -> Result<Index> {
33 let dataset: ManagedTensor = dataset.into();
34 let index = Index::new()?;
35 unsafe {
36 check_cuvs(ffi::cuvsBruteForceBuild(
37 res.0,
38 dataset.as_ptr(),
39 metric,
40 metric_arg.unwrap_or(2.0),
41 index.0,
42 ))?;
43 }
44 Ok(index)
45 }
46
47 pub fn new() -> Result<Index> {
49 unsafe {
50 let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
51 check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
52 Ok(Index(index.assume_init()))
53 }
54 }
55
56 pub fn search(
65 &self,
66 res: &Resources,
67 queries: &ManagedTensor,
68 neighbors: &ManagedTensor,
69 distances: &ManagedTensor,
70 ) -> Result<()> {
71 unsafe {
72 let prefilter = ffi::cuvsFilter {
73 addr: 0,
74 type_: ffi::cuvsFilterType::NO_FILTER,
75 };
76
77 check_cuvs(ffi::cuvsBruteForceSearch(
78 res.0,
79 self.0,
80 queries.as_ptr(),
81 neighbors.as_ptr(),
82 distances.as_ptr(),
83 prefilter,
84 ))
85 }
86 }
87}
88
89impl Drop for Index {
90 fn drop(&mut self) {
91 if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
92 write!(stderr(), "failed to call bruteForceIndexDestroy {:?}", e)
93 .expect("failed to write to stderr");
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use mark_flaky_tests::flaky;
102 use ndarray::s;
103 use ndarray_rand::rand_distr::Uniform;
104 use ndarray_rand::RandomExt;
105
106 fn test_bfknn(metric: DistanceType) {
107 let res = Resources::new().unwrap();
108
109 let n_datapoints = 16;
111 let n_features = 8;
112 let dataset_host =
113 ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
114
115 let dataset = ManagedTensor::from(&dataset_host).to_device(&res).unwrap();
116
117 println!("dataset {:#?}", dataset_host);
118
119 let index =
121 Index::build(&res, metric, None, dataset).expect("failed to create brute force index");
122
123 res.sync_stream().unwrap();
124
125 let n_queries = 4;
128 let queries = dataset_host.slice(s![0..n_queries, ..]);
129
130 let k = 4;
131
132 println!("queries! {:#?}", queries);
133 let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
134 let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
135 let neighbors = ManagedTensor::from(&neighbors_host)
136 .to_device(&res)
137 .unwrap();
138
139 let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
140 let distances = ManagedTensor::from(&distances_host)
141 .to_device(&res)
142 .unwrap();
143
144 index
145 .search(&res, &queries, &neighbors, &distances)
146 .unwrap();
147
148 distances.to_host(&res, &mut distances_host).unwrap();
150 neighbors.to_host(&res, &mut neighbors_host).unwrap();
151 res.sync_stream().unwrap();
152
153 println!("distances {:#?}", distances_host);
154 println!("neighbors {:#?}", neighbors_host);
155
156 assert_eq!(neighbors_host[[0, 0]], 0);
159 assert_eq!(neighbors_host[[1, 0]], 1);
160 assert_eq!(neighbors_host[[2, 0]], 2);
161 assert_eq!(neighbors_host[[3, 0]], 3);
162 }
163
164 #[flaky]
172 fn test_l2() {
173 test_bfknn(DistanceType::L2Expanded);
174 }
175
176 }