Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
umap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/sparse/coo.hpp>
20 
21 #include <cstddef>
22 #include <cstdint>
24 #include <memory>
25 
26 namespace raft {
27 class handle_t;
28 } // namespace raft
29 
30 namespace ML {
31 class UMAPParams;
32 namespace UMAP {
33 
41 void find_ab(const raft::handle_t& handle, UMAPParams* params);
42 
56 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
57  float* X, // input matrix
58  float* y, // labels
59  int n,
60  int d,
61  int64_t* knn_indices,
62  float* knn_dists,
64 
78 void refine(const raft::handle_t& handle,
79  float* X,
80  int n,
81  int d,
82  raft::sparse::COO<float, int>* graph,
84  float* embeddings);
85 
100 void fit(const raft::handle_t& handle,
101  float* X,
102  float* y,
103  int n,
104  int d,
105  int64_t* knn_indices,
106  float* knn_dists,
108  float* embeddings,
109  raft::sparse::COO<float, int>* graph);
110 
128 void fit_sparse(const raft::handle_t& handle,
129  int* indptr,
130  int* indices,
131  float* data,
132  size_t nnz,
133  float* y,
134  int n,
135  int d,
136  int* knn_indices,
137  float* knn_dists,
139  float* embeddings,
140  raft::sparse::COO<float, int>* graph);
141 
156 void transform(const raft::handle_t& handle,
157  float* X,
158  int n,
159  int d,
160  float* orig_X,
161  int orig_n,
162  float* embedding,
163  int embedding_n,
165  float* transformed);
166 
187 void transform_sparse(const raft::handle_t& handle,
188  int* indptr,
189  int* indices,
190  float* data,
191  size_t nnz,
192  int n,
193  int d,
194  int* orig_x_indptr,
195  int* orig_x_indices,
196  float* orig_x_data,
197  size_t orig_nnz,
198  int orig_n,
199  float* embedding,
200  int embedding_n,
202  float* transformed);
203 
204 } // namespace UMAP
205 } // namespace ML
Definition: umapparams.h:25
Definition: params.hpp:34
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:27
Definition: dbscan.hpp:23