Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
handle.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <algorithm>
18 #include <cstddef>
20 #ifdef CUML_ENABLE_GPU
21 #include <raft/core/handle.hpp>
22 #endif
23 
24 namespace raft_proto {
25 #ifdef CUML_ENABLE_GPU
26 struct handle_t {
27  handle_t(raft::handle_t const* handle_ptr = nullptr) : raft_handle_{handle_ptr} {}
28  handle_t(raft::handle_t const& raft_handle) : raft_handle_{&raft_handle} {}
29  auto get_next_usable_stream() const
30  {
31  return raft_proto::cuda_stream{raft_handle_->get_next_usable_stream().value()};
32  }
33  auto get_stream_pool_size() const { return raft_handle_->get_stream_pool_size(); }
34  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
35  void synchronize() const
36  {
37  raft_handle_->sync_stream_pool();
38  raft_handle_->sync_stream();
39  }
40 
41  private:
42  // Have to store a pointer because handle is not movable
43  raft::handle_t const* raft_handle_;
44 };
45 #else
46 struct handle_t {
48  auto get_stream_pool_size() const { return std::size_t{}; }
49  auto get_usable_stream_count() const { return std::max(get_stream_pool_size(), std::size_t{1}); }
50  void synchronize() const {}
51 };
52 #endif
53 } // namespace raft_proto
math_t max(math_t a, math_t b)
Definition: learning_rate.h:26
Definition: buffer.hpp:33
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: handle.hpp:46
auto get_usable_stream_count() const
Definition: handle.hpp:49
auto get_next_usable_stream() const
Definition: handle.hpp:47
auto get_stream_pool_size() const
Definition: handle.hpp:48
void synchronize() const
Definition: handle.hpp:50