
The vector search and clustering algorithms in RAFT are being migrated to a new library dedicated to vector search called cuVS. We will continue to support the vector search algorithms in RAFT during this move, but will no longer update them after the RAPIDS 24.06 (June) release. We plan to complete the migration by RAPIDS 24.10 (October) release and they will be removed from RAFT altogether in the 24.12 (December) release.


This page provides pylibraft class references for the publicly-exposed elements of the pylibraft.matrix package.

pylibraft.matrix.select_k(dataset, k=None, distances=None, indices=None, select_min=True, handle=None)[source]#

Selects the top k items from each row in a matrix

datasetarray interface compliant matrix, row-major layout,

shape (n_rows, dim). Supported dtype [float]


Number of items to return for each row. Optional if indices or distances arrays are given (in which case their second dimension is k).

distancesOptional array interface compliant matrix shape

(n_rows, k), dtype float. If supplied, distances will be written here in-place. (default None)

indicesOptional array interface compliant matrix shape

(n_rows, k), dtype int64_t. If supplied, neighbor indices will be written here in-place. (default None)


Whether to select the minimum or maximum K items

handleOptional RAFT resource handle for reusing CUDA resources.

If a handle isn’t supplied, CUDA resources will be allocated inside this function and synchronized before the function exits. If a handle is supplied, you will need to explicitly synchronize yourself by calling handle.sync() before accessing the output.

distances: array interface compliant object containing resulting distances

shape (n_rows, k)

indices: array interface compliant object containing resulting indices

shape (n_rows, k)


>>> import cupy as cp
>>> from pylibraft.matrix import select_k
>>> n_features = 50
>>> n_rows = 1000
>>> queries = cp.random.random_sample((n_rows, n_features),
...                                   dtype=cp.float32)
>>> k = 40
>>> distances, ids = select_k(queries, k)
>>> distances = cp.asarray(distances)
>>> ids = cp.asarray(ids)