handle.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
18 
19 #include <algorithm>
20 #include <cstddef>
21 #ifdef CUML_ENABLE_GPU
22 #include <raft/core/handle.hpp>
23 #endif
24 
25 namespace raft_proto {
26 #ifdef CUML_ENABLE_GPU
27 struct handle_t {
28  handle_t(raft::handle_t const* handle_ptr = nullptr) : raft_handle_{handle_ptr} {}
29  handle_t(raft::handle_t const& raft_handle) : raft_handle_{&raft_handle} {}
30  auto get_next_usable_stream() const
31  {
32  return raft_proto::cuda_stream{raft_handle_->get_next_usable_stream().value()};
33  }
34  auto get_stream_pool_size() const { return raft_handle_->get_stream_pool_size(); }
35  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
36  void synchronize() const
37  {
38  raft_handle_->sync_stream_pool();
39  raft_handle_->sync_stream();
40  }
41 
42  private:
43  // Have to store a pointer because handle is not movable
44  raft::handle_t const* raft_handle_;
45 };
46 #else
47 struct handle_t {
49  auto get_stream_pool_size() const { return std::size_t{}; }
50  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
51  void synchronize() const {}
52 };
53 #endif
54 } // namespace raft_proto
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
Definition: buffer.hpp:35
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: handle.hpp:47
auto get_usable_stream_count() const
Definition: handle.hpp:50
auto get_next_usable_stream() const
Definition: handle.hpp:48
auto get_stream_pool_size() const
Definition: handle.hpp:49
void synchronize() const
Definition: handle.hpp:51