cugraph_pyg.sampler.distributed_sampler.BaseDistributedSampler#

class cugraph_pyg.sampler.distributed_sampler.BaseDistributedSampler(graph: SGGraph | MGGraph, local_seeds_per_call: int, retain_original_seeds: bool = False)[source]#

Base class for distributed graph sampling using cuGraph.

This abstract base class provides the foundation for distributed graph sampling operations that leverage cuGraph’s high-performance graph analytics capabilities through pylibcugraph. It enables synchronized sampling across multiple workers/GPUs in a distributed environment, processing many batches simultaneously.

Subclasses should implement the sample_batches() method to define specific sampling strategies (e.g., neighbor sampling, random walks, etc.).

Attributes:
is_multi_gpu

Methods

get_start_batch_offset(local_num_batches[, ...])

Gets the starting batch offset to ensure each rank's set of batch ids is disjoint.

sample_batches(seeds, batch_id_offsets[, ...])

For a single call group of seeds and associated batch ids, performs sampling.

sample_from_edges(edges, *[, batch_size, ...])

Performs sampling starting from seed edges.

sample_from_nodes(nodes, *[, batch_size, ...])

Performs node-based sampling.

Examples

>>> # Create a distributed neighbor sampler
>>> sampler = DistributedNeighborSampler(
...     graph=mg_graph,
...     fanout=[25, 10],
...     local_seeds_per_call=1024
... )
>>>
>>> # Sample from nodes
>>> for batch in sampler.sample_from_nodes(nodes=seed_nodes):
...     # Process batch
...     pass
__init__(graph: SGGraph | MGGraph, local_seeds_per_call: int, retain_original_seeds: bool = False)[source]#
Parameters:
graph: SGGraph or MGGraph (required)

The pylibcugraph graph object that will be sampled.

local_seeds_per_call: int

The number of seeds on this rank this sampler will process in a single sampling call. Batches will get split into multiple sampling calls based on this parameter. This parameter must be the same across all ranks. The total number of seeds processed per sampling call is this parameter times the world size. Subclasses should generally calculate the appropriate number of seeds.

retain_original_seeds: bool (optional, default=False)

Whether to retain the original seeds even if they do not appear in the output minibatch. This will affect the output renumber map and CSR/CSC graph if applicable.

Methods

__init__(graph, local_seeds_per_call[, ...])

get_start_batch_offset(local_num_batches[, ...])

Gets the starting batch offset to ensure each rank's set of batch ids is disjoint.

sample_batches(seeds, batch_id_offsets[, ...])

For a single call group of seeds and associated batch ids, performs sampling.

sample_from_edges(edges, *[, batch_size, ...])

Performs sampling starting from seed edges.

sample_from_nodes(nodes, *[, batch_size, ...])

Performs node-based sampling.

Attributes

is_multi_gpu