cuvs/
brute_force.rs

1/*
2 * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5//! Brute Force KNN
6
7use std::io::{stderr, Write};
8
9use crate::distance_type::DistanceType;
10use crate::dlpack::ManagedTensor;
11use crate::error::{check_cuvs, Result};
12use crate::resources::Resources;
13
14/// Brute Force KNN Index
15#[derive(Debug)]
16pub struct Index(ffi::cuvsBruteForceIndex_t);
17
18impl Index {
19    /// Builds a new Brute Force KNN Index from the dataset for efficient search.
20    ///
21    /// # Arguments
22    ///
23    /// * `res` - Resources to use
24    /// * `metric` - DistanceType to use for building the index
25    /// * `metric_arg` - Optional value of `p` for Minkowski distances
26    /// * `dataset` - A row-major matrix on either the host or device to index
27    pub fn build<T: Into<ManagedTensor>>(
28        res: &Resources,
29        metric: DistanceType,
30        metric_arg: Option<f32>,
31        dataset: T,
32    ) -> Result<Index> {
33        let dataset: ManagedTensor = dataset.into();
34        let index = Index::new()?;
35        unsafe {
36            check_cuvs(ffi::cuvsBruteForceBuild(
37                res.0,
38                dataset.as_ptr(),
39                metric,
40                metric_arg.unwrap_or(2.0),
41                index.0,
42            ))?;
43        }
44        Ok(index)
45    }
46
47    /// Creates a new empty index
48    pub fn new() -> Result<Index> {
49        unsafe {
50            let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
51            check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
52            Ok(Index(index.assume_init()))
53        }
54    }
55
56    /// Perform a Nearest Neighbors search on the Index
57    ///
58    /// # Arguments
59    ///
60    /// * `res` - Resources to use
61    /// * `queries` - A matrix in device memory to query for
62    /// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
63    /// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
64    pub fn search(
65        &self,
66        res: &Resources,
67        queries: &ManagedTensor,
68        neighbors: &ManagedTensor,
69        distances: &ManagedTensor,
70    ) -> Result<()> {
71        unsafe {
72            let prefilter = ffi::cuvsFilter {
73                addr: 0,
74                type_: ffi::cuvsFilterType::NO_FILTER,
75            };
76
77            check_cuvs(ffi::cuvsBruteForceSearch(
78                res.0,
79                self.0,
80                queries.as_ptr(),
81                neighbors.as_ptr(),
82                distances.as_ptr(),
83                prefilter,
84            ))
85        }
86    }
87}
88
89impl Drop for Index {
90    fn drop(&mut self) {
91        if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
92            write!(stderr(), "failed to call bruteForceIndexDestroy {:?}", e)
93                .expect("failed to write to stderr");
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use mark_flaky_tests::flaky;
102    use ndarray::s;
103    use ndarray_rand::rand_distr::Uniform;
104    use ndarray_rand::RandomExt;
105
106    fn test_bfknn(metric: DistanceType) {
107        let res = Resources::new().unwrap();
108
109        // Create a new random dataset to index
110        let n_datapoints = 16;
111        let n_features = 8;
112        let dataset_host =
113            ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
114
115        let dataset = ManagedTensor::from(&dataset_host).to_device(&res).unwrap();
116
117        println!("dataset {:#?}", dataset_host);
118
119        // build the brute force index
120        let index =
121            Index::build(&res, metric, None, dataset).expect("failed to create brute force index");
122
123        res.sync_stream().unwrap();
124
125        // use the first 4 points from the dataset as queries : will test that we get them back
126        // as their own nearest neighbor
127        let n_queries = 4;
128        let queries = dataset_host.slice(s![0..n_queries, ..]);
129
130        let k = 4;
131
132        println!("queries! {:#?}", queries);
133        let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
134        let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
135        let neighbors = ManagedTensor::from(&neighbors_host)
136            .to_device(&res)
137            .unwrap();
138
139        let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
140        let distances = ManagedTensor::from(&distances_host)
141            .to_device(&res)
142            .unwrap();
143
144        index
145            .search(&res, &queries, &neighbors, &distances)
146            .unwrap();
147
148        // Copy back to host memory
149        distances.to_host(&res, &mut distances_host).unwrap();
150        neighbors.to_host(&res, &mut neighbors_host).unwrap();
151        res.sync_stream().unwrap();
152
153        println!("distances {:#?}", distances_host);
154        println!("neighbors {:#?}", neighbors_host);
155
156        // nearest neighbors should be themselves, since queries are from the
157        // dataset
158        assert_eq!(neighbors_host[[0, 0]], 0);
159        assert_eq!(neighbors_host[[1, 0]], 1);
160        assert_eq!(neighbors_host[[2, 0]], 2);
161        assert_eq!(neighbors_host[[3, 0]], 3);
162    }
163
164    /*
165        #[test]
166        fn test_cosine() {
167            test_bfknn(DistanceType::CosineExpanded);
168        }
169    */
170
171    #[flaky]
172    fn test_l2() {
173        test_bfknn(DistanceType::L2Expanded);
174    }
175
176    // NOTE: brute_force multiple-search test is omitted here because the C++
177    // brute_force::index stores a non-owning view into the dataset. Building
178    // from device data via `build()` drops the ManagedTensor after the call,
179    // leaving a dangling pointer. A follow-up PR will add dataset lifetime
180    // enforcement (DatasetOwnership<'a>) to make this safe.
181    // See: https://github.com/rapidsai/cuvs/issues/1838
182}