CAGRA#
CAGRA is a graph-based nearest neighbors algorithm that was built from the ground up for GPU acceleration. CAGRA demonstrates state-of-the art index build and query performance for both small- and large-batch sized search.
#include <cuvs/neighbors/cagra.hpp>
namespace cuvs::neighbors::cagra
Index build parameters#
-
enum class hnsw_heuristic_type : uint32_t#
A strategy for selecting the graph build parameters based on similar HNSW index parameters.
Define how
cagra::index_params::from_hnsw_paramsshould construct a graph to construct a graph that is to be converted to (used by) a CPU HNSW index.Values:
-
enumerator SIMILAR_SEARCH_PERFORMANCE#
Create a graph that is very similar to an HNSW graph in terms of the number of nodes and search performance. Since HNSW produces a variable-degree graph (2M being the max graph degree) and CAGRA produces a fixed-degree graph, there’s always a difference in the performance of the two.
This function attempts to produce such a graph that the QPS and recall of the two graphs being searched by HNSW are close for any search parameter combination. The CAGRA-produced graph tends to have a “longer tail” on the low recall side (that is being slightly faster and less precise).
-
enumerator SAME_GRAPH_FOOTPRINT#
Create a graph that has the same binary size as an HNSW graph with the given parameters (
graph_degree = 2 * M) while trying to match the search performance as closely as possible.The reference HNSW index and the corresponding from-CAGRA generated HNSW index will NOT produce the same recalls and QPS for the same parameter
ef. The graphs are different internally. For the sameef, the from-CAGRA index likely has a slightly higher recall and slightly lower QPS. However, the Recall-QPS curves should be similar (i.e. the points are just shifted along the curve).
-
enumerator SIMILAR_SEARCH_PERFORMANCE#
-
struct index_params : public cuvs::neighbors::index_params#
- #include <cagra.hpp>
Public Members
-
size_t intermediate_graph_degree = 128#
Degree of input graph for pruning.
-
size_t graph_degree = 64#
Degree of output graph.
-
std::optional<cuvs::neighbors::vpq_params> compression = std::nullopt#
Specify compression parameters if compression is desired. If set, overrides the attach_dataset_on_build (and the compressed dataset is always added to the index).
-
std::variant<std::monostate, graph_build_params::ivf_pq_params, graph_build_params::nn_descent_params, graph_build_params::ace_params, graph_build_params::iterative_search_params> graph_build_params#
Parameters for graph building.
Set ivf_pq_params, nn_descent_params, ace_params, or iterative_search_params to select the graph build algorithm and control their parameters. The default (std::monostate) is to use a heuristic to decide the algorithm and its parameters.
cagra::index_params params; // 1. Choose IVF-PQ algorithm params.graph_build_params = cagra::graph_build_params::ivf_pq_params(dataset.extent, params.metric); // 2. Choose NN Descent algorithm for kNN graph construction params.graph_build_params = cagra::graph_build_params::nn_descent_params(params.intermediate_graph_degree); // 3. Choose ACE algorithm for graph construction params.graph_build_params = cagra::graph_build_params::ace_params(); // 4. Choose iterative graph building using CAGRA's search() and optimize() [Experimental] params.graph_build_params = cagra::graph_build_params::iterative_search_params();
-
bool guarantee_connectivity = false#
Whether to use MST optimization to guarantee graph connectivity.
-
bool attach_dataset_on_build = true#
Whether to add the dataset content to the index, i.e.:
truemeans the index is filled with the dataset vectors and ready to search after callingbuildprovided there is enough memory available.falsemeansbuildonly builds the graph and the user is expected to update the dataset using cuvs::neighbors::cagra::update_dataset.
Regardless of the value of
attach_dataset_on_build, the search graph is created using all the vectors in the dataset. Settingattach_dataset_on_build = falsecan be useful if the user needs to build only the search graph but does not intend to search it using CAGRA (e.g. search using another graph search algorithm), or if specific memory placement options need to be applied on the dataset before it is attached to the index usingupdate_dataset. API.auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols); // use default index_parameters cagra::index_params index_params; // update index_params to only build the CAGRA graph index_params.attach_dataset_on_build = false; auto index = cagra::build(res, index_params, dataset.view()); // assert that the dataset is not attached to the index ASSERT(index.dataset().extent(0) == 0); // update dataset index.update_dataset(res, dataset.view()); // The index is now ready for search cagra::search(res, search_params, index, queries, neighbors, distances);
Public Static Functions
- static cagra::index_params from_hnsw_params(
- raft::matrix_extent<int64_t> dataset,
- int M,
- int ef_construction,
- hnsw_heuristic_type heuristic = hnsw_heuristic_type::SIMILAR_SEARCH_PERFORMANCE,
- cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded
Create a CAGRA index parameters compatible with HNSW index.
IMPORTANT NOTE *
The reference HNSW index and the corresponding from-CAGRA generated HNSW index will NOT produce exactly the same recalls and QPS for the same parameter
ef. The graphs are different internally. Depending on the selected heuristics, the CAGRA-produced graph’s QPS-Recall curve may be shifted along the curve right or left. See the heuristics descriptions for more details.Usage example:
using namespace cuvs::neighbors; raft::resources res; auto dataset = raft::make_device_matrix<float, int64_t>(res, N, D); auto cagra_params = cagra::index_params::from_hnsw_params(dataset.extents(), M, efc); auto cagra_index = cagra::build(res, cagra_params, dataset); auto hnsw_index = hnsw::from_cagra(res, hnsw_params, cagra_index);
- Parameters:
dataset – The shape of the input dataset
M – HNSW index parameter M
ef_construction – HNSW index parameter ef_construction
heuristic – The heuristic to use for selecting the graph build parameters
metric – The distance metric to search
-
size_t intermediate_graph_degree = 128#
Index search parameters#
-
enum class search_algo#
Values:
-
enumerator SINGLE_CTA#
For large batch sizes.
-
enumerator MULTI_CTA#
For small batch sizes.
-
enumerator MULTI_KERNEL#
-
enumerator AUTO#
-
enumerator SINGLE_CTA#
-
struct search_params : public cuvs::neighbors::search_params#
- #include <cagra.hpp>
Public Members
-
size_t max_queries = 0#
Maximum number of queries to search at the same time (batch size). Auto select when 0.
-
size_t itopk_size = 64#
Number of intermediate search results retained during the search.
This is the main knob to adjust trade off between accuracy and search speed. Higher values improve the search accuracy.
-
size_t max_iterations = 0#
Upper limit of search iterations. Auto select when 0.
-
search_algo algo = search_algo::AUTO#
Which search implementation to use.
-
size_t team_size = 0#
Number of threads used to calculate a single distance. 4, 8, 16, or 32.
-
size_t search_width = 1#
Number of graph nodes to select as the starting point for the search in each iteration. aka search width?
-
size_t min_iterations = 0#
Lower limit of search iterations.
-
size_t thread_block_size = 0#
Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
-
size_t hashmap_min_bitlen = 0#
Lower limit of hashmap bit length. More than 8.
-
float hashmap_max_fill_rate = 0.5#
Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
-
uint32_t num_random_samplings = 1#
Number of iterations of initial random seed node selection. 1 or more.
-
uint64_t rand_xor_mask = 0x128394#
Bit mask used for initial random seed node selection.
-
bool persistent = false#
Whether to use the persistent version of the kernel (only SINGLE_CTA is supported a.t.m.)
-
float persistent_lifetime = 2#
Persistent kernel: time in seconds before the kernel stops if no requests received.
-
float persistent_device_usage = 1.0#
Set the fraction of maximum grid size used by persistent kernel. Value 1.0 means the kernel grid size is maximum possible for the selected device. The value must be greater than 0.0 and not greater than 1.0.
One may need to run other kernels alongside this persistent kernel. This parameter can be used to reduce the grid size of the persistent kernel to leave a few SMs idle. Note: running any other work on GPU alongside with the persistent kernel makes the setup fragile.
Running another kernel in another thread usually works, but no progress guaranteed
Any CUDA allocations block the context (this issue may be obscured by using pools)
Memory copies to not-pinned host memory may block the context
Even when we know there are no other kernels working at the same time, setting kDeviceUsage to 1.0 surprisingly sometimes hurts performance. Proceed with care. If you suspect this is an issue, you can reduce this number to ~0.9 without a significant impact on the throughput.
-
float filtering_rate = -1.0#
A parameter indicating the rate of nodes to be filtered-out, when filtering is used. The value must be equal to or greater than 0.0 and less than 1.0. Default value is negative, in which case the filtering rate is automatically calculated.
-
size_t max_queries = 0#
Index extend parameters#
-
struct extend_params#
- #include <cagra.hpp>
Public Members
-
uint32_t max_chunk_size = 0#
The additional dataset is divided into chunks and added to the graph. This is the knob to adjust the tradeoff between the recall and operation throughput. Large chunk sizes can result in high throughput, but use more working memory (O(max_chunk_size*degree^2)). This can also degrade recall because no edges are added between the nodes in the same chunk. Auto select when
-
uint32_t max_chunk_size = 0#
Index#
-
template<typename T, typename IdxT>
struct index : public cuvs::neighbors::index# - #include <cagra.hpp>
CAGRA index.
The index stores the dataset and a kNN graph in device memory.
- Template Parameters:
T – data element type
IdxT – the data type used to store the neighbor indices in the search graph. It must be large enough to represent values up to dataset.extent(0).
Public Functions
-
inline cuvs::distance::DistanceType metric() const noexcept#
Distance metric used for clustering.
-
inline uint32_t dim() const noexcept#
Dimensionality of the data.
-
inline uint32_t graph_degree() const noexcept#
Graph degree
- inline const cuvs::neighbors::dataset<int64_t> &data(
Dataset [size, dim]
- inline raft::device_matrix_view<const graph_index_type, int64_t, raft::row_major> graph(
neighborhood graph [size, graph-degree]
- inline std::optional<raft::device_vector_view<const index_type, int64_t>> source_indices(
Mapping from internal graph node indices to the original user-provided indices.
- inline const std::optional<cuvs::util::file_descriptor> &dataset_fd(
Get the dataset file descriptor (for disk-backed index)
- inline const std::optional<cuvs::util::file_descriptor> &graph_fd(
Get the graph file descriptor (for disk-backed index)
- inline const std::optional<cuvs::util::file_descriptor> &mapping_fd(
Get the mapping file descriptor (for disk-backed index)
- inline std::optional<raft::device_vector_view<const float, int64_t>> dataset_norms(
Dataset norms for cosine distance [size]
- inline index(
- raft::resources const &res,
- cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded
Construct an empty index.
-
template<typename data_accessor, typename graph_accessor>
inline index( - raft::resources const &res,
- cuvs::distance::DistanceType metric,
- raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, data_accessor> dataset,
- raft::mdspan<const graph_index_type, raft::matrix_extent<int64_t>, raft::row_major, graph_accessor> knn_graph
Construct an index from dataset and knn_graph arrays
If the dataset and graph is already in GPU memory, then the index is just a thin wrapper around these that stores a non-owning a reference to the arrays.
The constructor also accepts host arrays. In that case they are copied to the device, and the device arrays will be owned by the index.
In case the dasates rows are not 16 bytes aligned, then we create a padded copy in device memory to ensure alignment for vectorized load.
Usage examples:
Cagra index is normally created by the cagra::build
In the above example, we have passed a host dataset to build. The returned index will own a device copy of the dataset and the knn_graph. In contrast, if we pass the dataset as a device_mdspan to build, then it will only store a reference to it.using namespace raft::neighbors::experimental; auto dataset = raft::make_host_matrix<float, int64_t>(n_rows, n_cols); load_dataset(dataset.view()); // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t, int64_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float, int64_t>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
Constructing index using existing knn-graph
using namespace raft::neighbors::experimental; auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols); auto knn_graph = raft::make_device_matrix<uint32_t, int64_t>(res, n_rows, graph_degree); // custom loading and graph creation // load_dataset(dataset.view()); // create_knn_graph(knn_graph.view()); // Wrap the existing device arrays into an index structure cagra::index<T, IdxT> index(res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(knn_graph.view())); // Both knn_graph and dataset objects have to be in scope while the index is used because // the index only stores a reference to these. cagra::search(res, search_params, index, queries, neighbors, distances);
- inline void update_dataset(
- raft::resources const &res,
- raft::device_matrix_view<const T, int64_t, raft::row_major> dataset
Replace the dataset with a new dataset.
If the new dataset rows are aligned on 16 bytes, then only a reference is stored to the dataset. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. It is expected that the same set of vectors are used for update_dataset and index build.
Note: This will clear any precomputed dataset norms.
- inline void update_dataset(
- raft::resources const &res,
- raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset
Set the dataset reference explicitly to a device matrix view with padding.
- inline void update_dataset(
- raft::resources const &res,
- raft::host_matrix_view<const T, int64_t, raft::row_major> dataset
Replace the dataset with a new dataset.
We create a copy of the dataset on the device. The index manages the lifetime of this copy. It is expected that the same set of vectors are used for update_dataset and index build.
Note: This will clear any precomputed dataset norms.
-
template<typename DatasetT>
inline std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>> update_dataset( - raft::resources const &res,
- DatasetT &&dataset
Replace the dataset with a new dataset. It is expected that the same set of vectors are used for update_dataset and index build.
Note: This will clear any precomputed dataset norms.
- inline void update_graph(
- raft::resources const &res,
- raft::device_matrix_view<const graph_index_type, int64_t, raft::row_major> knn_graph
Replace the graph with a new graph.
Since the new graph is a device array, we store a reference to that, and it is the caller’s responsibility to ensure that knn_graph stays alive as long as the index.
- inline void update_graph(
- raft::resources const &res,
- raft::host_matrix_view<const graph_index_type, int64_t, raft::row_major> knn_graph
Replace the graph with a new graph.
We create a copy of the graph on the device. The index manages the lifetime of this copy.
- inline void update_source_indices(
- raft::device_vector<index_type, int64_t> &&source_indices
Replace the source indices with a new source indices taking the ownership of the passed vector.
-
template<typename Accessor>
inline void update_source_indices( - raft::resources const &res,
- raft::mdspan<const index_type, raft::vector_extent<int64_t>, raft::row_major, Accessor> source_indices
Copy the provided source indices into the index.
- inline void update_dataset(
- raft::resources const &res,
- cuvs::util::file_descriptor &&fd
Update the dataset from a disk file using a file descriptor.
This method configures the index to use a disk-based dataset. The dataset file should contain a numpy header followed by vectors in row-major format. The number of rows and dimensionality are read from the numpy header.
- Parameters:
res – [in] raft resources
fd – [in] File descriptor (will be moved into the index for lifetime management)
- inline void update_graph(
- raft::resources const &res,
- cuvs::util::file_descriptor &&fd
Update the graph from a disk file using a file descriptor.
This method configures the index to use a disk-based graph. The graph file should contain a numpy header followed by neighbor indices in row-major format. The number of rows and graph degree are read from the numpy header.
- Parameters:
res – [in] raft resources
fd – [in] File descriptor (will be moved into the index for lifetime management)
- inline void update_mapping(
- raft::resources const &res,
- cuvs::util::file_descriptor &&fd
Update the dataset mapping from a disk file using a file descriptor.
This method configures the index to use a disk-based dataset mapping. The mapping file should contain a numpy header followed by index mappings.
- Parameters:
res – [in] raft resources
fd – [in] File descriptor (will be moved into the index for lifetime management)
Index build#
- cuvs::neighbors::cagra::index<float, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::device_matrix_view<const float, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (device) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<float, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::host_matrix_view<const float, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (host) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<half, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::device_matrix_view<const half, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (device) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<half, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::host_matrix_view<const half, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (host) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<int8_t, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::device_matrix_view<const int8_t, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (device) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<int8_t, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::host_matrix_view<const int8_t, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (host) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (device) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> build(
- raft::resources const &res,
- const cuvs::neighbors::cagra::index_params ¶ms,
- raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> dataset
Build the index from the dataset for efficient search.
The build consist of two steps: build an intermediate knn-graph, and optimize it to create the final graph. The index_params struct controls the node degree of these graphs.
The following distance metrics are supported:
L2
InnerProduct (currently only supported with IVF-PQ as the build algorithm)
CosineExpanded (dataset norms are computed as float regardless of input data type)
Usage example:
using namespace cuvs::neighbors; // use default index parameters cagra::index_params index_params; // create and fill the index from a [N, D] dataset auto index = cagra::build(res, index_params, dataset); // use default search parameters cagra::search_params search_params; // search K nearest neighbours auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k); auto distances = raft::make_device_matrix<float>(res, n_queries, k); cagra::search(res, search_params, index, queries, neighbors.view(), distances.view());
- Parameters:
res – [in]
params – [in] parameters for building the index
dataset – [in] a matrix view (host) to a row-major matrix [n_rows, dim]
- Returns:
the constructed cagra index
Index search#
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
- raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
- raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
- raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
- raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
- raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- raft::device_matrix_view<const half, int64_t, raft::row_major> queries,
- raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
- raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
- void search(
- raft::resources const &res,
- cuvs::neighbors::cagra::search_params const ¶ms,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
- raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
- raft::device_matrix_view<float, int64_t, raft::row_major> distances,
- const cuvs::neighbors::filtering::base_filter &sample_filter = cuvs::neighbors::filtering::none_sample_filter{}
Search ANN using the constructed index.
See the cagra::build documentation for a usage example.
- Parameters:
res – [in] raft resources
params – [in] configure the search
index – [in] cagra index
queries – [in] a device matrix view to a row-major matrix [n_queries, index->dim()]
neighbors – [out] a device matrix view to the indices of the neighbors in the source dataset [n_queries, k]
distances – [out] a device matrix view to the distances to the selected neighbors [n_queries, k]
sample_filter – [in] an optional device filter function object that greenlights samples for a given query. (none_sample_filter for no filtering)
Index extend#
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::device_matrix_view<const float, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<float, uint32_t> &idx,
- std::optional<raft::device_matrix_view<float, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_device_matrix<float, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on device memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::host_matrix_view<const float, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<float, uint32_t> &idx,
- std::optional<raft::device_matrix_view<float, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_host_matrix<float, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on host memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::device_matrix_view<const int8_t, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<int8_t, uint32_t> &idx,
- std::optional<raft::device_matrix_view<int8_t, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_device_matrix<int8_t, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on device memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::host_matrix_view<const int8_t, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<int8_t, uint32_t> &idx,
- std::optional<raft::device_matrix_view<int8_t, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_host_matrix<int8_t, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on host memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> &idx,
- std::optional<raft::device_matrix_view<uint8_t, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_host_matrix<uint8_t, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on host memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
- void extend(
- raft::resources const &handle,
- const cagra::extend_params ¶ms,
- raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> additional_dataset,
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> &idx,
- std::optional<raft::device_matrix_view<uint8_t, int64_t, raft::layout_stride>> new_dataset_buffer_view = std::nullopt,
- std::optional<raft::device_matrix_view<uint32_t, int64_t>> new_graph_buffer_view = std::nullopt
Add new vectors to a CAGRA index.
Usage example:
using namespace raft::neighbors; auto additional_dataset = raft::make_host_matrix<uint8_t, int64_t>(handle,add_size,dim); // set_additional_dataset(additional_dataset.view()); cagra::extend_params params; cagra::extend(res, params, raft::make_const_mdspan(additional_dataset.view()), index);
- Parameters:
handle – [in] raft resources
params – [in] extend params
additional_dataset – [in] additional dataset on host memory
idx – [inout] CAGRA index
new_dataset_buffer_view – [out] memory buffer view for the dataset including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets, cols must be the dimension of the dataset, and the stride must be the same as the original index dataset. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the dataset themselves.
new_graph_buffer_view – [out] memory buffer view for the graph including the additional part. The data will be copied from the current index in this function. The num rows must be the sum of the original and additional datasets and cols must be the graph degree. This view will be stored in the output index. It is the caller’s responsibility to ensure that dataset stays alive as long as the index. This option is useful when users want to manage the memory space for the graph themselves.
Index serialize#
- void serialize(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- bool include_dataset = true
Save the index to file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- const std::string &filename,
- cuvs::neighbors::cagra::index<float, uint32_t> *index
Load index from file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); cuvs::neighbors::cagra::index<float, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, filename, &index);
- Parameters:
handle – [in] the raft handle
filename – [in] the name of the file that stores the index
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- bool include_dataset = true
Write the index to an output stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- std::istream &is,
- cuvs::neighbors::cagra::index<float, uint32_t> *index
Load index from input stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an input stream std::istream is(std::cin.rdbuf()); cuvs::neighbors::cagra::index<float, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, is, &index);
- Parameters:
handle – [in] the raft handle
is – [in] input stream
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- bool include_dataset = true
Save the index to file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- const std::string &filename,
- cuvs::neighbors::cagra::index<half, uint32_t> *index
Load index from file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); cuvs::neighbors::cagra::index<half, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, filename, &index);
- Parameters:
handle – [in] the raft handle
filename – [in] the name of the file that stores the index
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- bool include_dataset = true
Write the index to an output stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- std::istream &is,
- cuvs::neighbors::cagra::index<half, uint32_t> *index
Load index from input stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an input stream std::istream is(std::cin.rdbuf()); cuvs::neighbors::cagra::index<half, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, is, &index);
- Parameters:
handle – [in] the raft handle
is – [in] input stream
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- bool include_dataset = true
Save the index to file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- const std::string &filename,
- cuvs::neighbors::cagra::index<int8_t, uint32_t> *index
Load index from file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); cuvs::neighbors::cagra::index<int8_t, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, filename, &index);
- Parameters:
handle – [in] the raft handle
filename – [in] the name of the file that stores the index
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- bool include_dataset = true
Write the index to an output stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- std::istream &is,
- cuvs::neighbors::cagra::index<int8_t, uint32_t> *index
Load index from input stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an input stream std::istream is(std::cin.rdbuf()); cuvs::neighbors::cagra::index<int8_t, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, is, &index);
- Parameters:
handle – [in] the raft handle
is – [in] input stream
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- bool include_dataset = true
Save the index to file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- const std::string &filename,
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> *index
Load index from file.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); cuvs::neighbors::cagra::index<uint8_t, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, filename, &index);
- Parameters:
handle – [in] the raft handle
filename – [in] the name of the file that stores the index
index – [out] the cagra index
- void serialize(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- bool include_dataset = true
Write the index to an output stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
include_dataset – [in] Whether or not to write out the dataset to the file.
- void deserialize(
- raft::resources const &handle,
- std::istream &is,
- cuvs::neighbors::cagra::index<uint8_t, uint32_t> *index
Load index from input stream
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an input stream std::istream is(std::cin.rdbuf()); cuvs::neighbors::cagra::index<uint8_t, uint32_t> index; cuvs::neighbors::cagra::deserialize(handle, is, &index);
- Parameters:
handle – [in] the raft handle
is – [in] input stream
index – [out] the cagra index
- void serialize_to_hnswlib(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset = std::nullopt
Write the CAGRA built index as a base layer HNSW index to an output stream NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<float, uint32_t> &index,
- std::optional<raft::host_matrix_view<const float, int64_t, raft::row_major>> dataset = std::nullopt
Save a CAGRA build index in hnswlib base-layer-only serialized format NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- std::optional<raft::host_matrix_view<const half, int64_t, raft::row_major>> dataset = std::nullopt
Write the CAGRA built index as a base layer HNSW index to an output stream NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<half, uint32_t> &index,
- std::optional<raft::host_matrix_view<const half, int64_t, raft::row_major>> dataset = std::nullopt
Save a CAGRA build index in hnswlib base-layer-only serialized format NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset = std::nullopt
Write the CAGRA built index as a base layer HNSW index to an output stream NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<int8_t, uint32_t> &index,
- std::optional<raft::host_matrix_view<const int8_t, int64_t, raft::row_major>> dataset = std::nullopt
Save a CAGRA build index in hnswlib base-layer-only serialized format NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- std::ostream &os,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset = std::nullopt
Write the CAGRA built index as a base layer HNSW index to an output stream NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create an output stream std::ostream os(std::cout.rdbuf()); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index);
- Parameters:
handle – [in] the raft handle
os – [in] output stream
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.
- void serialize_to_hnswlib(
- raft::resources const &handle,
- const std::string &filename,
- const cuvs::neighbors::cagra::index<uint8_t, uint32_t> &index,
- std::optional<raft::host_matrix_view<const uint8_t, int64_t, raft::row_major>> dataset = std::nullopt
Save a CAGRA build index in hnswlib base-layer-only serialized format NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, as the serialization format is not compatible with the original hnswlib.
Experimental, both the API and the serialization format are subject to change.
#include <raft/core/resources.hpp> #include <cuvs/neighbors/cagra.hpp> raft::resources handle; // create a string with a filepath std::string filename("/path/to/index"); // create an index with `auto index = cuvs::neighbors::cagra::build(...);` cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index);
- Parameters:
handle – [in] the raft handle
filename – [in] the file name for saving the index
index – [in] CAGRA index
dataset – [in] [optional] host array that stores the dataset, required if the index does not contain the dataset.