Loading [MathJax]/jax/output/HTML-CSS/config.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
cd_mg.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 #include <cumlprims/opg/matrix/data.hpp>
21 #include <cumlprims/opg/matrix/part_descriptor.hpp>
22 
23 namespace ML {
24 namespace CD {
25 namespace opg {
26 
44 void fit(raft::handle_t& handle,
45  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
46  MLCommon::Matrix::PartDescriptor& input_desc,
47  std::vector<MLCommon::Matrix::Data<float>*>& labels,
48  float* coef,
49  float* intercept,
50  bool fit_intercept,
51  bool normalize,
52  int epochs,
53  float alpha,
54  float l1_ratio,
55  bool shuffle,
56  float tol,
57  bool verbose);
58 
59 void fit(raft::handle_t& handle,
60  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
61  MLCommon::Matrix::PartDescriptor& input_desc,
62  std::vector<MLCommon::Matrix::Data<double>*>& labels,
63  double* coef,
64  double* intercept,
65  bool fit_intercept,
66  bool normalize,
67  int epochs,
68  double alpha,
69  double l1_ratio,
70  bool shuffle,
71  double tol,
72  bool verbose);
73 
87 void predict(raft::handle_t& handle,
88  MLCommon::Matrix::RankSizePair** rank_sizes,
89  size_t n_parts,
90  MLCommon::Matrix::Data<float>** input,
91  size_t n_rows,
92  size_t n_cols,
93  float* coef,
94  float intercept,
95  MLCommon::Matrix::Data<float>** preds,
96  bool verbose);
97 
98 void predict(raft::handle_t& handle,
99  MLCommon::Matrix::RankSizePair** rank_sizes,
100  size_t n_parts,
101  MLCommon::Matrix::Data<double>** input,
102  size_t n_rows,
103  size_t n_cols,
104  double* coef,
105  double intercept,
106  MLCommon::Matrix::Data<double>** preds,
107  bool verbose);
108 
109 }; // end namespace opg
110 }; // namespace CD
111 }; // end namespace ML
void predict(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, size_t n_parts, MLCommon::Matrix::Data< float > **input, size_t n_rows, size_t n_cols, float *coef, float intercept, MLCommon::Matrix::Data< float > **preds, bool verbose)
performs MNMG prediction for OLS
void fit(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &labels, float *coef, float *intercept, bool fit_intercept, bool normalize, int epochs, float alpha, float l1_ratio, bool shuffle, float tol, bool verbose)
performs MNMG fit operation for the ridge regression
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:196
void shuffle(std::vector< math_t > &rand_indices, std::mt19937 &g)
Definition: shuffle.h:35
Definition: dbscan.hpp:27