20 #include <raft/core/handle.hpp>
21 #include <raft/linalg/add.cuh>
22 #include <raft/linalg/map_then_reduce.cuh>
23 #include <raft/linalg/norm.cuh>
24 #include <raft/linalg/ternary_op.cuh>
25 #include <raft/linalg/unary_op.cuh>
26 #include <raft/sparse/detail/cusparse_wrappers.h>
27 #include <raft/util/cuda_utils.cuh>
28 #include <raft/util/cudart_utils.hpp>
30 #include <rmm/device_uvector.hpp>
47 template <
typename T,
typename I =
int>
63 void print(std::ostream& oss)
const override { oss << (*this) << std::endl; }
67 inline void gemmb(
const raft::handle_t& handle,
74 cudaStream_t stream)
const override
81 ASSERT(A.
n == C.
m,
"GEMM invalid dims: m");
84 ASSERT(A.
m == C.
m,
"GEMM invalid dims: m");
88 ASSERT(B.
m == C.
n,
"GEMM invalid dims: n");
91 ASSERT(B.
n == C.
n,
"GEMM invalid dims: n");
93 ASSERT(kA == kB,
"GEMM invalid dims: k");
97 cusparseDnMatDescr_t descrC;
98 auto order = C.
ord ==
COL_MAJOR ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
99 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
100 &descrC, C.
n, C.
m, order == CUSPARSE_ORDER_COL ? C.
n : C.
m, C.
data, order));
119 cusparseDnMatDescr_t descrA;
120 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(&descrA,
127 transA ^ (C.
ord == A.
ord) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
129 cusparseSpMatDescr_t descrB;
130 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
132 auto opB = transB ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE;
134 auto alg = order == CUSPARSE_ORDER_COL ? CUSPARSE_SPMM_CSR_ALG1 : CUSPARSE_SPMM_CSR_ALG2;
137 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(),
150 rmm::device_uvector<T> tmp(bufferSize, stream);
152 RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(),
164 RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrA));
165 RAFT_CUSPARSE_TRY(cusparseDestroySpMat(descrB));
166 RAFT_CUSPARSE_TRY(cusparseDestroyDnMat(descrC));
170 template <
typename T,
typename I =
int>
174 raft::update_host(&row_ids_nnz, &mat.
row_ids[mat.
m], 1, stream);
176 ASSERT(row_ids_nnz == mat.
nnz,
177 "SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
178 "the last element must be equal nnz.");
181 template <
typename T,
typename I =
int>
185 os <<
"SimpleSparseMat (CSR)"
187 std::vector<T> values(mat.
nnz);
188 std::vector<I> cols(mat.
nnz);
189 std::vector<I> row_ids(mat.
m + 1);
190 raft::update_host(&values[0], mat.
values, mat.
nnz, rmm::cuda_stream_default);
191 raft::update_host(&cols[0], mat.
cols, mat.
nnz, rmm::cuda_stream_default);
192 raft::update_host(&row_ids[0], mat.
row_ids, mat.
m + 1, rmm::cuda_stream_default);
196 for (
int row = 0; row < mat.
m; row++) {
198 row_end = row_ids[row + 1];
199 for (
int col = 0; col < mat.
n; col++) {
200 if (i >= row_end || col < cols[i]) {
206 if (col < mat.
n - 1) os <<
",";
Definition: dbscan.hpp:30
std::ostream & operator<<(std::ostream &os, const SimpleVec< T > &v)
Definition: dense.hpp:335
void check_csr(const SimpleSparseMat< T, I > &mat, cudaStream_t stream)
Definition: sparse.hpp:171
@ COL_MAJOR
Definition: dense.hpp:38
void synchronize(cuda_stream stream)
Definition: cuda_stream.hpp:27
T * data
Definition: dense.hpp:44
STORAGE_ORDER ord
Definition: dense.hpp:46
int m
Definition: base.hpp:29
int n
Definition: base.hpp:29
Definition: sparse.hpp:48
SimpleSparseMat(T *values, I *cols, I *row_ids, I nnz, int m, int n)
Definition: sparse.hpp:57
void print(std::ostream &oss) const override
Definition: sparse.hpp:63
T * values
Definition: sparse.hpp:50
SimpleSparseMat()
Definition: sparse.hpp:55
I * row_ids
Definition: sparse.hpp:52
I nnz
Definition: sparse.hpp:53
SimpleMat< T > Super
Definition: sparse.hpp:49
void operator=(const SimpleSparseMat< T, I > &other)=delete
I * cols
Definition: sparse.hpp:51
void gemmb(const raft::handle_t &handle, const T alpha, const SimpleDenseMat< T > &A, const bool transA, const bool transB, const T beta, SimpleDenseMat< T > &C, cudaStream_t stream) const override
Definition: sparse.hpp:67