Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
umapparams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/callback.hpp>
20 #include <cuml/common/logger.hpp>
21 #include <raft/distance/distance_types.hpp>
22 
23 namespace ML {
24 
25 class UMAPParams {
26  public:
28 
35  int n_neighbors = 15;
36 
40  int n_components = 2;
41 
46  int n_epochs = 0;
47 
51  float learning_rate = 1.0;
52 
61  float min_dist = 0.1;
62 
67  float spread = 1.0;
68 
77  float set_op_mix_ratio = 1.0;
78 
86  float local_connectivity = 1.0;
87 
93  float repulsion_strength = 1.0;
94 
102 
109  float transform_queue_size = 4.0;
110 
115 
121  float a = -1.0;
122 
128  float b = -1.0;
129 
133  float initial_alpha = 1.0;
134 
140  int init = 1;
141 
147 
149 
150  float target_weight = 0.5;
151 
152  uint64_t random_state = 0;
153 
159  bool deterministic = true;
160 
161  raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded;
162 
163  float p = 2.0;
164 
166 };
167 
168 } // namespace ML
Definition: callback.hpp:29
Definition: umapparams.h:25
bool deterministic
Definition: umapparams.h:159
float min_dist
Definition: umapparams.h:61
float repulsion_strength
Definition: umapparams.h:93
float local_connectivity
Definition: umapparams.h:86
float set_op_mix_ratio
Definition: umapparams.h:77
float initial_alpha
Definition: umapparams.h:133
float target_weight
Definition: umapparams.h:150
float spread
Definition: umapparams.h:67
raft::distance::DistanceType metric
Definition: umapparams.h:161
float a
Definition: umapparams.h:121
int n_components
Definition: umapparams.h:40
int n_neighbors
Definition: umapparams.h:35
float transform_queue_size
Definition: umapparams.h:109
MetricType target_metric
Definition: umapparams.h:148
float p
Definition: umapparams.h:163
int negative_sample_rate
Definition: umapparams.h:101
float b
Definition: umapparams.h:128
int n_epochs
Definition: umapparams.h:46
int verbosity
Definition: umapparams.h:114
int init
Definition: umapparams.h:140
MetricType
Definition: umapparams.h:27
@ EUCLIDEAN
Definition: umapparams.h:27
@ CATEGORICAL
Definition: umapparams.h:27
Internals::GraphBasedDimRedCallback * callback
Definition: umapparams.h:165
uint64_t random_state
Definition: umapparams.h:152
float learning_rate
Definition: umapparams.h:51
int target_n_neighbors
Definition: umapparams.h:146
#define CUML_LEVEL_INFO
Definition: log_levels.hpp:28
Definition: dbscan.hpp:27