pylibwholegraph.torch.embedding.create_embedding#

pylibwholegraph.torch.embedding.create_embedding(comm: WholeMemoryCommunicator, memory_type: str, memory_location: str, dtype: dtype, sizes: List[int], *, optimizer: Optional[WholeMemoryOptimizer] = None, cache_policy: Optional[WholeMemoryCachePolicy] = None, random_init: bool = False, gather_sms: int = -1, round_robin_size=0)#

Create embedding :param comm: WholeMemoryCommunicator :param memory_type: WholeMemory type, should be continuous, chunked or distributed :param memory_location: WholeMemory location, should be cpu or cuda :param dtype: data type :param sizes: size of the embedding, must be 2D :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process :return: WholeMemoryEmbedding