9 #include <raft/core/handle.hpp>
10 #include <raft/linalg/add.cuh>
11 #include <raft/linalg/map_then_reduce.cuh>
12 #include <raft/linalg/norm.cuh>
13 #include <raft/linalg/ternary_op.cuh>
14 #include <raft/linalg/unary_op.cuh>
15 #include <raft/sparse/detail/cusparse_wrappers.h>
16 #include <raft/util/cuda_utils.cuh>
17 #include <raft/util/cudart_utils.hpp>
19 #include <rmm/device_uvector.hpp>
36 template <
typename T,
typename I =
int>
52 void print(std::ostream& oss)
const override { oss << (*this) << std::endl; }
56 inline void gemmb(
const raft::handle_t& handle,
63 cudaStream_t stream)
const override
70 ASSERT(A.
n == C.
m,
"GEMM invalid dims: m");
73 ASSERT(A.
m == C.
m,
"GEMM invalid dims: m");
77 ASSERT(B.
m == C.
n,
"GEMM invalid dims: n");
80 ASSERT(B.
n == C.
n,
"GEMM invalid dims: n");
82 ASSERT(kA == kB,
"GEMM invalid dims: k");
86 cusparseDnMatDescr_t descrC;
87 auto order = C.
ord ==
COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
88 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
89 &descrC, C.
n, C.
m, order == CUSPARSE_ORDER_COL ? C.
n : C.
m, C.
data, order));
108 cusparseDnMatDescr_t descrA;
109 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA,
116 transA ^ (C.
ord == A.
ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
118 cusparseSpMatDescr_t descrB;
119 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
121 auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
123 auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2;
126 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(),
139 rmm::device_uvector<T> tmp(bufferSize, stream);
141 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(),
153 RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA));
154 RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB));
155 RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC));
159 template <
typename T,
typename I =
int>
163 raft::update_host(&row_ids_nnz, &mat.
row_ids[mat.
m], 1, stream);
165 ASSERT(row_ids_nnz == mat.
nnz,
166 "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
167 "the last element must be equal nnz.");
170 template <
typename T,
typename I =
int>
174 os <<
"SimpleSparseMat (CSR)"
176 std::vector<T> values(mat.
nnz);
177 std::vector<I> cols(mat.
nnz);
178 std::vector<I> row_ids(mat.
m + 1);
179 raft::update_host(&values[0], mat.
values, mat.
nnz, rmm::cuda_stream_default);
180 raft::update_host(&cols[0], mat.
cols, mat.
nnz, rmm::cuda_stream_default);
181 raft::update_host(&row_ids[0], mat.
row_ids, mat.
m + 1, rmm::cuda_stream_default);
185 for (
int row = 0; row < mat.
m; row++) {
187 row_end = row_ids[row + 1];
188 for (
int col = 0; col < mat.
n; col++) {
189 if (i >= row_end || col < cols[i]) {
195 if (col < mat.
n - 1) os <<
",";
Definition: dbscan.hpp:18
std::ostream & operator<<(std::ostream &os, const SimpleVec< T > &v)
Definition: dense.hpp:324
void check_csr(const SimpleSparseMat< T, I > &mat, cudaStream_t stream)
Definition: sparse.hpp:160
@ COL_MAJOR
Definition: dense.hpp:27
void synchronize(cuda_stream stream)
Definition: cuda_stream.hpp:16
T * data
Definition: dense.hpp:33
STORAGE_ORDER ord
Definition: dense.hpp:35
int m
Definition: base.hpp:18
int n
Definition: base.hpp:18
Definition: sparse.hpp:37
SimpleSparseMat(T *values, I *cols, I *row_ids, I nnz, int m, int n)
Definition: sparse.hpp:46
void print(std::ostream &oss) const override
Definition: sparse.hpp:52
T * values
Definition: sparse.hpp:39
SimpleSparseMat()
Definition: sparse.hpp:44
I * row_ids
Definition: sparse.hpp:41
I nnz
Definition: sparse.hpp:42
SimpleMat< T > Super
Definition: sparse.hpp:38
void operator=(const SimpleSparseMat< T, I > &other)=delete
I * cols
Definition: sparse.hpp:40
void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const override
Definition: sparse.hpp:56