Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
utils.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cub/cub.cuh>
20 
21 #include <raft/util/cudart_utils.hpp>
22 
23 #include <raft/sparse/convert/csr.cuh>
24 #include <raft/sparse/op/sort.cuh>
25 
26 #include <cuml/cluster/hdbscan.hpp>
27 
28 #include <raft/core/device_mdspan.hpp>
29 #include <raft/label/classlabels.cuh>
30 #include <raft/linalg/matrix_vector_op.cuh>
31 #include <raft/linalg/norm.cuh>
32 
33 #include <algorithm>
34 
35 #include "../condensed_hierarchy.cu"
36 #include <common/fast_int_div.cuh>
37 
38 #include <thrust/copy.h>
39 #include <thrust/execution_policy.h>
40 #include <thrust/for_each.h>
41 #include <thrust/functional.h>
42 #include <thrust/iterator/zip_iterator.h>
43 #include <thrust/reduce.h>
44 #include <thrust/sort.h>
45 #include <thrust/transform.h>
46 #include <thrust/transform_reduce.h>
47 #include <thrust/tuple.h>
48 
49 #include <rmm/device_uvector.hpp>
50 #include <rmm/exec_policy.hpp>
51 
52 namespace ML {
53 namespace HDBSCAN {
54 namespace detail {
55 namespace Utils {
56 
70 template <typename value_idx, typename value_t, typename CUBReduceFunc>
71 void cub_segmented_reduce(const value_t* in,
72  value_t* out,
73  int n_segments,
74  const value_idx* offsets,
75  cudaStream_t stream,
76  CUBReduceFunc cub_reduce_func)
77 {
78  rmm::device_uvector<char> d_temp_storage(0, stream);
79  size_t temp_storage_bytes = 0;
80  cub_reduce_func(
81  nullptr, temp_storage_bytes, in, out, n_segments, offsets, offsets + 1, stream, false);
82  d_temp_storage.resize(temp_storage_bytes, stream);
83 
84  cub_reduce_func(d_temp_storage.data(),
85  temp_storage_bytes,
86  in,
87  out,
88  n_segments,
89  offsets,
90  offsets + 1,
91  stream,
92  false);
93 }
94 
104 template <typename value_idx, typename value_t>
106  const raft::handle_t& handle, Common::CondensedHierarchy<value_idx, value_t>& condensed_tree)
107 {
108  auto stream = handle.get_stream();
109  auto thrust_policy = handle.get_thrust_policy();
110  auto parents = condensed_tree.get_parents();
111  auto children = condensed_tree.get_children();
112  auto lambdas = condensed_tree.get_lambdas();
113  auto sizes = condensed_tree.get_sizes();
114 
115  value_idx cluster_tree_edges = thrust::transform_reduce(
116  thrust_policy,
117  sizes,
118  sizes + condensed_tree.get_n_edges(),
119  [=] __device__(value_idx a) { return a > 1; },
120  0,
121  thrust::plus<value_idx>());
122 
123  // remove leaves from condensed tree
124  rmm::device_uvector<value_idx> cluster_parents(cluster_tree_edges, stream);
125  rmm::device_uvector<value_idx> cluster_children(cluster_tree_edges, stream);
126  rmm::device_uvector<value_t> cluster_lambdas(cluster_tree_edges, stream);
127  rmm::device_uvector<value_idx> cluster_sizes(cluster_tree_edges, stream);
128 
129  auto in = thrust::make_zip_iterator(thrust::make_tuple(parents, children, lambdas, sizes));
130 
131  auto out = thrust::make_zip_iterator(thrust::make_tuple(
132  cluster_parents.data(), cluster_children.data(), cluster_lambdas.data(), cluster_sizes.data()));
133 
134  thrust::copy_if(thrust_policy,
135  in,
136  in + (condensed_tree.get_n_edges()),
137  sizes,
138  out,
139  [=] __device__(value_idx a) { return a > 1; });
140 
141  auto n_leaves = condensed_tree.get_n_leaves();
142  thrust::transform(thrust_policy,
143  cluster_parents.begin(),
144  cluster_parents.end(),
145  cluster_parents.begin(),
146  [n_leaves] __device__(value_idx a) { return a - n_leaves; });
147  thrust::transform(thrust_policy,
148  cluster_children.begin(),
149  cluster_children.end(),
150  cluster_children.begin(),
151  [n_leaves] __device__(value_idx a) { return a - n_leaves; });
152 
154  condensed_tree.get_n_leaves(),
155  cluster_tree_edges,
156  condensed_tree.get_n_clusters(),
157  std::move(cluster_parents),
158  std::move(cluster_children),
159  std::move(cluster_lambdas),
160  std::move(cluster_sizes));
161 }
162 
172 template <typename value_idx, typename value_t>
173 void parent_csr(const raft::handle_t& handle,
175  value_idx* sorted_parents,
176  value_idx* indptr)
177 {
178  auto stream = handle.get_stream();
179  auto thrust_policy = handle.get_thrust_policy();
180 
181  auto children = condensed_tree.get_children();
182  auto sizes = condensed_tree.get_sizes();
183  auto n_edges = condensed_tree.get_n_edges();
184  auto n_leaves = condensed_tree.get_n_leaves();
185  auto n_clusters = condensed_tree.get_n_clusters();
186 
187  // 0-index sorted parents by subtracting n_leaves for offsets and birth/stability indexing
188  auto index_op = [n_leaves] __device__(const auto& x) { return x - n_leaves; };
190  thrust_policy, sorted_parents, sorted_parents + n_edges, sorted_parents, index_op);
191 
192  raft::sparse::convert::sorted_coo_to_csr(sorted_parents, n_edges, indptr, n_clusters + 1, stream);
193 }
194 
195 template <typename value_idx, typename value_t>
196 void normalize(value_t* data, value_idx n, size_t m, cudaStream_t stream)
197 {
198  rmm::device_uvector<value_t> sums(m, stream);
199 
200  // Compute row sums
201  raft::linalg::rowNorm<value_t, size_t>(
202  sums.data(), data, (size_t)n, m, raft::linalg::L1Norm, true, stream);
203 
204  // Divide vector by row sums (modify in place)
205  raft::linalg::matrixVectorOp(
206  data,
207  const_cast<value_t*>(data),
208  sums.data(),
209  n,
210  (value_idx)m,
211  true,
212  false,
213  [] __device__(value_t mat_in, value_t vec_in) { return mat_in / vec_in; },
214  stream);
215 }
216 
227 template <typename value_idx, typename value_t>
228 void softmax(const raft::handle_t& handle, value_t* data, value_idx n, size_t m)
229 {
230  rmm::device_uvector<value_t> linf_norm(m, handle.get_stream());
231 
232  auto data_const_view =
233  raft::make_device_matrix_view<const value_t, value_idx, raft::row_major>(data, (int)m, n);
234  auto data_view =
235  raft::make_device_matrix_view<value_t, value_idx, raft::row_major>(data, (int)m, n);
236  auto linf_norm_const_view =
237  raft::make_device_vector_view<const value_t, value_idx>(linf_norm.data(), (int)m);
238  auto linf_norm_view = raft::make_device_vector_view<value_t, value_idx>(linf_norm.data(), (int)m);
239 
240  raft::linalg::norm(handle,
241  data_const_view,
242  linf_norm_view,
243  raft::linalg::LinfNorm,
244  raft::linalg::Apply::ALONG_ROWS);
245 
246  raft::linalg::matrix_vector_op(
247  handle,
248  data_const_view,
249  linf_norm_const_view,
250  data_view,
251  raft::linalg::Apply::ALONG_COLUMNS,
252  [] __device__(value_t mat_in, value_t vec_in) { return exp(mat_in - vec_in); });
253 }
254 
255 }; // namespace Utils
256 }; // namespace detail
257 }; // namespace HDBSCAN
258 }; // namespace ML
Definition: hdbscan.hpp:40
value_idx * get_sizes()
Definition: hdbscan.hpp:118
value_t * get_lambdas()
Definition: hdbscan.hpp:117
value_idx get_n_leaves() const
Definition: hdbscan.hpp:121
value_idx get_n_edges()
Definition: hdbscan.hpp:119
value_idx * get_children()
Definition: hdbscan.hpp:116
int get_n_clusters()
Definition: hdbscan.hpp:120
value_idx * get_parents()
Definition: hdbscan.hpp:115
Common::CondensedHierarchy< value_idx, value_t > make_cluster_tree(const raft::handle_t &handle, Common::CondensedHierarchy< value_idx, value_t > &condensed_tree)
Definition: utils.h:105
void softmax(const raft::handle_t &handle, value_t *data, value_idx n, size_t m)
Definition: utils.h:228
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:196
void cub_segmented_reduce(const value_t *in, value_t *out, int n_segments, const value_idx *offsets, cudaStream_t stream, CUBReduceFunc cub_reduce_func)
Definition: utils.h:71
void parent_csr(const raft::handle_t &handle, Common::CondensedHierarchy< value_idx, value_t > &condensed_tree, value_idx *sorted_parents, value_idx *indptr)
Definition: utils.h:173
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:27