cuML C++ API  24.04
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
knn.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/distance/distance_types.hpp>
20 #include <raft/spatial/knn/ann_common.h>
21 #include <raft/spatial/knn/ball_cover_types.hpp>
22 
23 namespace raft {
24 class handle_t;
25 }
26 
27 namespace ML {
28 
50 void brute_force_knn(const raft::handle_t& handle,
51  std::vector<float*>& input,
52  std::vector<int>& sizes,
53  int D,
54  float* search_items,
55  int n,
56  int64_t* res_I,
57  float* res_D,
58  int k,
59  bool rowMajorIndex = false,
60  bool rowMajorQuery = false,
61  raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
62  float metric_arg = 2.0f);
63 
64 void rbc_build_index(const raft::handle_t& handle,
65  raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index);
66 
67 void rbc_knn_query(const raft::handle_t& handle,
68  raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index,
69  uint32_t k,
70  const float* search_items,
71  uint32_t n_search_items,
72  int64_t* out_inds,
73  float* out_dists);
87 void approx_knn_build_index(raft::handle_t& handle,
88  raft::spatial::knn::knnIndex* index,
89  raft::spatial::knn::knnIndexParam* params,
90  raft::distance::DistanceType metric,
91  float metricArg,
92  float* index_array,
93  int n,
94  int D);
95 
109 void approx_knn_search(raft::handle_t& handle,
110  float* distances,
111  int64_t* indices,
112  raft::spatial::knn::knnIndex* index,
113  int k,
114  float* query_array,
115  int n);
116 
131 void knn_classify(raft::handle_t& handle,
132  int* out,
133  int64_t* knn_indices,
134  std::vector<int*>& y,
135  size_t n_index_rows,
136  size_t n_query_rows,
137  int k);
138 
153 void knn_regress(raft::handle_t& handle,
154  float* out,
155  int64_t* knn_indices,
156  std::vector<float*>& y,
157  size_t n_index_rows,
158  size_t n_query_rows,
159  int k);
160 
175 void knn_class_proba(raft::handle_t& handle,
176  std::vector<float*>& out,
177  int64_t* knn_indices,
178  std::vector<int*>& y,
179  size_t n_index_rows,
180  size_t n_query_rows,
181  int k);
182 }; // namespace ML
Definition: params.hpp:34
Definition: dbscan.hpp:30
void approx_knn_build_index(raft::handle_t &handle, raft::spatial::knn::knnIndex *index, raft::spatial::knn::knnIndexParam *params, raft::distance::DistanceType metric, float metricArg, float *index_array, int n, int D)
Flat C++ API function to build an approximate nearest neighbors index from an index array and a set o...
void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn classification using a given a vector of label arrays....
void brute_force_knn(const raft::handle_t &handle, std::vector< float * > &input, std::vector< int > &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex=false, bool rowMajorQuery=false, raft::distance::DistanceType metric=raft::distance::DistanceType::L2Expanded, float metric_arg=2.0f)
Flat C++ API function to perform a brute force knn on a series of input arrays and combine the result...
void approx_knn_search(raft::handle_t &handle, float *distances, int64_t *indices, raft::spatial::knn::knnIndex *index, int k, float *query_array, int n)
Flat C++ API function to perform an approximate nearest neighbors search from previously built index ...
void knn_class_proba(raft::handle_t &handle, std::vector< float * > &out, int64_t *knn_indices, std::vector< int * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to compute knn class probabilities using a vector of device arrays containing d...
void knn_regress(raft::handle_t &handle, float *out, int64_t *knn_indices, std::vector< float * > &y, size_t n_index_rows, size_t n_query_rows, int k)
Flat C++ API function to perform a knn regression using a given a vector of label arrays....
void rbc_build_index(const raft::handle_t &handle, raft::spatial::knn::BallCoverIndex< int64_t, float, uint32_t > &index)
void rbc_knn_query(const raft::handle_t &handle, raft::spatial::knn::BallCoverIndex< int64_t, float, uint32_t > &index, uint32_t k, const float *search_items, uint32_t n_search_items, int64_t *out_inds, float *out_dists)
Definition: dbscan.hpp:26