part_descriptor.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include "data.hpp"
9 
10 #include <stdint.h>
11 
12 #include <ostream>
13 #include <set>
14 #include <vector>
15 
16 namespace MLCommon {
17 namespace Matrix {
18 
20 enum Layout {
25 };
26 
27 struct RankSizePair {
28  RankSizePair() : rank(-1), size(0) {}
29 
30  RankSizePair(int _rank, size_t _size) : rank(_rank), size(_size) {}
31 
32  int rank;
33 
37  size_t size;
38 };
39 
42  size_t M;
44  size_t N;
45 
46  int rank;
47 
50  std::vector<RankSizePair*> partsToRanks;
51 
63  PartDescriptor(size_t _M,
64  size_t _N,
65  const std::vector<RankSizePair*>& _partsToRanks,
66  int rank,
67  Layout _layout = LayoutColMajor);
68 
70  int totalBlocks() const { return partsToRanks.size(); }
71 
73  int totalBlocksOwnedBy(int rank) const;
74 
75  std::set<int> uniqueRanks();
76 
77  std::vector<size_t> startIndices() const;
78 
79  std::vector<size_t> startIndices(int rank) const;
80 
85  std::vector<RankSizePair*> blocksOwnedBy(int rank) const;
86 
88  size_t totalElementsOwnedBy(int rank) const;
89 
90  friend std::ostream& operator<<(std::ostream& os, const PartDescriptor& desc);
91  friend bool operator==(const PartDescriptor& a, const PartDescriptor& b);
92 };
93 
95 std::ostream& operator<<(std::ostream& os, const PartDescriptor& desc);
96 
98 bool operator==(const PartDescriptor& a, const PartDescriptor& b);
99 
100 }; // end namespace Matrix
101 }; // end namespace MLCommon
bool operator==(const PartDescriptor &a, const PartDescriptor &b)
Layout
Definition: part_descriptor.hpp:20
@ LayoutRowMajor
Definition: part_descriptor.hpp:22
@ LayoutColMajor
Definition: part_descriptor.hpp:24
std::ostream & operator<<(std::ostream &os, const PartDescriptor &desc)
Definition: comm_utils.h:11
Definition: part_descriptor.hpp:40
friend bool operator==(const PartDescriptor &a, const PartDescriptor &b)
int rank
Definition: part_descriptor.hpp:46
std::vector< size_t > startIndices(int rank) const
int totalBlocks() const
Definition: part_descriptor.hpp:70
std::vector< RankSizePair * > blocksOwnedBy(int rank) const
Returns the vector of blocks (each identified by linearBLockIndex) owned by the given rank.
size_t totalElementsOwnedBy(int rank) const
std::vector< RankSizePair * > partsToRanks
Definition: part_descriptor.hpp:50
std::vector< size_t > startIndices() const
PartDescriptor(size_t _M, size_t _N, const std::vector< RankSizePair * > &_partsToRanks, int rank, Layout _layout=LayoutColMajor)
For a given matrix and block-sizes construct the corresponding descriptor for it. This is useful when...
size_t N
Definition: part_descriptor.hpp:44
int totalBlocksOwnedBy(int rank) const
size_t M
Definition: part_descriptor.hpp:42
Layout layout
Definition: part_descriptor.hpp:48
friend std::ostream & operator<<(std::ostream &os, const PartDescriptor &desc)
Definition: part_descriptor.hpp:27
RankSizePair(int _rank, size_t _size)
Definition: part_descriptor.hpp:30
int rank
Definition: part_descriptor.hpp:32
size_t size
Definition: part_descriptor.hpp:37
RankSizePair()
Definition: part_descriptor.hpp:28