umap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <raft/core/host_coo_matrix.hpp>
22 #include <raft/sparse/coo.hpp>
23 
24 #include <rmm/device_buffer.hpp>
25 
26 #include <cstddef>
27 #include <cstdint>
28 #include <memory>
29 
30 namespace raft {
31 class handle_t;
32 } // namespace raft
33 
34 namespace ML {
35 class UMAPParams;
36 namespace UMAP {
37 
45 void find_ab(const raft::handle_t& handle, UMAPParams* params);
46 
60 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
61  float* X, // input matrix
62  float* y, // labels
63  int n,
64  int d,
65  int64_t* knn_indices,
66  float* knn_dists,
68 
82 void refine(const raft::handle_t& handle,
83  float* X,
84  int n,
85  int d,
86  raft::sparse::COO<float, int>* graph,
88  float* embeddings);
89 
103 void init_and_refine(const raft::handle_t& handle,
104  float* X,
105  int n,
106  int d,
107  raft::sparse::COO<float, int>* graph,
109  float* embeddings);
110 
126 void fit(const raft::handle_t& handle,
127  float* X,
128  float* y,
129  int n,
130  int d,
131  int64_t* knn_indices,
132  float* knn_dists,
134  std::unique_ptr<rmm::device_buffer>& embeddings,
135  raft::host_coo_matrix<float, int, int, uint64_t>& graph);
136 
155 void fit_sparse(const raft::handle_t& handle,
156  int* indptr,
157  int* indices,
158  float* data,
159  size_t nnz,
160  float* y,
161  int n,
162  int d,
163  int* knn_indices,
164  float* knn_dists,
166  std::unique_ptr<rmm::device_buffer>& embeddings,
167  raft::host_coo_matrix<float, int, int, uint64_t>& graph);
168 
183 void transform(const raft::handle_t& handle,
184  float* X,
185  int n,
186  int d,
187  float* orig_X,
188  int orig_n,
189  float* embedding,
190  int embedding_n,
192  float* transformed);
193 
214 void transform_sparse(const raft::handle_t& handle,
215  int* indptr,
216  int* indices,
217  float* data,
218  size_t nnz,
219  int n,
220  int d,
221  int* orig_x_indptr,
222  int* orig_x_indices,
223  float* orig_x_data,
224  size_t orig_nnz,
225  int orig_n,
226  float* embedding,
227  int embedding_n,
229  float* transformed);
230 
231 } // namespace UMAP
232 } // namespace ML
Definition: umapparams.h:75
Definition: params.hpp:34
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void init_and_refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, std::unique_ptr< rmm::device_buffer > &embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, std::unique_ptr< rmm::device_buffer > &embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:29
Definition: dbscan.hpp:25