thrust_allocator_adaptor.hpp
1 /*
2  * Copyright (c) 2019-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <rmm/mr/device/device_memory_resource.hpp>
21 
22 #include <thrust/detail/type_traits/pointer_traits.h>
23 #include <thrust/device_malloc_allocator.h>
24 
25 namespace rmm::mr {
36 template <typename T>
37 class thrust_allocator : public thrust::device_malloc_allocator<T> {
38  public:
39  using Base = thrust::device_malloc_allocator<T>;
40  using pointer = typename Base::pointer;
41  using size_type = typename Base::size_type;
42 
49  template <typename U>
50  struct rebind {
51  using other = thrust_allocator<U>;
52  };
53 
58  thrust_allocator() = default;
59 
67 
76  {
77  }
78 
84  template <typename U>
86  : _mr(other.resource()), _stream{other.stream()}
87  {
88  }
89 
96  pointer allocate(size_type num)
97  {
98  return thrust::device_pointer_cast(static_cast<T*>(_mr->allocate(num * sizeof(T), _stream)));
99  }
100 
108  void deallocate(pointer ptr, size_type num)
109  {
110  return _mr->deallocate(thrust::raw_pointer_cast(ptr), num * sizeof(T), _stream);
111  }
112 
116  [[nodiscard]] device_memory_resource* resource() const noexcept { return _mr; }
117 
121  [[nodiscard]] cuda_stream_view stream() const noexcept { return _stream; }
122 
123  private:
124  cuda_stream_view _stream{};
125  device_memory_resource* _mr{rmm::mr::get_current_device_resource()};
126 };
127 } // namespace rmm::mr
rmm::mr::thrust_allocator::thrust_allocator
thrust_allocator(cuda_stream_view stream, device_memory_resource *mr)
Constructs a thrust_allocator using a device memory resource and stream.
Definition: thrust_allocator_adaptor.hpp:75
rmm::mr::thrust_allocator
An allocator compatible with Thrust containers and algorithms using a device_memory_resource for memo...
Definition: thrust_allocator_adaptor.hpp:37
rmm::mr::device_memory_resource::allocate
void * allocate(std::size_t bytes, cuda_stream_view stream=cuda_stream_view{})
Allocates memory of size at least bytes.
Definition: device_memory_resource.hpp:106
per_device_resource.hpp
Management of per-device device_memory_resources.
rmm::mr::thrust_allocator::allocate
pointer allocate(size_type num)
Allocate objects of type T
Definition: thrust_allocator_adaptor.hpp:96
rmm::cuda_stream_view
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:34
rmm::mr::thrust_allocator::thrust_allocator
thrust_allocator(thrust_allocator< U > const &other)
Copy constructor. Copies the resource pointer and stream.
Definition: thrust_allocator_adaptor.hpp:85
rmm::mr::device_memory_resource::deallocate
void deallocate(void *ptr, std::size_t bytes, cuda_stream_view stream=cuda_stream_view{})
Deallocate memory pointed to by p.
Definition: device_memory_resource.hpp:129
rmm::mr::thrust_allocator::rebind
Provides the type of a thrust_allocator instantiated with another type.
Definition: thrust_allocator_adaptor.hpp:50
rmm::mr::thrust_allocator::stream
cuda_stream_view stream() const noexcept
Returns the stream used by this allocator.
Definition: thrust_allocator_adaptor.hpp:121
rmm::mr::thrust_allocator::thrust_allocator
thrust_allocator()=default
Default constructor creates an allocator using the default memory resource and default stream.
rmm::mr::thrust_allocator::thrust_allocator
thrust_allocator(cuda_stream_view stream)
Constructs a thrust_allocator using the default device memory resource and specified stream.
Definition: thrust_allocator_adaptor.hpp:66
rmm::mr::thrust_allocator::deallocate
void deallocate(pointer ptr, size_type num)
Deallocates objects of type T
Definition: thrust_allocator_adaptor.hpp:108
rmm::mr::device_memory_resource
Base class for all libcudf device memory allocation.
Definition: device_memory_resource.hpp:82
rmm::mr::thrust_allocator::resource
device_memory_resource * resource() const noexcept
Returns the device memory resource used by this allocator.
Definition: thrust_allocator_adaptor.hpp:116