umap.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <raft/core/host_coo_matrix.hpp>
11 #include <raft/sparse/coo.hpp>
12 
13 #include <rmm/device_buffer.hpp>
14 
15 #include <cstddef>
16 #include <cstdint>
17 #include <memory>
18 
19 namespace raft {
20 class handle_t;
21 } // namespace raft
22 
23 namespace ML {
24 class UMAPParams;
25 namespace UMAP {
26 
34 void find_ab(const raft::handle_t& handle, UMAPParams* params);
35 
49 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
50  float* X, // input matrix
51  float* y, // labels
52  int n,
53  int d,
54  int64_t* knn_indices,
55  float* knn_dists,
57 
71 void refine(const raft::handle_t& handle,
72  float* X,
73  int n,
74  int d,
75  raft::sparse::COO<float, int>* graph,
77  float* embeddings);
78 
92 void init_and_refine(const raft::handle_t& handle,
93  float* X,
94  int n,
95  int d,
96  raft::sparse::COO<float, int>* graph,
98  float* embeddings);
99 
117 void fit(const raft::handle_t& handle,
118  float* X,
119  float* y,
120  int n,
121  int d,
122  int64_t* knn_indices,
123  float* knn_dists,
125  std::unique_ptr<rmm::device_buffer>& embeddings,
126  raft::host_coo_matrix<float, int, int, uint64_t>& graph,
127  float* sigmas = nullptr,
128  float* rhos = nullptr);
129 
148 void fit_sparse(const raft::handle_t& handle,
149  int* indptr,
150  int* indices,
151  float* data,
152  size_t nnz,
153  float* y,
154  int n,
155  int d,
156  int* knn_indices,
157  float* knn_dists,
159  std::unique_ptr<rmm::device_buffer>& embeddings,
160  raft::host_coo_matrix<float, int, int, uint64_t>& graph);
161 
176 void transform(const raft::handle_t& handle,
177  float* X,
178  int n,
179  int d,
180  float* orig_X,
181  int orig_n,
182  float* embedding,
183  int embedding_n,
185  float* transformed);
186 
207 void transform_sparse(const raft::handle_t& handle,
208  int* indptr,
209  int* indices,
210  float* data,
211  size_t nnz,
212  int n,
213  int d,
214  int* orig_x_indptr,
215  int* orig_x_indices,
216  float* orig_x_data,
217  size_t orig_nnz,
218  int orig_n,
219  float* embedding,
220  int embedding_n,
222  float* transformed);
223 
243 void inverse_transform(const raft::handle_t& handle,
244  float* inv_transformed,
245  int n,
246  int n_features,
247  float* orig_X,
248  int orig_n,
249  int* graph_rows,
250  int* graph_cols,
251  float* graph_vals,
252  int nnz,
253  float* sigmas,
254  float* rhos,
256  int n_epochs);
257 
258 } // namespace UMAP
259 } // namespace ML
Definition: umapparams.h:64
Definition: params.hpp:23
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void init_and_refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, std::unique_ptr< rmm::device_buffer > &embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph)
void inverse_transform(const raft::handle_t &handle, float *inv_transformed, int n, int n_features, float *orig_X, int orig_n, int *graph_rows, int *graph_cols, float *graph_vals, int nnz, float *sigmas, float *rhos, UMAPParams *params, int n_epochs)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, std::unique_ptr< rmm::device_buffer > &embeddings, raft::host_coo_matrix< float, int, int, uint64_t > &graph, float *sigmas=nullptr, float *rhos=nullptr)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:18
Definition: dbscan.hpp:14