Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
common.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <stdint.h>
20 
21 namespace ML {
22 
23 // Dense input uses int64_t until FAISS is updated
24 typedef int64_t knn_indices_dense_t;
25 
27 
33 template <typename value_idx, typename value_t>
34 struct knn_graph {
35  knn_graph(value_idx n_rows_, int n_neighbors_)
36  : n_rows(n_rows_), n_neighbors(n_neighbors_), knn_indices{nullptr}, knn_dists{nullptr}
37  {
38  }
39 
40  knn_graph(value_idx n_rows_, int n_neighbors_, value_idx* knn_indices_, value_t* knn_dists_)
41  : n_rows(n_rows_), n_neighbors(n_neighbors_), knn_indices(knn_indices_), knn_dists(knn_dists_)
42  {
43  }
44 
45  value_idx* knn_indices;
46  value_t* knn_dists;
47 
48  value_idx n_rows;
50 };
51 
57 template <typename T>
59  T* y;
60  int n;
61  int d;
62 
63  manifold_inputs_t(T* y_, int n_, int d_) : y(y_), n(n_), d(d_) {}
64 
65  virtual bool alloc_knn_graph() const = 0;
66 };
67 
72 template <typename T>
74  T* X;
75 
76  manifold_dense_inputs_t(T* x_, T* y_, int n_, int d_) : manifold_inputs_t<T>(y_, n_, d_), X(x_) {}
77 
78  bool alloc_knn_graph() const { return true; }
79 };
80 
86 template <typename value_idx, typename T>
88  value_idx* indptr;
89  value_idx* indices;
90  T* data;
91 
92  size_t nnz;
93 
95  value_idx* indptr_, value_idx* indices_, T* data_, T* y_, size_t nnz_, int n_, int d_)
96  : manifold_inputs_t<T>(y_, n_, d_), indptr(indptr_), indices(indices_), data(data_), nnz(nnz_)
97  {
98  }
99 
100  bool alloc_knn_graph() const { return true; }
101 };
102 
108 template <typename value_idx, typename value_t>
111  value_idx* knn_indices_, value_t* knn_dists_, value_t* y_, int n_, int d_, int n_neighbors_)
112  : manifold_inputs_t<value_t>(y_, n_, d_), knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_)
113  {
114  }
115 
117 
118  bool alloc_knn_graph() const { return false; }
119 };
120 
121 }; // end namespace ML
Definition: dbscan.hpp:30
int64_t knn_indices_dense_t
Definition: common.hpp:24
int knn_indices_sparse_t
Definition: common.hpp:26
Definition: common.hpp:34
int n_neighbors
Definition: common.hpp:49
value_idx n_rows
Definition: common.hpp:48
value_idx * knn_indices
Definition: common.hpp:45
knn_graph(value_idx n_rows_, int n_neighbors_, value_idx *knn_indices_, value_t *knn_dists_)
Definition: common.hpp:40
value_t * knn_dists
Definition: common.hpp:46
knn_graph(value_idx n_rows_, int n_neighbors_)
Definition: common.hpp:35
Definition: common.hpp:73
bool alloc_knn_graph() const
Definition: common.hpp:78
T * X
Definition: common.hpp:74
manifold_dense_inputs_t(T *x_, T *y_, int n_, int d_)
Definition: common.hpp:76
Definition: common.hpp:58
virtual bool alloc_knn_graph() const =0
manifold_inputs_t(T *y_, int n_, int d_)
Definition: common.hpp:63
T * y
Definition: common.hpp:59
int d
Definition: common.hpp:61
int n
Definition: common.hpp:60
Definition: common.hpp:109
bool alloc_knn_graph() const
Definition: common.hpp:118
knn_graph< value_idx, value_t > knn_graph
Definition: common.hpp:116
Definition: common.hpp:87
value_idx * indices
Definition: common.hpp:89
size_t nnz
Definition: common.hpp:92
manifold_sparse_inputs_t(value_idx *indptr_, value_idx *indices_, T *data_, T *y_, size_t nnz_, int n_, int d_)
Definition: common.hpp:94
bool alloc_knn_graph() const
Definition: common.hpp:100
value_idx * indptr
Definition: common.hpp:88
T * data
Definition: common.hpp:90