buffer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
26 
27 #include <stdint.h>
28 
29 #include <cstddef>
30 #include <iterator>
31 #include <memory>
32 #include <utility>
33 #include <variant>
34 
35 namespace raft_proto {
40 template <typename T>
41 struct buffer {
42  using index_type = std::size_t;
43  using value_type = T;
44 
45  using data_store = std::variant<non_owning_buffer<device_type::cpu, T>,
49 
50  buffer() : device_{}, data_{}, size_{}, cached_ptr{nullptr} {}
51 
54  device_type mem_type = device_type::cpu,
55  int device = 0,
56  cuda_stream stream = 0)
57  : device_{[mem_type, &device]() {
58  auto result = device_id_variant{};
59  switch (mem_type) {
62  }
63  return result;
64  }()},
65  data_{[this, mem_type, size, stream]() {
66  auto result = data_store{};
67  switch (mem_type) {
68  case device_type::cpu: result = owning_buffer<device_type::cpu, T>{size}; break;
69  case device_type::gpu:
70  result = owning_buffer<device_type::gpu, T>{std::get<1>(device_), size, stream};
71  break;
72  }
73  return result;
74  }()},
75  size_{size},
76  cached_ptr{[this]() {
77  auto result = static_cast<T*>(nullptr);
78  switch (data_.index()) {
79  case 0: result = std::get<0>(data_).get(); break;
80  case 1: result = std::get<1>(data_).get(); break;
81  case 2: result = std::get<2>(data_).get(); break;
82  case 3: result = std::get<3>(data_).get(); break;
83  }
84  return result;
85  }()}
86  {
87  }
88 
90  buffer(T* input_data, index_type size, device_type mem_type = device_type::cpu, int device = 0)
91  : device_{[mem_type, &device]() {
92  auto result = device_id_variant{};
93  switch (mem_type) {
94  case device_type::cpu: result = device_id<device_type::cpu>{device}; break;
95  case device_type::gpu: result = device_id<device_type::gpu>{device}; break;
96  }
97  return result;
98  }()},
99  data_{[input_data, mem_type]() {
100  auto result = data_store{};
101  switch (mem_type) {
102  case device_type::cpu: result = non_owning_buffer<device_type::cpu, T>{input_data}; break;
103  case device_type::gpu: result = non_owning_buffer<device_type::gpu, T>{input_data}; break;
104  }
105  return result;
106  }()},
107  size_{size},
108  cached_ptr{[this]() {
109  auto result = static_cast<T*>(nullptr);
110  switch (data_.index()) {
111  case 0: result = std::get<0>(data_).get(); break;
112  case 1: result = std::get<1>(data_).get(); break;
113  case 2: result = std::get<2>(data_).get(); break;
114  case 3: result = std::get<3>(data_).get(); break;
115  }
116  return result;
117  }()}
118  {
119  }
120 
127  buffer(buffer<T> const& other,
128  device_type mem_type,
129  int device = 0,
130  cuda_stream stream = cuda_stream{})
131  : device_{[mem_type, &device]() {
132  auto result = device_id_variant{};
133  switch (mem_type) {
134  case device_type::cpu: result = device_id<device_type::cpu>{device}; break;
135  case device_type::gpu: result = device_id<device_type::gpu>{device}; break;
136  }
137  return result;
138  }()},
139  data_{[this, &other, mem_type, stream]() {
140  auto result = data_store{};
141  auto result_data = static_cast<T*>(nullptr);
142  if (mem_type == device_type::cpu) {
143  auto buf = owning_buffer<device_type::cpu, T>(other.size());
144  result_data = buf.get();
145  result = std::move(buf);
146  } else if (mem_type == device_type::gpu) {
147  auto buf = owning_buffer<device_type::gpu, T>(std::get<1>(device_), other.size(), stream);
148  result_data = buf.get();
149  result = std::move(buf);
150  }
151  copy(result_data, other.data(), other.size(), mem_type, other.memory_type(), stream);
152  return result;
153  }()},
154  size_{other.size()},
155  cached_ptr{[this]() {
156  auto result = static_cast<T*>(nullptr);
157  switch (data_.index()) {
158  case 0: result = std::get<0>(data_).get(); break;
159  case 1: result = std::get<1>(data_).get(); break;
160  case 2: result = std::get<2>(data_).get(); break;
161  case 3: result = std::get<3>(data_).get(); break;
162  }
163  return result;
164  }()}
165  {
166  }
167 
172  buffer(buffer<T> const& other, cuda_stream stream = cuda_stream{})
173  : buffer(other, other.memory_type(), other.device_index(), stream)
174  {
175  }
176 
181  friend void swap(buffer<T>& first, buffer<T>& second)
182  {
183  using std::swap;
184  swap(first.device_, second.device_);
185  swap(first.data_, second.data_);
186  swap(first.size_, second.size_);
187  swap(first.cached_ptr, second.cached_ptr);
188  }
190  {
191  auto copy = other;
192  swap(*this, copy);
193  return *this;
194  }
195 
200  buffer(buffer<T>&& other, device_type mem_type, int device, cuda_stream stream)
201  : device_{[mem_type, &device]() {
202  auto result = device_id_variant{};
203  switch (mem_type) {
204  case device_type::cpu: result = device_id<device_type::cpu>{device}; break;
205  case device_type::gpu: result = device_id<device_type::gpu>{device}; break;
206  }
207  return result;
208  }()},
209  data_{[&other, mem_type, device, stream]() {
210  auto result = data_store{};
211  if (mem_type == other.memory_type() && device == other.device_index()) {
212  result = std::move(other.data_);
213  } else {
214  auto* result_data = static_cast<T*>(nullptr);
215  if (mem_type == device_type::cpu) {
216  auto buf = owning_buffer<device_type::cpu, T>{other.size()};
217  result_data = buf.get();
218  result = std::move(buf);
219  } else if (mem_type == device_type::gpu) {
220  auto buf = owning_buffer<device_type::gpu, T>{device, other.size(), stream};
221  result_data = buf.get();
222  result = std::move(buf);
223  }
224  copy(result_data, other.data(), other.size(), mem_type, other.memory_type(), stream);
225  }
226  return result;
227  }()},
228  size_{other.size()},
229  cached_ptr{[this]() {
230  auto result = static_cast<T*>(nullptr);
231  switch (data_.index()) {
232  case 0: result = std::get<0>(data_).get(); break;
233  case 1: result = std::get<1>(data_).get(); break;
234  case 2: result = std::get<2>(data_).get(); break;
235  case 3: result = std::get<3>(data_).get(); break;
236  }
237  return result;
238  }()}
239  {
240  }
241  buffer(buffer<T>&& other, device_type mem_type, int device)
242  : buffer{std::move(other), mem_type, device, cuda_stream{}}
243  {
244  }
245  buffer(buffer<T>&& other, device_type mem_type)
246  : buffer{std::move(other), mem_type, 0, cuda_stream{}}
247  {
248  }
249 
250  buffer(buffer<T>&& other) noexcept
251  : buffer{std::move(other), other.memory_type(), other.device_index(), cuda_stream{}}
252  {
253  }
254  buffer<T>& operator=(buffer<T>&& other) noexcept
255  {
256  data_ = std::move(other.data_);
257  device_ = std::move(other.device_);
258  size_ = std::move(other.size_);
259  cached_ptr = std::move(other.cached_ptr);
260  return *this;
261  }
262 
263  template <
264  typename iter_t,
265  typename = decltype(*std::declval<iter_t&>(), void(), ++std::declval<iter_t&>(), void())>
266  buffer(iter_t const& begin, iter_t const& end)
267  : buffer{static_cast<size_t>(std::distance(begin, end))}
268  {
269  auto index = std::size_t{};
270  std::for_each(begin, end, [&index, this](auto&& val) { data()[index++] = val; });
271  }
272 
273  template <
274  typename iter_t,
275  typename = decltype(*std::declval<iter_t&>(), void(), ++std::declval<iter_t&>(), void())>
276  buffer(iter_t const& begin, iter_t const& end, device_type mem_type)
277  : buffer{buffer{begin, end}, mem_type}
278  {
279  }
280 
281  template <
282  typename iter_t,
283  typename = decltype(*std::declval<iter_t&>(), void(), ++std::declval<iter_t&>(), void())>
284  buffer(iter_t const& begin,
285  iter_t const& end,
286  device_type mem_type,
287  int device,
288  cuda_stream stream = cuda_stream{})
289  : buffer{buffer{begin, end}, mem_type, device, stream}
290  {
291  }
292 
293  auto size() const noexcept { return size_; }
294  HOST DEVICE auto* data() const noexcept { return cached_ptr; }
295  auto memory_type() const noexcept
296  {
297  auto result = device_type{};
298  if (device_.index() == 0) {
299  result = device_type::cpu;
300  } else {
301  result = device_type::gpu;
302  }
303  return result;
304  }
305 
306  auto device() const noexcept { return device_; }
307 
308  auto device_index() const noexcept
309  {
310  auto result = int{};
311  switch (device_.index()) {
312  case 0: result = std::get<0>(device_).value(); break;
313  case 1: result = std::get<1>(device_).value(); break;
314  }
315  return result;
316  }
317  ~buffer() = default;
318 
319  private:
320  device_id_variant device_;
321  data_store data_;
322  index_type size_;
323  T* cached_ptr;
324 };
325 
326 template <bool bounds_check, typename T, typename U>
328  buffer<U> const& src,
329  typename buffer<T>::index_type dst_offset,
330  typename buffer<U>::index_type src_offset,
331  typename buffer<T>::index_type size,
332  cuda_stream stream)
333 {
334  if constexpr (bounds_check) {
335  if (src.size() - src_offset < size || dst.size() - dst_offset < size) {
336  throw out_of_bounds("Attempted copy to or from buffer of inadequate size");
337  }
338  }
339  copy(dst.data() + dst_offset,
340  src.data() + src_offset,
341  size,
342  dst.memory_type(),
343  src.memory_type(),
344  stream);
345 }
346 
347 template <bool bounds_check, typename T, typename U>
349 {
350  copy<bounds_check>(dst, src, 0, 0, src.size(), stream);
351 }
352 template <bool bounds_check, typename T, typename U>
354 {
355  copy<bounds_check>(dst, src, 0, 0, src.size(), cuda_stream{});
356 }
357 
358 template <bool bounds_check, typename T, typename U>
360  buffer<U>&& src,
361  typename buffer<T>::index_type dst_offset,
362  typename buffer<U>::index_type src_offset,
363  typename buffer<T>::index_type size,
364  cuda_stream stream)
365 {
366  if constexpr (bounds_check) {
367  if (src.size() - src_offset < size || dst.size() - dst_offset < size) {
368  throw out_of_bounds("Attempted copy to or from buffer of inadequate size");
369  }
370  }
371  copy(dst.data() + dst_offset,
372  src.data() + src_offset,
373  size,
374  dst.memory_type(),
375  src.memory_type(),
376  stream);
377 }
378 
379 template <bool bounds_check, typename T, typename U>
381  buffer<U>&& src,
382  typename buffer<T>::index_type dst_offset,
383  cuda_stream stream)
384 {
385  copy<bounds_check>(dst, src, dst_offset, 0, src.size(), stream);
386 }
387 
388 template <bool bounds_check, typename T, typename U>
390 {
391  copy<bounds_check>(dst, src, 0, 0, src.size(), stream);
392 }
393 template <bool bounds_check, typename T, typename U>
395 {
396  copy<bounds_check>(dst, src, 0, 0, src.size(), cuda_stream{});
397 }
398 
399 } // namespace raft_proto
#define DEVICE
Definition: gpu_support.hpp:35
#define HOST
Definition: gpu_support.hpp:34
Definition: buffer.hpp:35
const_agnostic_same_t< T, U > copy(buffer< T > &&dst, buffer< U > &&src)
Definition: buffer.hpp:394
int cuda_stream
Definition: cuda_stream.hpp:25
std::enable_if_t< std::is_same_v< std::remove_const_t< T >, std::remove_const_t< U > >, V > const_agnostic_same_t
Definition: const_agnostic.hpp:22
device_type
Definition: device_type.hpp:18
std::variant< device_id< device_type::cpu >, device_id< device_type::gpu > > device_id_variant
Definition: device_id.hpp:31
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:41
T value_type
Definition: buffer.hpp:43
buffer(iter_t const &begin, iter_t const &end, device_type mem_type, int device, cuda_stream stream=cuda_stream{})
Definition: buffer.hpp:284
buffer(buffer< T > const &other, device_type mem_type, int device=0, cuda_stream stream=cuda_stream{})
Construct one buffer from another in the given memory location (either on host or on device) A buffer...
Definition: buffer.hpp:127
buffer(iter_t const &begin, iter_t const &end)
Definition: buffer.hpp:266
std::variant< non_owning_buffer< device_type::cpu, T >, non_owning_buffer< device_type::gpu, T >, owning_buffer< device_type::cpu, T >, owning_buffer< device_type::gpu, T > > data_store
Definition: buffer.hpp:48
buffer(buffer< T > const &other, cuda_stream stream=cuda_stream{})
Create owning copy of existing buffer with given stream The memory type of this new buffer will be th...
Definition: buffer.hpp:172
buffer()
Definition: buffer.hpp:50
auto size() const noexcept
Definition: buffer.hpp:293
buffer(buffer< T > &&other, device_type mem_type, int device)
Definition: buffer.hpp:241
friend void swap(buffer< T > &first, buffer< T > &second)
Create owning copy of existing buffer The memory type of this new buffer will be the same as the orig...
Definition: buffer.hpp:181
buffer(index_type size, device_type mem_type=device_type::cpu, int device=0, cuda_stream stream=0)
Definition: buffer.hpp:53
buffer< T > & operator=(buffer< T > &&other) noexcept
Definition: buffer.hpp:254
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
buffer(iter_t const &begin, iter_t const &end, device_type mem_type)
Definition: buffer.hpp:276
std::size_t index_type
Definition: buffer.hpp:42
auto memory_type() const noexcept
Definition: buffer.hpp:295
buffer(T *input_data, index_type size, device_type mem_type=device_type::cpu, int device=0)
Definition: buffer.hpp:90
buffer< T > & operator=(buffer< T > const &other)
Definition: buffer.hpp:189
buffer(buffer< T > &&other) noexcept
Definition: buffer.hpp:250
buffer(buffer< T > &&other, device_type mem_type)
Definition: buffer.hpp:245
auto device_index() const noexcept
Definition: buffer.hpp:308
auto device() const noexcept
Definition: buffer.hpp:306
buffer(buffer< T > &&other, device_type mem_type, int device, cuda_stream stream)
Move from existing buffer unless a copy is necessary based on memory location.
Definition: buffer.hpp:200
Definition: base.hpp:22
Definition: base.hpp:27
Definition: exceptions.hpp:29