sparse.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include "base.hpp"
8 
9 #include <raft/core/handle.hpp>
10 #include <raft/linalg/add.cuh>
11 #include <raft/linalg/map_then_reduce.cuh>
12 #include <raft/linalg/norm.cuh>
13 #include <raft/linalg/ternary_op.cuh>
14 #include <raft/linalg/unary_op.cuh>
15 #include <raft/sparse/detail/cusparse_wrappers.h>
16 #include <raft/util/cuda_utils.cuh>
17 #include <raft/util/cudart_utils.hpp>
18 
19 #include <rmm/device_uvector.hpp>
20 
21 #include <iostream>
22 #include <vector>
23 
24 namespace ML {
25 
36 template <typename T, typename I = int>
39  T* values;
40  I* cols;
41  I* row_ids;
42  I nnz;
43 
44  SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {}
45 
46  SimpleSparseMat(T* values, I* cols, I* row_ids, I nnz, int m, int n)
48  {
49  check_csr(*this, 0);
50  }
51 
52  void print(std::ostream& oss) const override { oss << (*this) << std::endl; }
53 
54  void operator=(const SimpleSparseMat<T, I>& other) = delete;
55 
56  inline void gemmb(const raft::handle_t& handle,
57  const T alpha,
58  const SimpleDenseMat<T>& A,
59  const bool transA,
60  const bool transB,
61  const T beta,
63  cudaStream_t stream) const override
64  {
65  const SimpleSparseMat<T, I>& B = *this;
66  int kA = A.n;
67  int kB = B.m;
68 
69  if (transA) {
70  ASSERT(A.n == C.m, "GEMM invalid dims: m");
71  kA = A.m;
72  } else {
73  ASSERT(A.m == C.m, "GEMM invalid dims: m");
74  }
75 
76  if (transB) {
77  ASSERT(B.m == C.n, "GEMM invalid dims: n");
78  kB = B.n;
79  } else {
80  ASSERT(B.n == C.n, "GEMM invalid dims: n");
81  }
82  ASSERT(kA == kB, "GEMM invalid dims: k");
83 
84  // matrix C must change the order and be transposed, because we need
85  // to swap arguments A and B in cusparseSpMM.
86  cusparseDnMatDescr_t descrC;
87  auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
88  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
89  &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order));
90 
91  /*
92  The matrix A must have the same order as the matrix C in the input
93  of function cusparseSpMM (i.e. swapped order w.r.t. original C).
94  To account this requirement, I may need to flip transA (whether to transpose A).
95 
96  C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA
97  c r n m m c r n m m x
98  c r n m m r r m n n o
99  r c n m n c c m n m o
100  r c n m n r c n m n x
101 
102  where:
103  c/r - column/row major order
104  A,C - input to gemmb
105  A', C' - input to cusparseSpMM
106  ldX' - leading dimension - m or n, depending on order and transX
107  */
108  cusparseDnMatDescr_t descrA;
109  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA,
110  C.ord == A.ord ? A.n : A.m,
111  C.ord == A.ord ? A.m : A.n,
112  A.ord == COL_MAJOR ? A.m : A.n,
113  A.data,
114  order));
115  auto opA =
116  transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
117 
118  cusparseSpMatDescr_t descrB;
119  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
120  &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values));
121  auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
122 
123  auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2;
124 
125  size_t bufferSize;
126  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(),
127  opB,
128  opA,
129  &alpha,
130  descrB,
131  descrA,
132  &beta,
133  descrC,
134  alg,
135  &bufferSize,
136  stream));
137 
139  rmm::device_uvector<T> tmp(bufferSize, stream);
140 
141  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(),
142  opB,
143  opA,
144  &alpha,
145  descrB,
146  descrA,
147  &beta,
148  descrC,
149  alg,
150  tmp.data(),
151  stream));
152 
153  RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA));
154  RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB));
155  RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC));
156  }
157 };
158 
159 template <typename T, typename I = int>
160 inline void check_csr(const SimpleSparseMat<T, I>& mat, cudaStream_t stream)
161 {
162  I row_ids_nnz;
163  raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream);
165  ASSERT(row_ids_nnz == mat.nnz,
166  "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
167  "the last element must be equal nnz.");
168 }
169 
170 template <typename T, typename I = int>
171 std::ostream& operator<<(std::ostream& os, const SimpleSparseMat<T, I>& mat)
172 {
173  check_csr(mat, 0);
174  os << "SimpleSparseMat (CSR)"
175  << "\n";
176  std::vector<T> values(mat.nnz);
177  std::vector<I> cols(mat.nnz);
178  std::vector<I> row_ids(mat.m + 1);
179  raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default);
180  raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default);
181  raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default);
182  raft::interruptible::synchronize(rmm::cuda_stream_view());
183 
184  int i, row_end = 0;
185  for (int row = 0; row < mat.m; row++) {
186  i = row_end;
187  row_end = row_ids[row + 1];
188  for (int col = 0; col < mat.n; col++) {
189  if (i >= row_end || col < cols[i]) {
190  os << "0";
191  } else {
192  os << values[i];
193  i++;
194  }
195  if (col < mat.n - 1) os << ",";
196  }
197 
198  os << std::endl;
199  }
200 
201  return os;
202 }
203 
204 }; // namespace ML
Definition: dbscan.hpp:18
std::ostream & operator<<(std::ostream &os, const SimpleVec< T > &v)
Definition: dense.hpp:324
void check_csr(const SimpleSparseMat< T, I > &mat, cudaStream_t stream)
Definition: sparse.hpp:160
@ COL_MAJOR
Definition: dense.hpp:27
void synchronize(cuda_stream stream)
Definition: cuda_stream.hpp:16
Definition: dense.hpp:30
T * data
Definition: dense.hpp:33
STORAGE_ORDER ord
Definition: dense.hpp:35
Definition: base.hpp:17
int m
Definition: base.hpp:18
int n
Definition: base.hpp:18
Definition: sparse.hpp:37
SimpleSparseMat(T *values, I *cols, I *row_ids, I nnz, int m, int n)
Definition: sparse.hpp:46
void print(std::ostream &oss) const override
Definition: sparse.hpp:52
T * values
Definition: sparse.hpp:39
SimpleSparseMat()
Definition: sparse.hpp:44
I * row_ids
Definition: sparse.hpp:41
I nnz
Definition: sparse.hpp:42
SimpleMat< T > Super
Definition: sparse.hpp:38
void operator=(const SimpleSparseMat< T, I > &other)=delete
I * cols
Definition: sparse.hpp:40
void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const override
Definition: sparse.hpp:56