Loading...
Searching...
No Matches
test_util.cuh
1/*
2 * Copyright (c) 2022, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
20#include <cuspatial/traits.hpp>
21
22#include <rmm/device_uvector.hpp>
23
24#include <thrust/for_each.h>
25#include <thrust/host_vector.h>
26#include <thrust/optional.h>
27
28#include <cstdio>
29#include <iomanip>
30#include <string_view>
31
32namespace cuspatial {
33
34namespace test {
35
47template <typename T, typename Vector>
48thrust::host_vector<T> to_host(Vector const& dvec)
49{
50 if constexpr (std::is_same_v<Vector, rmm::device_uvector<T>>) {
51 thrust::host_vector<T> hvec(dvec.size());
52 cudaMemcpyAsync(hvec.data(),
53 dvec.data(),
54 dvec.size() * sizeof(T),
55 cudaMemcpyKind::cudaMemcpyDeviceToHost,
56 dvec.stream());
57 dvec.stream().synchronize();
58 return hvec;
59 } else {
60 return thrust::host_vector<T>(dvec);
61 }
62}
63
73template <typename Iter, typename T = cuspatial::iterator_value_type<Iter>>
74thrust::host_vector<T> to_host(Iter begin, Iter end)
75{
76 return thrust::host_vector<T>(begin, end);
77}
78
90template <typename Iter>
91void print_device_range(Iter begin,
92 Iter end,
93 std::string_view pre = "",
94 std::string_view post = "\n")
95{
96 auto hvec = to_host(begin, end);
97
98 std::cout << pre;
99 std::for_each(hvec.begin(), hvec.end(), [](auto const& x) { std::cout << x << " "; });
100 std::cout << post;
101}
102
113template <typename Vector>
114void print_device_vector(Vector const& vec, std::string_view pre = "", std::string_view post = "\n")
115{
116 using T = typename Vector::value_type;
117 auto hvec = to_host<T>(vec);
118
119 std::cout << pre;
120 std::for_each(hvec.begin(), hvec.end(), [](auto const& x) { std::cout << x << " "; });
121 std::cout << post;
122}
123
124} // namespace test
125} // namespace cuspatial