dense.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "base.hpp"
19 
20 #include <raft/core/handle.hpp>
21 #include <raft/linalg/add.cuh>
22 #include <raft/linalg/ternary_op.cuh>
23 #include <raft/util/cuda_utils.cuh>
24 #include <raft/util/cudart_utils.hpp>
25 
26 #include <iostream>
27 #include <vector>
28 // #TODO: Replace with public header when ready
29 #include <raft/linalg/detail/cublas_wrappers.hpp>
30 #include <raft/linalg/map_then_reduce.cuh>
31 #include <raft/linalg/norm.cuh>
32 #include <raft/linalg/unary_op.cuh>
33 
34 #include <rmm/device_uvector.hpp>
35 
36 namespace ML {
37 
38 enum STORAGE_ORDER { COL_MAJOR = 0, ROW_MAJOR = 1 };
39 
40 template <typename T>
43  int len;
44  T* data;
45 
46  STORAGE_ORDER ord; // storage order: runtime param for compile time sake
47 
48  SimpleDenseMat(STORAGE_ORDER order = COL_MAJOR) : Super(0, 0), data(nullptr), len(0), ord(order)
49  {
50  }
51 
52  SimpleDenseMat(T* data, int m, int n, STORAGE_ORDER order = COL_MAJOR)
53  : Super(m, n), data(data), len(m * n), ord(order)
54  {
55  }
56 
57  void reset(T* data_, int m_, int n_)
58  {
59  this->m = m_;
60  this->n = n_;
61  data = data_;
62  len = m_ * n_;
63  }
64 
65  // Implemented GEMM as a static method here to improve readability
66  inline static void gemm(const raft::handle_t& handle,
67  const T alpha,
68  const SimpleDenseMat<T>& A,
69  const bool transA,
70  const SimpleDenseMat<T>& B,
71  const bool transB,
72  const T beta,
74  cudaStream_t stream)
75  {
76  int kA = A.n;
77  int kB = B.m;
78 
79  if (transA) {
80  ASSERT(A.n == C.m, "GEMM invalid dims: m");
81  kA = A.m;
82  } else {
83  ASSERT(A.m == C.m, "GEMM invalid dims: m");
84  }
85 
86  if (transB) {
87  ASSERT(B.m == C.n, "GEMM invalid dims: n");
88  kB = B.n;
89  } else {
90  ASSERT(B.n == C.n, "GEMM invalid dims: n");
91  }
92  ASSERT(kA == kB, "GEMM invalid dims: k");
93 
94  if (A.ord == COL_MAJOR && B.ord == COL_MAJOR && C.ord == COL_MAJOR) {
95  // #TODO: Call from public API when ready
96  raft::linalg::detail::cublasgemm(handle.get_cublas_handle(), // handle
97  transA ? CUBLAS_OP_T : CUBLAS_OP_N, // transA
98  transB ? CUBLAS_OP_T : CUBLAS_OP_N, // transB
99  C.m,
100  C.n,
101  kA, // dimensions m,n,k
102  &alpha,
103  A.data,
104  A.m, // lda
105  B.data,
106  B.m, // ldb
107  &beta,
108  C.data,
109  C.m, // ldc,
110  stream);
111  return;
112  }
113  if (A.ord == ROW_MAJOR) {
114  const SimpleDenseMat<T> Acm(A.data, A.n, A.m, COL_MAJOR);
115  gemm(handle, alpha, Acm, !transA, B, transB, beta, C, stream);
116  return;
117  }
118  if (B.ord == ROW_MAJOR) {
119  const SimpleDenseMat<T> Bcm(B.data, B.n, B.m, COL_MAJOR);
120  gemm(handle, alpha, A, transA, Bcm, !transB, beta, C, stream);
121  return;
122  }
123  if (C.ord == ROW_MAJOR) {
124  SimpleDenseMat<T> Ccm(C.data, C.n, C.m, COL_MAJOR);
125  gemm(handle, alpha, B, !transB, A, !transA, beta, Ccm, stream);
126  return;
127  }
128  }
129 
130  inline void gemmb(const raft::handle_t& handle,
131  const T alpha,
132  const SimpleDenseMat<T>& A,
133  const bool transA,
134  const bool transB,
135  const T beta,
137  cudaStream_t stream) const override
138  {
139  SimpleDenseMat<T>::gemm(handle, alpha, A, transA, *this, transB, beta, C, stream);
140  }
141 
149  inline void assign_gemm(const raft::handle_t& handle,
150  const T alpha,
151  const SimpleDenseMat<T>& A,
152  const bool transA,
153  const SimpleMat<T>& B,
154  const bool transB,
155  const T beta,
156  cudaStream_t stream)
157  {
158  B.gemmb(handle, alpha, A, transA, transB, beta, *this, stream);
159  }
160 
161  // this = a*x
162  inline void ax(const T a, const SimpleDenseMat<T>& x, cudaStream_t stream)
163  {
164  ASSERT(ord == x.ord, "SimpleDenseMat::ax: Storage orders must match");
165 
166  auto scale = [a] __device__(const T x) { return a * x; };
167  raft::linalg::unaryOp(data, x.data, len, scale, stream);
168  }
169 
170  // this = a*x + y
171  inline void axpy(const T a,
172  const SimpleDenseMat<T>& x,
173  const SimpleDenseMat<T>& y,
174  cudaStream_t stream)
175  {
176  ASSERT(ord == x.ord, "SimpleDenseMat::axpy: Storage orders must match");
177  ASSERT(ord == y.ord, "SimpleDenseMat::axpy: Storage orders must match");
178 
179  auto axpy = [a] __device__(const T x, const T y) { return a * x + y; };
180  raft::linalg::binaryOp(data, x.data, y.data, len, axpy, stream);
181  }
182 
183  template <typename Lambda>
184  inline void assign_unary(const SimpleDenseMat<T>& other, Lambda f, cudaStream_t stream)
185  {
186  ASSERT(ord == other.ord, "SimpleDenseMat::assign_unary: Storage orders must match");
187 
188  raft::linalg::unaryOp(data, other.data, len, f, stream);
189  }
190 
191  template <typename Lambda>
192  inline void assign_binary(const SimpleDenseMat<T>& other1,
193  const SimpleDenseMat<T>& other2,
194  Lambda& f,
195  cudaStream_t stream)
196  {
197  ASSERT(ord == other1.ord, "SimpleDenseMat::assign_binary: Storage orders must match");
198  ASSERT(ord == other2.ord, "SimpleDenseMat::assign_binary: Storage orders must match");
199 
200  raft::linalg::binaryOp(data, other1.data, other2.data, len, f, stream);
201  }
202 
203  template <typename Lambda>
204  inline void assign_ternary(const SimpleDenseMat<T>& other1,
205  const SimpleDenseMat<T>& other2,
206  const SimpleDenseMat<T>& other3,
207  Lambda& f,
208  cudaStream_t stream)
209  {
210  ASSERT(ord == other1.ord, "SimpleDenseMat::assign_ternary: Storage orders must match");
211  ASSERT(ord == other2.ord, "SimpleDenseMat::assign_ternary: Storage orders must match");
212  ASSERT(ord == other3.ord, "SimpleDenseMat::assign_ternary: Storage orders must match");
213 
214  raft::linalg::ternaryOp(data, other1.data, other2.data, other3.data, len, f, stream);
215  }
216 
217  inline void fill(const T val, cudaStream_t stream)
218  {
219  // TODO this reads data unnecessary, though it's mostly used for testing
220  auto f = [val] __device__(const T x) { return val; };
221  raft::linalg::unaryOp(data, data, len, f, stream);
222  }
223 
224  inline void copy_async(const SimpleDenseMat<T>& other, cudaStream_t stream)
225  {
226  ASSERT((ord == other.ord) && (this->m == other.m) && (this->n == other.n),
227  "SimpleDenseMat::copy: matrices not compatible");
228 
229  RAFT_CUDA_TRY(
230  cudaMemcpyAsync(data, other.data, len * sizeof(T), cudaMemcpyDeviceToDevice, stream));
231  }
232 
233  void print(std::ostream& oss) const override { oss << (*this) << std::endl; }
234 
235  void operator=(const SimpleDenseMat<T>& other) = delete;
236 };
237 
238 template <typename T>
241 
242  SimpleVec(T* data, const int n) : Super(data, n, 1, COL_MAJOR) {}
243  // this = alpha * A * x + beta * this
244  void assign_gemv(const raft::handle_t& handle,
245  const T alpha,
246  const SimpleDenseMat<T>& A,
247  bool transA,
248  const SimpleVec<T>& x,
249  const T beta,
250  cudaStream_t stream)
251  {
252  Super::assign_gemm(handle, alpha, A, transA, x, false, beta, stream);
253  }
254 
256 
257  inline void reset(T* new_data, int n) { Super::reset(new_data, n, 1); }
258 };
259 
260 template <typename T>
261 inline void col_ref(const SimpleDenseMat<T>& mat, SimpleVec<T>& mask_vec, int c)
262 {
263  ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats");
264  T* tmp = &mat.data[mat.m * c];
265  mask_vec.reset(tmp, mat.m);
266 }
267 
268 template <typename T>
269 inline void col_slice(const SimpleDenseMat<T>& mat,
270  SimpleDenseMat<T>& mask_mat,
271  int c_from,
272  int c_to)
273 {
274  ASSERT(c_from >= 0 && c_from < mat.n, "col_slice: invalid from");
275  ASSERT(c_to >= 0 && c_to <= mat.n, "col_slice: invalid to");
276 
277  ASSERT(mat.ord == COL_MAJOR, "col_ref only available for column major mats");
278  ASSERT(mask_mat.ord == COL_MAJOR, "col_ref only available for column major mask");
279  T* tmp = &mat.data[mat.m * c_from];
280  mask_mat.reset(tmp, mat.m, c_to - c_from);
281 }
282 
283 // Reductions such as dot or norm require an additional location in dev mem
284 // to hold the result. We don't want to deal with this in the SimpleVec class
285 // as it impedes thread safety and constness
286 
287 template <typename T>
288 inline T dot(const SimpleVec<T>& u, const SimpleVec<T>& v, T* tmp_dev, cudaStream_t stream)
289 {
290  auto f = [] __device__(const T x, const T y) { return x * y; };
291  raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data);
292  T tmp_host;
293  raft::update_host(&tmp_host, tmp_dev, 1, stream);
294 
296  return tmp_host;
297 }
298 
299 template <typename T>
300 inline T squaredNorm(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
301 {
302  return dot(u, u, tmp_dev, stream);
303 }
304 
305 template <typename T>
306 inline T nrmMax(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
307 {
308  auto f = [] __device__(const T x) { return raft::abs<T>(x); };
309  auto r = [] __device__(const T x, const T y) { return raft::max<T>(x, y); };
310  raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data);
311  T tmp_host;
312  raft::update_host(&tmp_host, tmp_dev, 1, stream);
314  return tmp_host;
315 }
316 
317 template <typename T>
318 inline T nrm2(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
319 {
320  return raft::mySqrt<T>(squaredNorm(u, tmp_dev, stream));
321 }
322 
323 template <typename T>
324 inline T nrm1(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
325 {
326  raft::linalg::rowNorm(
327  tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop<T>());
328  T tmp_host;
329  raft::update_host(&tmp_host, tmp_dev, 1, stream);
331  return tmp_host;
332 }
333 
334 template <typename T>
335 std::ostream& operator<<(std::ostream& os, const SimpleVec<T>& v)
336 {
337  std::vector<T> out(v.len);
338  raft::update_host(&out[0], v.data, v.len, 0);
339  raft::interruptible::synchronize(rmm::cuda_stream_view());
340  int it = 0;
341  for (; it < v.len - 1;) {
342  os << out[it] << " ";
343  it++;
344  }
345  os << out[it];
346  return os;
347 }
348 
349 template <typename T>
350 std::ostream& operator<<(std::ostream& os, const SimpleDenseMat<T>& mat)
351 {
352  os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n";
353  std::vector<T> out(mat.len);
354  raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default);
355  raft::interruptible::synchronize(rmm::cuda_stream_view());
356  if (mat.ord == COL_MAJOR) {
357  for (int r = 0; r < mat.m; r++) {
358  int idx = r;
359  for (int c = 0; c < mat.n - 1; c++) {
360  os << out[idx] << ",";
361  idx += mat.m;
362  }
363  os << out[idx] << std::endl;
364  }
365  } else {
366  for (int c = 0; c < mat.m; c++) {
367  int idx = c * mat.n;
368  for (int r = 0; r < mat.n - 1; r++) {
369  os << out[idx] << ",";
370  idx += 1;
371  }
372  os << out[idx] << std::endl;
373  }
374  }
375 
376  return os;
377 }
378 
379 template <typename T>
382  typedef rmm::device_uvector<T> Buffer;
384 
385  SimpleVecOwning() = delete;
386 
387  SimpleVecOwning(int n, cudaStream_t stream) : Super(), buf(n, stream)
388  {
389  Super::reset(buf.data(), n);
390  }
391 
392  void operator=(const SimpleVec<T>& other) = delete;
393 };
394 
395 template <typename T>
398  typedef rmm::device_uvector<T> Buffer;
400  using Super::m;
401  using Super::n;
402  using Super::ord;
403 
404  SimpleMatOwning() = delete;
405 
406  SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order = COL_MAJOR)
407  : Super(order), buf(m * n, stream)
408  {
409  Super::reset(buf.data(), m, n);
410  }
411 
412  void operator=(const SimpleVec<T>& other) = delete;
413 };
414 
415 }; // namespace ML
Definition: dbscan.hpp:30
void col_slice(const SimpleDenseMat< T > &mat, SimpleDenseMat< T > &mask_mat, int c_from, int c_to)
Definition: dense.hpp:269
T nrm1(const SimpleVec< T > &u, T *tmp_dev, cudaStream_t stream)
Definition: dense.hpp:324
std::ostream & operator<<(std::ostream &os, const SimpleVec< T > &v)
Definition: dense.hpp:335
T nrmMax(const SimpleVec< T > &u, T *tmp_dev, cudaStream_t stream)
Definition: dense.hpp:306
T squaredNorm(const SimpleVec< T > &u, T *tmp_dev, cudaStream_t stream)
Definition: dense.hpp:300
T dot(const SimpleVec< T > &u, const SimpleVec< T > &v, T *tmp_dev, cudaStream_t stream)
Definition: dense.hpp:288
T nrm2(const SimpleVec< T > &u, T *tmp_dev, cudaStream_t stream)
Definition: dense.hpp:318
STORAGE_ORDER
Definition: dense.hpp:38
@ ROW_MAJOR
Definition: dense.hpp:38
@ COL_MAJOR
Definition: dense.hpp:38
void col_ref(const SimpleDenseMat< T > &mat, SimpleVec< T > &mask_vec, int c)
Definition: dense.hpp:261
void synchronize(cuda_stream stream)
Definition: cuda_stream.hpp:27
Definition: dense.hpp:41
void fill(const T val, cudaStream_t stream)
Definition: dense.hpp:217
void assign_binary(const SimpleDenseMat< T > &other1, const SimpleDenseMat< T > &other2, Lambda &f, cudaStream_t stream)
Definition: dense.hpp:192
void assign_gemm(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const SimpleMat< T > &B, const bool transB, const T beta, cudaStream_t stream)
Definition: dense.hpp:149
static void gemm(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const SimpleDenseMat< T > &B, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream)
Definition: dense.hpp:66
void ax(const T a, const SimpleDenseMat< T > &x, cudaStream_t stream)
Definition: dense.hpp:162
void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const override
Definition: dense.hpp:130
void assign_unary(const SimpleDenseMat< T > &other, Lambda f, cudaStream_t stream)
Definition: dense.hpp:184
SimpleDenseMat(T *data, int m, int n, STORAGE_ORDER order=COL_MAJOR)
Definition: dense.hpp:52
void axpy(const T a, const SimpleDenseMat< T > &x, const SimpleDenseMat< T > &y, cudaStream_t stream)
Definition: dense.hpp:171
SimpleDenseMat(STORAGE_ORDER order=COL_MAJOR)
Definition: dense.hpp:48
void assign_ternary(const SimpleDenseMat< T > &other1, const SimpleDenseMat< T > &other2, const SimpleDenseMat< T > &other3, Lambda &f, cudaStream_t stream)
Definition: dense.hpp:204
int len
Definition: dense.hpp:43
void operator=(const SimpleDenseMat< T > &other)=delete
void copy_async(const SimpleDenseMat< T > &other, cudaStream_t stream)
Definition: dense.hpp:224
T * data
Definition: dense.hpp:44
void print(std::ostream &oss) const override
Definition: dense.hpp:233
SimpleMat< T > Super
Definition: dense.hpp:42
void reset(T *data_, int m_, int n_)
Definition: dense.hpp:57
STORAGE_ORDER ord
Definition: dense.hpp:46
Definition: dense.hpp:396
int m
Definition: base.hpp:29
SimpleMatOwning(int m, int n, cudaStream_t stream, STORAGE_ORDER order=COL_MAJOR)
Definition: dense.hpp:406
SimpleMatOwning()=delete
Buffer buf
Definition: dense.hpp:399
int n
Definition: base.hpp:29
rmm::device_uvector< T > Buffer
Definition: dense.hpp:398
SimpleDenseMat< T > Super
Definition: dense.hpp:397
void operator=(const SimpleVec< T > &other)=delete
Definition: base.hpp:28
int m
Definition: base.hpp:29
int n
Definition: base.hpp:29
virtual void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const =0
Definition: dense.hpp:380
void operator=(const SimpleVec< T > &other)=delete
SimpleVecOwning()=delete
SimpleVecOwning(int n, cudaStream_t stream)
Definition: dense.hpp:387
SimpleVec< T > Super
Definition: dense.hpp:381
rmm::device_uvector< T > Buffer
Definition: dense.hpp:382
Buffer buf
Definition: dense.hpp:383
Definition: dense.hpp:239
SimpleDenseMat< T > Super
Definition: dense.hpp:240
void reset(T *new_data, int n)
Definition: dense.hpp:257
void assign_gemv(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, bool transA, const SimpleVec< T > &x, const T beta, cudaStream_t stream)
Definition: dense.hpp:244
SimpleVec(T *data, const int n)
Definition: dense.hpp:242
SimpleVec()
Definition: dense.hpp:255