sparse.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include "base.hpp"
19 
20 #include <raft/core/handle.hpp>
21 #include <raft/linalg/add.cuh>
22 #include <raft/linalg/map_then_reduce.cuh>
23 #include <raft/linalg/norm.cuh>
24 #include <raft/linalg/ternary_op.cuh>
25 #include <raft/linalg/unary_op.cuh>
26 #include <raft/sparse/detail/cusparse_wrappers.h>
27 #include <raft/util/cuda_utils.cuh>
28 #include <raft/util/cudart_utils.hpp>
29 
30 #include <rmm/device_uvector.hpp>
31 
32 #include <iostream>
33 #include <vector>
34 
35 namespace ML {
36 
47 template <typename T, typename I = int>
50  T* values;
51  I* cols;
52  I* row_ids;
53  I nnz;
54 
55  SimpleSparseMat() : Super(0, 0), values(nullptr), cols(nullptr), row_ids(nullptr), nnz(0) {}
56 
57  SimpleSparseMat(T* values, I* cols, I* row_ids, I nnz, int m, int n)
59  {
60  check_csr(*this, 0);
61  }
62 
63  void print(std::ostream& oss) const override { oss << (*this) << std::endl; }
64 
65  void operator=(const SimpleSparseMat<T, I>& other) = delete;
66 
67  inline void gemmb(const raft::handle_t& handle,
68  const T alpha,
69  const SimpleDenseMat<T>& A,
70  const bool transA,
71  const bool transB,
72  const T beta,
74  cudaStream_t stream) const override
75  {
76  const SimpleSparseMat<T, I>& B = *this;
77  int kA = A.n;
78  int kB = B.m;
79 
80  if (transA) {
81  ASSERT(A.n == C.m, "GEMM invalid dims: m");
82  kA = A.m;
83  } else {
84  ASSERT(A.m == C.m, "GEMM invalid dims: m");
85  }
86 
87  if (transB) {
88  ASSERT(B.m == C.n, "GEMM invalid dims: n");
89  kB = B.n;
90  } else {
91  ASSERT(B.n == C.n, "GEMM invalid dims: n");
92  }
93  ASSERT(kA == kB, "GEMM invalid dims: k");
94 
95  // matrix C must change the order and be transposed, because we need
96  // to swap arguments A and B in cusparseSpMM.
97  cusparseDnMatDescr_t descrC;
98  auto order = C.ord == COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
99  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
100  &descrC, C.n, C.m, order == CUSPARSE_ORDER_COL ? C.n : C.m, C.data, order));
101 
102  /*
103  The matrix A must have the same order as the matrix C in the input
104  of function cusparseSpMM (i.e. swapped order w.r.t. original C).
105  To account this requirement, I may need to flip transA (whether to transpose A).
106 
107  C C' rowsC' colsC' ldC' A A' rowsA' colsA' ldA' flipTransA
108  c r n m m c r n m m x
109  c r n m m r r m n n o
110  r c n m n c c m n m o
111  r c n m n r c n m n x
112 
113  where:
114  c/r - column/row major order
115  A,C - input to gemmb
116  A', C' - input to cusparseSpMM
117  ldX' - leading dimension - m or n, depending on order and transX
118  */
119  cusparseDnMatDescr_t descrA;
120  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA,
121  C.ord == A.ord ? A.n : A.m,
122  C.ord == A.ord ? A.m : A.n,
123  A.ord == COL_MAJOR ? A.m : A.n,
124  A.data,
125  order));
126  auto opA =
127  transA ^ (C.ord == A.ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
128 
129  cusparseSpMatDescr_t descrB;
130  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
131  &descrB, B.m, B.n, B.nnz, B.row_ids, B.cols, B.values));
132  auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
133 
134  auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2;
135 
136  size_t bufferSize;
137  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(),
138  opB,
139  opA,
140  &alpha,
141  descrB,
142  descrA,
143  &beta,
144  descrC,
145  alg,
146  &bufferSize,
147  stream));
148 
150  rmm::device_uvector<T> tmp(bufferSize, stream);
151 
152  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(),
153  opB,
154  opA,
155  &alpha,
156  descrB,
157  descrA,
158  &beta,
159  descrC,
160  alg,
161  tmp.data(),
162  stream));
163 
164  RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA));
165  RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB));
166  RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC));
167  }
168 };
169 
170 template <typename T, typename I = int>
171 inline void check_csr(const SimpleSparseMat<T, I>& mat, cudaStream_t stream)
172 {
173  I row_ids_nnz;
174  raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream);
176  ASSERT(row_ids_nnz == mat.nnz,
177  "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
178  "the last element must be equal nnz.");
179 }
180 
181 template <typename T, typename I = int>
182 std::ostream& operator<<(std::ostream& os, const SimpleSparseMat<T, I>& mat)
183 {
184  check_csr(mat, 0);
185  os << "SimpleSparseMat (CSR)"
186  << "\n";
187  std::vector<T> values(mat.nnz);
188  std::vector<I> cols(mat.nnz);
189  std::vector<I> row_ids(mat.m + 1);
190  raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default);
191  raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default);
192  raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default);
193  raft::interruptible::synchronize(rmm::cuda_stream_view());
194 
195  int i, row_end = 0;
196  for (int row = 0; row < mat.m; row++) {
197  i = row_end;
198  row_end = row_ids[row + 1];
199  for (int col = 0; col < mat.n; col++) {
200  if (i >= row_end || col < cols[i]) {
201  os << "0";
202  } else {
203  os << values[i];
204  i++;
205  }
206  if (col < mat.n - 1) os << ",";
207  }
208 
209  os << std::endl;
210  }
211 
212  return os;
213 }
214 
215 }; // namespace ML
Definition: dbscan.hpp:30
std::ostream & operator<<(std::ostream &os, const SimpleVec< T > &v)
Definition: dense.hpp:335
void check_csr(const SimpleSparseMat< T, I > &mat, cudaStream_t stream)
Definition: sparse.hpp:171
@ COL_MAJOR
Definition: dense.hpp:38
void synchronize(cuda_stream stream)
Definition: cuda_stream.hpp:27
Definition: dense.hpp:41
T * data
Definition: dense.hpp:44
STORAGE_ORDER ord
Definition: dense.hpp:46
Definition: base.hpp:28
int m
Definition: base.hpp:29
int n
Definition: base.hpp:29
Definition: sparse.hpp:48
SimpleSparseMat(T *values, I *cols, I *row_ids, I nnz, int m, int n)
Definition: sparse.hpp:57
void print(std::ostream &oss) const override
Definition: sparse.hpp:63
T * values
Definition: sparse.hpp:50
SimpleSparseMat()
Definition: sparse.hpp:55
I * row_ids
Definition: sparse.hpp:52
I nnz
Definition: sparse.hpp:53
SimpleMat< T > Super
Definition: sparse.hpp:49
void operator=(const SimpleSparseMat< T, I > &other)=delete
I * cols
Definition: sparse.hpp:51
void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const override
Definition: sparse.hpp:67