cuvs/ivf_flat/
search_params.rs

1/*
2 * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::error::{check_cuvs, Result};
7use std::fmt;
8use std::io::{stderr, Write};
9
10/// Supplemental parameters to search IvfFlat index
11pub struct SearchParams(pub ffi::cuvsIvfFlatSearchParams_t);
12
13impl SearchParams {
14    /// Returns a new SearchParams object
15    pub fn new() -> Result<SearchParams> {
16        unsafe {
17            let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfFlatSearchParams_t>::uninit();
18            check_cuvs(ffi::cuvsIvfFlatSearchParamsCreate(params.as_mut_ptr()))?;
19            Ok(SearchParams(params.assume_init()))
20        }
21    }
22
23    /// Supplemental parameters to search IVF-Flat index
24    pub fn set_n_probes(self, n_probes: u32) -> SearchParams {
25        unsafe {
26            (*self.0).n_probes = n_probes;
27        }
28        self
29    }
30}
31
32impl fmt::Debug for SearchParams {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        // custom debug trait here, default value will show the pointer address
35        // for the inner params object which isn't that useful.
36        write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
37    }
38}
39
40impl Drop for SearchParams {
41    fn drop(&mut self) {
42        if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatSearchParamsDestroy(self.0) }) {
43            write!(
44                stderr(),
45                "failed to call cuvsIvfFlatSearchParamsDestroy {:?}",
46                e
47            )
48            .expect("failed to write to stderr");
49        }
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn test_search_params() {
59        let params = SearchParams::new().unwrap().set_n_probes(128);
60
61        unsafe {
62            assert_eq!((*params.0).n_probes, 128);
63        }
64    }
65}