umap.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <raft/sparse/coo.hpp>
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <memory>
26 
27 namespace raft {
28 class handle_t;
29 } // namespace raft
30 
31 namespace ML {
32 class UMAPParams;
33 namespace UMAP {
34 
42 void find_ab(const raft::handle_t& handle, UMAPParams* params);
43 
57 std::unique_ptr<raft::sparse::COO<float, int>> get_graph(const raft::handle_t& handle,
58  float* X, // input matrix
59  float* y, // labels
60  int n,
61  int d,
62  int64_t* knn_indices,
63  float* knn_dists,
65 
79 void refine(const raft::handle_t& handle,
80  float* X,
81  int n,
82  int d,
83  raft::sparse::COO<float, int>* graph,
85  float* embeddings);
86 
100 void init_and_refine(const raft::handle_t& handle,
101  float* X,
102  int n,
103  int d,
104  raft::sparse::COO<float, int>* graph,
106  float* embeddings);
107 
122 void fit(const raft::handle_t& handle,
123  float* X,
124  float* y,
125  int n,
126  int d,
127  int64_t* knn_indices,
128  float* knn_dists,
130  float* embeddings,
131  raft::sparse::COO<float, int>* graph);
132 
150 void fit_sparse(const raft::handle_t& handle,
151  int* indptr,
152  int* indices,
153  float* data,
154  size_t nnz,
155  float* y,
156  int n,
157  int d,
158  int* knn_indices,
159  float* knn_dists,
161  float* embeddings,
162  raft::sparse::COO<float, int>* graph);
163 
178 void transform(const raft::handle_t& handle,
179  float* X,
180  int n,
181  int d,
182  float* orig_X,
183  int orig_n,
184  float* embedding,
185  int embedding_n,
187  float* transformed);
188 
209 void transform_sparse(const raft::handle_t& handle,
210  int* indptr,
211  int* indices,
212  float* data,
213  size_t nnz,
214  int n,
215  int d,
216  int* orig_x_indptr,
217  int* orig_x_indices,
218  float* orig_x_data,
219  size_t orig_nnz,
220  int orig_n,
221  float* embedding,
222  int embedding_n,
224  float* transformed);
225 
226 } // namespace UMAP
227 } // namespace ML
Definition: umapparams.h:30
Definition: params.hpp:34
void refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
std::unique_ptr< raft::sparse::COO< float, int > > get_graph(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params)
void init_and_refine(const raft::handle_t &handle, float *X, int n, int d, raft::sparse::COO< float, int > *graph, UMAPParams *params, float *embeddings)
void fit_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, float *y, int n, int d, int *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void fit(const raft::handle_t &handle, float *X, float *y, int n, int d, int64_t *knn_indices, float *knn_dists, UMAPParams *params, float *embeddings, raft::sparse::COO< float, int > *graph)
void find_ab(const raft::handle_t &handle, UMAPParams *params)
void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices, float *data, size_t nnz, int n, int d, int *orig_x_indptr, int *orig_x_indices, float *orig_x_data, size_t orig_nnz, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
void transform(const raft::handle_t &handle, float *X, int n, int d, float *orig_X, int orig_n, float *embedding, int embedding_n, UMAPParams *params, float *transformed)
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26