Attention

The vector search and clustering algorithms in RAFT are being migrated to a new library dedicated to vector search called cuVS. We will continue to support the vector search algorithms in RAFT during this move, but will no longer update them after the RAPIDS 24.06 (June) release. We plan to complete the migration by RAPIDS 24.08 (August) release.

memory_type_dispatcher#

#include <raft/util/memory_type_dispatcher.cuh>

template<typename lambda_t, typename mdbuffer_type, enable_if_mdbuffer<mdbuffer_type>* = nullptr>
decltype(auto) memory_type_dispatcher(raft::resources const &res, lambda_t &&f, mdbuffer_type &&buf)#

Dispatch to various specializations of a functor which accepts an mdspan based on the mdspan’s memory type.

This function template is used to dispatch to one or more implementations of a function based on memory type. For instance, if a functor has been implemented with an operator that accepts only a device_mdspan, input data can be passed to that functor with minimal copies or allocations by wrapping the functor in this template.

More specifically, host memory data will be copied to device before being passed to the functor as a device_mdspan. Device, managed, and pinned data will be passed directly to the functor as a device_mdspan.

If the functor’s operator were also specialized for host_mdspan, then this wrapper would pass an input host_mdspan directly to the corresponding specialization.

If a functor explicitly specializes for managed/pinned memory and receives managed/pinned input, the corresponding specialization will be invoked. If the functor does not specialize for either, it will preferentially invoke the device specialization if available and then the host specialization. Managed input will never be dispatched to an explicit specialization for pinned memory and vice versa.

Dispatching is performed by coercing the input mdspan to an mdbuffer of the correct type. If it is necessary to coerce the input data to a different data type (e.g. floats to doubles) or to a different memory layout, this can be done by passing an explicit mdbuffer type to the memory_type_dispatcher template.

Usage example:

// Functor which accepts only a `device_mdspan` or `managed_mdspan` of
// doubles in C-contiguous layout. We wish to be able to call this
// functor on any compatible data, regardless of data type, memory type,
// or layout.
struct functor {
   auto operator()(device_matrix_view<double> data) {
     // Do something with data on device
   };
   auto operator()(managed_matrix_view<double> data) {
     // Do something with data, taking advantage of knowledge that
     // underlying memory is managed
   };
};

auto rows = 3;
auto cols = 5;
auto res = raft::device_resources{};

auto host_data = raft::make_host_matrix<double>(rows, cols);
// functor{}(host_data.view()); // This would fail to compile
auto device_data = raft::make_device_matrix<double>(res, rows, cols);
functor{}(device_data.view()); // Functor accepts device mdspan
auto managed_data = raft::make_managed_matrix<double>(res, rows, cols);
// functor{}(managed_data.view()); // Functor accepts managed mdspan
auto pinned_data = raft::make_managed_matrix<double>(res, rows, cols);
functor{}(pinned_data.view()); // This would fail to compile
auto float_data = raft::make_device_matrix<float>(res, rows, cols);
// functor{}(float_data.view()); // This would fail to compile
auto f_data = raft::make_device_matrix<double, int, raft::layout_f_contiguous>(res, rows, cols);
// functor{}(f_data.view()); // This would fail to compile

// `memory_type_dispatcher` lets us call this functor on all of the above
raft::memory_type_dispatcher(res, functor{}, host_data.view());
raft::memory_type_dispatcher(res, functor{}, device_data.view());
raft::memory_type_dispatcher(res, functor{}, managed_data.view());
raft::memory_type_dispatcher(res, functor{}, pinned_data.view());
// Here, we use the mdbuffer type template parameter to ensure that the data
// type and layout are as expected by the functor
raft::memory_type_dispatcher<raft::mdbuffer<double, matrix_extents<int>>>(res, functor{},
float_data.view()); raft::memory_type_dispatcher<raft::mdbuffer<double,
matrix_extents<int>>>(res, functor{}, f_data.view());

As this example shows, memory_type_dispatcher can be used to dispatch any compatible mdspan input to a functor, regardless of the mdspan type(s) that functor supports.

template<typename lambda_t, typename mdspan_type, enable_if_mdspan<mdspan_type>* = nullptr>
decltype(auto) memory_type_dispatcher(raft::resources const &res, lambda_t &&f, mdspan_type view)#
template<typename mdbuffer_type, typename lambda_t, typename mdspan_type, enable_if_mdbuffer<mdbuffer_type>* = nullptr, enable_if_mdspan<mdspan_type>* = nullptr>
decltype(auto) memory_type_dispatcher(raft::resources const &res, lambda_t &&f, mdspan_type view)#