gemm.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include "../matrix/data.hpp"
8 #include "../matrix/part_descriptor.hpp"
9 
10 #include <raft/core/handle.hpp>
11 
12 namespace MLCommon {
13 namespace LinAlg {
14 namespace opg {
15 
39 void gemm(const raft::handle_t& h,
40  std::vector<Matrix::Data<float>*>& outZParts,
41  Matrix::PartDescriptor& outZDesc,
42  std::vector<Matrix::Data<float>*>& inXParts,
43  Matrix::PartDescriptor& inXDesc,
44  std::vector<Matrix::Data<float>*>& inYParts,
45  Matrix::PartDescriptor& inYDesc,
46  int myRank,
47  cudaStream_t stream);
48 
49 void gemm(const raft::handle_t& h,
50  std::vector<Matrix::Data<double>*>& outZParts,
51  Matrix::PartDescriptor& outZDesc,
52  std::vector<Matrix::Data<double>*>& inXParts,
53  Matrix::PartDescriptor& inXDesc,
54  std::vector<Matrix::Data<double>*>& inYParts,
55  Matrix::PartDescriptor& inYDesc,
56  int myRank,
57  cudaStream_t stream);
58 
59 } // end namespace opg
60 } // end namespace LinAlg
61 } // end namespace MLCommon
void gemm(const raft::handle_t &h, std::vector< Matrix::Data< float > * > &outZParts, Matrix::PartDescriptor &outZDesc, std::vector< Matrix::Data< float > * > &inXParts, Matrix::PartDescriptor &inXDesc, std::vector< Matrix::Data< float > * > &inYParts, Matrix::PartDescriptor &inYDesc, int myRank, cudaStream_t stream)
A multi gpu generalized matrix multiplication function. This function performs Z = X * Y The X and Y ...
Definition: comm_utils.h:11
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40