utils.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include "../condensed_hierarchy.cu"
20 
21 #include <common/fast_int_div.cuh>
22 
23 #include <cuml/cluster/hdbscan.hpp>
24 
25 #include <raft/core/device_mdspan.hpp>
26 #include <raft/label/classlabels.cuh>
27 #include <raft/linalg/matrix_vector_op.cuh>
28 #include <raft/linalg/norm.cuh>
29 #include <raft/sparse/convert/csr.cuh>
30 #include <raft/sparse/op/sort.cuh>
31 #include <raft/util/cudart_utils.hpp>
32 
33 #include <rmm/device_uvector.hpp>
34 #include <rmm/exec_policy.hpp>
35 
36 #include <cub/cub.cuh>
37 #include <thrust/copy.h>
38 #include <thrust/execution_policy.h>
39 #include <thrust/for_each.h>
40 #include <thrust/functional.h>
41 #include <thrust/iterator/zip_iterator.h>
42 #include <thrust/reduce.h>
43 #include <thrust/sort.h>
44 #include <thrust/transform.h>
45 #include <thrust/transform_reduce.h>
46 #include <thrust/tuple.h>
47 
48 #include <algorithm>
49 
50 namespace ML {
51 namespace HDBSCAN {
52 namespace detail {
53 namespace Utils {
54 
68 template <typename value_idx, typename value_t, typename CUBReduceFunc>
69 void cub_segmented_reduce(const value_t* in,
70  value_t* out,
71  int n_segments,
72  const value_idx* offsets,
73  cudaStream_t stream,
74  CUBReduceFunc cub_reduce_func)
75 {
76  rmm::device_uvector<char> d_temp_storage(0, stream);
77  size_t temp_storage_bytes = 0;
78  cub_reduce_func(
79  nullptr, temp_storage_bytes, in, out, n_segments, offsets, offsets + 1, stream, false);
80  d_temp_storage.resize(temp_storage_bytes, stream);
81 
82  cub_reduce_func(d_temp_storage.data(),
83  temp_storage_bytes,
84  in,
85  out,
86  n_segments,
87  offsets,
88  offsets + 1,
89  stream,
90  false);
91 }
92 
102 template <typename value_idx, typename value_t>
104  const raft::handle_t& handle, Common::CondensedHierarchy<value_idx, value_t>& condensed_tree)
105 {
106  auto stream = handle.get_stream();
107  auto thrust_policy = handle.get_thrust_policy();
108  auto parents = condensed_tree.get_parents();
109  auto children = condensed_tree.get_children();
110  auto lambdas = condensed_tree.get_lambdas();
111  auto sizes = condensed_tree.get_sizes();
112 
113  value_idx cluster_tree_edges = thrust::transform_reduce(
114  thrust_policy,
115  sizes,
116  sizes + condensed_tree.get_n_edges(),
117  cuda::proclaim_return_type<bool>([=] __device__(value_idx a) -> bool { return a > 1; }),
118  0,
119  thrust::plus<value_idx>());
120 
121  // remove leaves from condensed tree
122  rmm::device_uvector<value_idx> cluster_parents(cluster_tree_edges, stream);
123  rmm::device_uvector<value_idx> cluster_children(cluster_tree_edges, stream);
124  rmm::device_uvector<value_t> cluster_lambdas(cluster_tree_edges, stream);
125  rmm::device_uvector<value_idx> cluster_sizes(cluster_tree_edges, stream);
126 
127  auto in = thrust::make_zip_iterator(thrust::make_tuple(parents, children, lambdas, sizes));
128 
129  auto out = thrust::make_zip_iterator(thrust::make_tuple(
130  cluster_parents.data(), cluster_children.data(), cluster_lambdas.data(), cluster_sizes.data()));
131 
132  thrust::copy_if(thrust_policy,
133  in,
134  in + (condensed_tree.get_n_edges()),
135  sizes,
136  out,
137  [=] __device__(value_idx a) { return a > 1; });
138 
139  auto n_leaves = condensed_tree.get_n_leaves();
140  thrust::transform(thrust_policy,
141  cluster_parents.begin(),
142  cluster_parents.end(),
143  cluster_parents.begin(),
144  [n_leaves] __device__(value_idx a) { return a - n_leaves; });
145  thrust::transform(thrust_policy,
146  cluster_children.begin(),
147  cluster_children.end(),
148  cluster_children.begin(),
149  [n_leaves] __device__(value_idx a) { return a - n_leaves; });
150 
152  condensed_tree.get_n_leaves(),
153  cluster_tree_edges,
154  condensed_tree.get_n_clusters(),
155  std::move(cluster_parents),
156  std::move(cluster_children),
157  std::move(cluster_lambdas),
158  std::move(cluster_sizes));
159 }
160 
170 template <typename value_idx, typename value_t>
171 void parent_csr(const raft::handle_t& handle,
173  value_idx* sorted_parents,
174  value_idx* indptr)
175 {
176  auto stream = handle.get_stream();
177  auto thrust_policy = handle.get_thrust_policy();
178 
179  auto children = condensed_tree.get_children();
180  auto sizes = condensed_tree.get_sizes();
181  auto n_edges = condensed_tree.get_n_edges();
182  auto n_leaves = condensed_tree.get_n_leaves();
183  auto n_clusters = condensed_tree.get_n_clusters();
184 
185  // 0-index sorted parents by subtracting n_leaves for offsets and birth/stability indexing
186  auto index_op = [n_leaves] __device__(const auto& x) { return x - n_leaves; };
188  thrust_policy, sorted_parents, sorted_parents + n_edges, sorted_parents, index_op);
189 
190  raft::sparse::convert::sorted_coo_to_csr(sorted_parents, n_edges, indptr, n_clusters + 1, stream);
191 }
192 
193 template <typename value_idx, typename value_t>
194 void normalize(value_t* data, value_idx n, size_t m, cudaStream_t stream)
195 {
196  rmm::device_uvector<value_t> sums(m, stream);
197 
198  // Compute row sums
199  raft::linalg::rowNorm<value_t, size_t>(
200  sums.data(), data, (size_t)n, m, raft::linalg::L1Norm, true, stream);
201 
202  // Divide vector by row sums (modify in place)
203  raft::linalg::matrixVectorOp(
204  data,
205  const_cast<value_t*>(data),
206  sums.data(),
207  n,
208  (value_idx)m,
209  true,
210  false,
211  [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
212  stream);
213 }
214 
225 template <typename value_idx, typename value_t>
226 void softmax(const raft::handle_t& handle, value_t* data, value_idx n, size_t m)
227 {
228  rmm::device_uvector<value_t> linf_norm(m, handle.get_stream());
229 
230  auto data_const_view =
231  raft::make_device_matrix_view<const value_t, value_idx, raft::row_major>(data, (int)m, n);
232  auto data_view =
233  raft::make_device_matrix_view<value_t, value_idx, raft::row_major>(data, (int)m, n);
234  auto linf_norm_const_view =
235  raft::make_device_vector_view<const value_t, value_idx>(linf_norm.data(), (int)m);
236  auto linf_norm_view = raft::make_device_vector_view<value_t, value_idx>(linf_norm.data(), (int)m);
237 
238  raft::linalg::norm(handle,
239  data_const_view,
240  linf_norm_view,
241  raft::linalg::LinfNorm,
242  raft::linalg::Apply::ALONG_ROWS);
243 
244  raft::linalg::matrix_vector_op(
245  handle,
246  data_const_view,
247  linf_norm_const_view,
248  data_view,
249  raft::linalg::Apply::ALONG_COLUMNS,
250  [] __device__(value_t mat_in, value_t vec_in) { return exp(mat_in - vec_in); });
251 }
252 
253 }; // namespace Utils
254 }; // namespace detail
255 }; // namespace HDBSCAN
256 }; // namespace ML
Definition: hdbscan.hpp:39
value_idx * get_sizes()
Definition: hdbscan.hpp:117
value_t * get_lambdas()
Definition: hdbscan.hpp:116
value_idx get_n_leaves() const
Definition: hdbscan.hpp:120
value_idx get_n_edges()
Definition: hdbscan.hpp:118
value_idx * get_children()
Definition: hdbscan.hpp:115
int get_n_clusters()
Definition: hdbscan.hpp:119
value_idx * get_parents()
Definition: hdbscan.hpp:114
Common::CondensedHierarchy< value_idx, value_t > make_cluster_tree(const raft::handle_t &handle, Common::CondensedHierarchy< value_idx, value_t > &condensed_tree)
Definition: utils.h:103
void softmax(const raft::handle_t &handle, value_t *data, value_idx n, size_t m)
Definition: utils.h:226
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:194
void cub_segmented_reduce(const value_t *in, value_t *out, int n_segments, const value_idx *offsets, cudaStream_t stream, CUBReduceFunc cub_reduce_func)
Definition: utils.h:69
void parent_csr(const raft::handle_t &handle, Common::CondensedHierarchy< value_idx, value_t > &condensed_tree, value_idx *sorted_parents, value_idx *indptr)
Definition: utils.h:171
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:30