pylibwholegraph.torch.embedding.create_embedding#
- pylibwholegraph.torch.embedding.create_embedding(comm: WholeMemoryCommunicator, memory_type: str, memory_location: str, dtype: dtype, sizes: List[int], *, cache_policy: WholeMemoryCachePolicy | None = None, embedding_entry_partition: List[int] | None = None, random_init: bool = False, gather_sms: int = -1, round_robin_size: int = 0)#
Create embedding :param comm: WholeMemoryCommunicator :param memory_type: WholeMemory type, should be continuous, chunked or distributed :param memory_location: WholeMemory location, should be cpu or cuda :param dtype: data type :param sizes: size of the embedding, must be 2D :param cache_policy: cache policy :param embedding_entry_partition: rank partition based on entry;
embedding_entry_partition[i] determines the entry count of rank i and shoud be a positive integer; the sum of embedding_entry_partition should equal to total entry count; entries will be equally partitioned if None
- Parameters:
gather_sms – the number of SMs used in gather process
round_robin_size – continuous embedding size of a rank using round robin shard strategy
- Returns:
WholeMemoryEmbedding