#include "../matrix/data.hpp"
#include "../matrix/part_descriptor.hpp"
#include <raft/core/handle.hpp>
Go to the source code of this file.
|
| void | MLCommon::LinAlg::opg::gemm (const raft::handle_t &h, std::vector< Matrix::Data< float > * > &outZParts, Matrix::PartDescriptor &outZDesc, std::vector< Matrix::Data< float > * > &inXParts, Matrix::PartDescriptor &inXDesc, std::vector< Matrix::Data< float > * > &inYParts, Matrix::PartDescriptor &inYDesc, int myRank, cudaStream_t stream) |
| | A multi gpu generalized matrix multiplication function. This function performs Z = X * Y The X and Y matrix are distributed in blocks on different ranks. First Y matrix is duplicated at each rank. It is multiplied with blocks of X local to the rank. More...
|
| |
| void | MLCommon::LinAlg::opg::gemm (const raft::handle_t &h, std::vector< Matrix::Data< double > * > &outZParts, Matrix::PartDescriptor &outZDesc, std::vector< Matrix::Data< double > * > &inXParts, Matrix::PartDescriptor &inXDesc, std::vector< Matrix::Data< double > * > &inYParts, Matrix::PartDescriptor &inYDesc, int myRank, cudaStream_t stream) |
| |