Loading [MathJax]/extensions/tex2jax.js
cuML C++ API  23.12
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
svr.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cublas_v2.h>
21 
22 namespace ML {
23 namespace SVM {
24 
25 template <typename math_t>
26 struct SvmModel;
27 struct SvmParameter;
28 
29 // Forward declarations of the stateless API
49 template <typename math_t>
50 void svrFit(const raft::handle_t& handle,
51  math_t* X,
52  int n_rows,
53  int n_cols,
54  math_t* y,
55  const SvmParameter& param,
56  MLCommon::Matrix::KernelParams& kernel_params,
57  SvmModel<math_t>& model,
58  const math_t* sample_weight = nullptr);
59 
81 template <typename math_t>
82 void svrFitSparse(const raft::handle_t& handle,
83  int* indptr,
84  int* indices,
85  math_t* data,
86  int n_rows,
87  int n_cols,
88  int nnz,
89  math_t* y,
90  const SvmParameter& param,
91  raft::distance::kernels::KernelParams& kernel_params,
92  SvmModel<math_t>& model,
93  const math_t* sample_weight = nullptr);
94 
95 // For prediction we use svcPredict
96 
97 }; // end namespace SVM
98 }; // end namespace ML
void svrFit(const raft::handle_t &handle, math_t *X, int n_rows, int n_cols, math_t *y, const SvmParameter &param, MLCommon::Matrix::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
void svrFitSparse(const raft::handle_t &handle, int *indptr, int *indices, math_t *data, int n_rows, int n_cols, int nnz, math_t *y, const SvmParameter &param, raft::distance::kernels::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
Definition: dbscan.hpp:27
Definition: svm_model.h:35
Definition: svm_parameter.h:34