handle.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
7 
8 #include <algorithm>
9 #include <cstddef>
10 #ifdef CUML_ENABLE_GPU
11 #include <raft/core/handle.hpp>
12 #endif
13 
14 namespace raft_proto {
15 #ifdef CUML_ENABLE_GPU
16 struct handle_t {
17  handle_t(raft::handle_t const* handle_ptr = nullptr) : raft_handle_{handle_ptr} {}
18  handle_t(raft::handle_t const& raft_handle) : raft_handle_{&raft_handle} {}
19  auto get_next_usable_stream() const
20  {
21  return raft_proto::cuda_stream{raft_handle_->get_next_usable_stream().value()};
22  }
23  auto get_stream_pool_size() const { return raft_handle_->get_stream_pool_size(); }
24  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
25  void synchronize() const
26  {
27  raft_handle_->sync_stream_pool();
28  raft_handle_->sync_stream();
29  }
30 
31  private:
32  // Have to store a pointer because handle is not movable
33  raft::handle_t const* raft_handle_;
34 };
35 #else
36 struct handle_t {
38  auto get_stream_pool_size() const { return std::size_t{}; }
39  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
40  void synchronize() const {}
41 };
42 #endif
43 } // namespace raft_proto
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
Definition: buffer.hpp:24
int cuda_stream
Definition: cuda_stream.hpp:14
Definition: handle.hpp:36
auto get_usable_stream_count() const
Definition: handle.hpp:39
auto get_next_usable_stream() const
Definition: handle.hpp:37
auto get_stream_pool_size() const
Definition: handle.hpp:38
void synchronize() const
Definition: handle.hpp:40