gpu.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
10 
11 #include <rmm/device_buffer.hpp>
12 
13 #include <cuda_runtime_api.h>
14 
15 #include <type_traits>
16 
17 namespace raft_proto {
18 namespace detail {
19 template <typename T>
21  // TODO(wphicks): Assess need for buffers of const T
22  using value_type = std::remove_const_t<T>;
23  owning_buffer() : data_{} {}
24 
26  std::size_t size,
27  cudaStream_t stream) noexcept(false)
28  : data_{[&device_id, &size, &stream]() {
29  auto device_context = device_setter{device_id};
30  return rmm::device_buffer{size * sizeof(value_type), rmm::cuda_stream_view{stream}};
31  }()}
32  {
33  }
34 
35  auto* get() const { return reinterpret_cast<T*>(data_.data()); }
36 
37  private:
38  mutable rmm::device_buffer data_;
39 };
40 } // namespace detail
41 } // namespace raft_proto
Definition: buffer.hpp:24
device_type
Definition: device_type.hpp:7
Definition: base.hpp:11
Definition: base.hpp:14
std::remove_const_t< T > value_type
Definition: gpu.hpp:22
owning_buffer(device_id< device_type::gpu > device_id, std::size_t size, cudaStream_t stream) noexcept(false)
Definition: gpu.hpp:25
Definition: base.hpp:16