Loading [MathJax]/jax/output/HTML-CSS/config.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
kmeans.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 #include <raft/cluster/kmeans_types.hpp>
21 
22 namespace raft {
23 class handle_t;
24 }
25 
26 namespace ML {
27 
28 namespace kmeans {
29 
31 
56 void fit_predict(const raft::handle_t& handle,
57  const KMeansParams& params,
58  const float* X,
59  int n_samples,
60  int n_features,
61  const float* sample_weight,
62  float* centroids,
63  int* labels,
64  float& inertia,
65  int& n_iter);
66 
67 void fit_predict(const raft::handle_t& handle,
68  const KMeansParams& params,
69  const double* X,
70  int n_samples,
71  int n_features,
72  const double* sample_weight,
73  double* centroids,
74  int* labels,
75  double& inertia,
76  int& n_iter);
77 void fit_predict(const raft::handle_t& handle,
78  const KMeansParams& params,
79  const float* X,
80  int64_t n_samples,
81  int64_t n_features,
82  const float* sample_weight,
83  float* centroids,
84  int64_t* labels,
85  float& inertia,
86  int64_t& n_iter);
87 
88 void fit_predict(const raft::handle_t& handle,
89  const KMeansParams& params,
90  const double* X,
91  int64_t n_samples,
92  int64_t n_features,
93  const double* sample_weight,
94  double* centroids,
95  int64_t* labels,
96  double& inertia,
97  int64_t& n_iter);
98 
121 void predict(const raft::handle_t& handle,
122  const KMeansParams& params,
123  const float* centroids,
124  const float* X,
125  int n_samples,
126  int n_features,
127  const float* sample_weight,
128  bool normalize_weights,
129  int* labels,
130  float& inertia);
131 
132 void predict(const raft::handle_t& handle,
133  const KMeansParams& params,
134  const double* centroids,
135  const double* X,
136  int n_samples,
137  int n_features,
138  const double* sample_weight,
139  bool normalize_weights,
140  int* labels,
141  double& inertia);
142 void predict(const raft::handle_t& handle,
143  const KMeansParams& params,
144  const float* centroids,
145  const float* X,
146  int64_t n_samples,
147  int64_t n_features,
148  const float* sample_weight,
149  bool normalize_weights,
150  int64_t* labels,
151  float& inertia);
152 
153 void predict(const raft::handle_t& handle,
154  const KMeansParams& params,
155  const double* centroids,
156  const double* X,
157  int64_t n_samples,
158  int64_t n_features,
159  const double* sample_weight,
160  bool normalize_weights,
161  int64_t* labels,
162  double& inertia);
180 void transform(const raft::handle_t& handle,
181  const KMeansParams& params,
182  const float* centroids,
183  const float* X,
184  int n_samples,
185  int n_features,
186  float* X_new);
187 
188 void transform(const raft::handle_t& handle,
189  const KMeansParams& params,
190  const double* centroids,
191  const double* X,
192  int n_samples,
193  int n_features,
194  double* X_new);
195 void transform(const raft::handle_t& handle,
196  const KMeansParams& params,
197  const float* centroids,
198  const float* X,
199  int64_t n_samples,
200  int64_t n_features,
201  float* X_new);
202 
203 void transform(const raft::handle_t& handle,
204  const KMeansParams& params,
205  const double* centroids,
206  const double* X,
207  int64_t n_samples,
208  int64_t n_features,
209  double* X_new);
210 }; // end namespace kmeans
211 }; // end namespace ML
Definition: params.hpp:34
void fit_predict(const raft::handle_t &handle, const KMeansParams &params, const float *X, int n_samples, int n_features, const float *sample_weight, float *centroids, int *labels, float &inertia, int &n_iter)
Compute k-means clustering and predicts cluster index for each sample in the input.
raft::cluster::KMeansParams KMeansParams
Definition: kmeans.hpp:30
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
void predict(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, const float *sample_weight, bool normalize_weights, int *labels, float &inertia)
Predict the closest cluster each sample in X belongs to.
Definition: dbscan.hpp:27
Definition: dbscan.hpp:23