umap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <raft/core/host_coo_matrix.hpp>
22 #include <raft/sparse/coo.hpp>
23 
24 #include <cstddef>
25 #include <cstdint>
26 #include <memory>
27 
28 namespace raft {
29 class handle_t;
30 } // namespace raft
31 
32 namespace ML {
33 class UMAPParams;
34 namespace UMAP {
35 
43 void find_ab(const raft::handle_t& handle, UMAPParams* params);
44 
58 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
59  float* X, // input matrix
60  float* y, // labels
61  int n,
62  int d,
63  int64_t* knn_indices,
64  float* knn_dists,
66 
80 void refine(const raft::handle_t& handle,
81  float* X,
82  int n,
83  int d,
84  raft::sparse::COO<float, int>* graph,
86  float* embeddings);
87 
101 void init_and_refine(const raft::handle_t& handle,
102  float* X,
103  int n,
104  int d,
105  raft::sparse::COO<float, int>* graph,
107  float* embeddings);
108 
123 void fit(const raft::handle_t& handle,
124  float* X,
125  float* y,
126  int n,
127  int d,
128  int64_t* knn_indices,
129  float* knn_dists,
131  float* embeddings,
132  raft::host_coo_matrix<float, int, int, uint64_t>& graph);
133 
151 void fit_sparse(const raft::handle_t& handle,
152  int* indptr,
153  int* indices,
154  float* data,
155  size_t nnz,
156  float* y,
157  int n,
158  int d,
159  int* knn_indices,
160  float* knn_dists,
162  float* embeddings,
163  raft::host_coo_matrix<float, int, int, uint64_t>& graph);
164 
179 void transform(const raft::handle_t& handle,
180  float* X,
181  int n,
182  int d,
183  float* orig_X,
184  int orig_n,
185  float* embedding,
186  int embedding_n,
188  float* transformed);
189 
210 void transform_sparse(const raft::handle_t& handle,
211  int* indptr,
212  int* indices,
213  float* data,
214  size_t nnz,
215  int n,
216  int d,
217  int* orig_x_indptr,
218  int* orig_x_indices,
219  float* orig_x_data,
220  size_t orig_nnz,
221  int orig_n,
222  float* embedding,
223  int embedding_n,
225  float* transformed);
226 
227 } // namespace UMAP
228 } // namespace ML
Definition: umapparams.h:75
Definition: params.hpp:34
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph)
void init_and_refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:29
Definition: dbscan.hpp:25