Namespaces | Functions
gemm.hpp File Reference
#include "../matrix/data.hpp"
#include "../matrix/part_descriptor.hpp"
#include <raft/core/handle.hpp>
Include dependency graph for gemm.hpp:

Go to the source code of this file.

Namespaces

 MLCommon
 
 MLCommon::LinAlg
 
 MLCommon::LinAlg::opg
 

Functions

void MLCommon::LinAlg::opg::gemm (const raft::handle_t &h, std::vector< Matrix::Data< float > * > &outZParts, Matrix::PartDescriptor &outZDesc, std::vector< Matrix::Data< float > * > &inXParts, Matrix::PartDescriptor &inXDesc, std::vector< Matrix::Data< float > * > &inYParts, Matrix::PartDescriptor &inYDesc, int myRank, cudaStream_t stream)
 A multi gpu generalized matrix multiplication function. This function performs Z = X * Y The X and Y matrix are distributed in blocks on different ranks. First Y matrix is duplicated at each rank. It is multiplied with blocks of X local to the rank. More...
 
void MLCommon::LinAlg::opg::gemm (const raft::handle_t &h, std::vector< Matrix::Data< double > * > &outZParts, Matrix::PartDescriptor &outZDesc, std::vector< Matrix::Data< double > * > &inXParts, Matrix::PartDescriptor &inXDesc, std::vector< Matrix::Data< double > * > &inYParts, Matrix::PartDescriptor &inYDesc, int myRank, cudaStream_t stream)