Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  24.04
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
umap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <raft/sparse/coo.hpp>
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <memory>
26 
27 namespace raft {
28 class handle_t;
29 } // namespace raft
30 
31 namespace ML {
32 class UMAPParams;
33 namespace UMAP {
34 
42 void find_ab(const raft::handle_t& handle, UMAPParams* params);
43 
57 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
58  float* X, // input matrix
59  float* y, // labels
60  int n,
61  int d,
62  int64_t* knn_indices,
63  float* knn_dists,
65 
79 void refine(const raft::handle_t& handle,
80  float* X,
81  int n,
82  int d,
83  raft::sparse::COO<float, int>* graph,
85  float* embeddings);
86 
101 void fit(const raft::handle_t& handle,
102  float* X,
103  float* y,
104  int n,
105  int d,
106  int64_t* knn_indices,
107  float* knn_dists,
109  float* embeddings,
110  raft::sparse::COO<float, int>* graph);
111 
129 void fit_sparse(const raft::handle_t& handle,
130  int* indptr,
131  int* indices,
132  float* data,
133  size_t nnz,
134  float* y,
135  int n,
136  int d,
137  int* knn_indices,
138  float* knn_dists,
140  float* embeddings,
141  raft::sparse::COO<float, int>* graph);
142 
157 void transform(const raft::handle_t& handle,
158  float* X,
159  int n,
160  int d,
161  float* orig_X,
162  int orig_n,
163  float* embedding,
164  int embedding_n,
166  float* transformed);
167 
188 void transform_sparse(const raft::handle_t& handle,
189  int* indptr,
190  int* indices,
191  float* data,
192  size_t nnz,
193  int n,
194  int d,
195  int* orig_x_indptr,
196  int* orig_x_indices,
197  float* orig_x_data,
198  size_t orig_nnz,
199  int orig_n,
200  float* embedding,
201  int embedding_n,
203  float* transformed);
204 
205 } // namespace UMAP
206 } // namespace ML
Definition: umapparams.h:26
Definition: params.hpp:34
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26