Source code for dask_cuda.explicit_comms.comms

import asyncio
import concurrent.futures
import contextlib
import time
import uuid
from typing import Any, Dict, Hashable, Iterable, List, Optional

import distributed.comm
from distributed import Client, Worker, default_client, get_worker
from distributed.comm.addressing import parse_address, parse_host_port, unparse_address

_default_comms = None


def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
    """Return either a MultiLock or a NULL context

    Parameters
    ----------
    multi_lock_context: bool
        If True return MultiLock context else return a NULL context that
        doesn't do anything

    *args, **kwargs:
        Arguments parsed to the MultiLock creation

    Returns
    -------
    context: context
        Either `MultiLock(*args, **kwargs)` or a NULL context
    """
    if multi_lock_context:
        from distributed import MultiLock

        return MultiLock(*args, **kwargs)
    else:
        return contextlib.nullcontext()


def default_comms(client: Optional[Client] = None) -> "CommsContext":
    """Return the default comms object

    Creates a new default comms object if no one exist.

    Parameters
    ----------
    client: Client, optional
        If no default comm object exists, create the new comm on `client`
        are returned.

    Returns
    -------
    comms: CommsContext
        The default comms object
    """
    global _default_comms
    if _default_comms is None:
        _default_comms = CommsContext(client=client)
    return _default_comms


def worker_state(sessionId: Optional[int] = None) -> dict:
    """Retrieve the state(s) of the current worker

    Parameters
    ----------
    sessionId: int, optional
        Worker session state ID. If None, all states of the worker
        are returned.

    Returns
    -------
    state: dict
        Either a single state dict or a dict of state dict
    """
    worker: Any = get_worker()
    if not hasattr(worker, "_explicit_comm_state"):
        worker._explicit_comm_state = {}
    if sessionId is not None:
        if sessionId not in worker._explicit_comm_state:
            worker._explicit_comm_state[sessionId] = {
                "ts": time.time(),
                "eps": {},
                "loop": worker.loop.asyncio_loop,
                "worker": worker,
            }
        return worker._explicit_comm_state[sessionId]
    return worker._explicit_comm_state


def _run_coroutine_on_worker(sessionId, coroutine, args):
    session_state = worker_state(sessionId)

    def _run():
        future = asyncio.run_coroutine_threadsafe(
            coroutine(session_state, *args), session_state["loop"]
        )
        return future.result()

    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        return executor.submit(_run).result()


async def _create_listeners(session_state, nworkers, rank):
    assert session_state["loop"] is asyncio.get_event_loop()
    assert "nworkers" not in session_state
    session_state["nworkers"] = nworkers
    assert "rank" not in session_state
    session_state["rank"] = rank

    async def server_handler(ep):
        peer_rank = await ep.read()
        session_state["eps"][peer_rank] = ep

    # We listen on the same protocol and address as the worker address
    protocol, address = parse_address(session_state["worker"].address)
    address = parse_host_port(address)[0]
    address = unparse_address(protocol, address)

    session_state["lf"] = distributed.comm.listen(address, server_handler)
    await session_state["lf"].start()
    return session_state["lf"].listen_address


async def _create_endpoints(session_state, peers):
    """Each worker creates a UCX endpoint to all workers with greater rank"""
    assert session_state["loop"] is asyncio.get_event_loop()

    myrank = session_state["rank"]
    peers = list(enumerate(peers))

    # Create endpoints to workers with a greater rank than my rank
    for rank, address in peers[myrank + 1 :]:
        ep = await distributed.comm.connect(address)
        await ep.write(session_state["rank"])
        session_state["eps"][rank] = ep

    # Block until all endpoints has been created
    while len(session_state["eps"]) < session_state["nworkers"] - 1:
        await asyncio.sleep(0.1)


async def _stop_ucp_listeners(session_state):
    assert len(session_state["eps"]) == session_state["nworkers"] - 1
    assert session_state["loop"] is asyncio.get_event_loop()
    session_state["lf"].stop()
    del session_state["lf"]


async def _stage_keys(session_state: dict, name: str, keys: set):
    worker: Worker = session_state["worker"]
    data = worker.data
    my_keys = keys.intersection(data)

    stages = session_state.get("stages", {})
    stage = stages.get(name, {})
    for k in my_keys:
        stage[k] = data[k]
    stages[name] = stage
    session_state["stages"] = stages
    return (session_state["rank"], my_keys)


[docs] class CommsContext: """Communication handler for explicit communication Parameters ---------- client: Client, optional Specify client to use for communication. If None, use the default client. """ client: Client sessionId: int worker_addresses: List[str] def __init__(self, client: Optional[Client] = None): self.client = client if client is not None else default_client() self.sessionId = uuid.uuid4().int # Get address of all workers (not Nanny addresses) self.worker_addresses = list(self.client.scheduler_info()["workers"].keys()) # Make all workers listen and get all listen addresses self.worker_direct_addresses = [] for rank, address in enumerate(self.worker_addresses): self.worker_direct_addresses.append( self.submit( address, _create_listeners, len(self.worker_addresses), rank, wait=True, ) ) # Each worker creates an endpoint to all workers with greater rank self.run(_create_endpoints, self.worker_direct_addresses) # At this point all workers should have a rank and endpoints to # all other workers thus we can now stop the listening. self.run(_stop_ucp_listeners)
[docs] def submit(self, worker, coroutine, *args, wait=False): """Run a coroutine on a single worker The coroutine is given the worker's state dict as the first argument and ``*args`` as the following arguments. Parameters ---------- worker: str Worker to run the ``coroutine`` coroutine: coroutine The function to run on the worker *args: Arguments for ``coroutine`` wait: boolean, optional If True, waits for the coroutine to finished before returning. Returns ------- ret: object or Future If wait=True, the result of `coroutine` If wait=False, Future that can be waited on later. """ ret = self.client.submit( _run_coroutine_on_worker, self.sessionId, coroutine, args, workers=[worker], pure=False, ) return ret.result() if wait else ret
[docs] def run(self, coroutine, *args, workers=None, lock_workers=False): """Run a coroutine on multiple workers The coroutine is given the worker's state dict as the first argument and ``*args`` as the following arguments. Parameters ---------- coroutine: coroutine The function to run on each worker *args: Arguments for ``coroutine`` workers: list, optional List of workers. Default is all workers lock_workers: bool, optional Use distributed.MultiLock to get exclusive access to the workers. Use this flag to support parallel runs. Returns ------- ret: list List of the output from each worker """ if workers is None: workers = self.worker_addresses with get_multi_lock_or_null_context(lock_workers, workers): ret = [] for worker in workers: ret.append( self.client.submit( _run_coroutine_on_worker, self.sessionId, coroutine, args, workers=[worker], pure=False, ) ) return self.client.gather(ret)
[docs] def stage_keys(self, name: str, keys: Iterable[Hashable]) -> Dict[int, set]: """Staging keys on workers under the given name In an explicit-comms task, use `pop_staging_area(..., name)` to access the staged keys and the associated data. Notes ----- In the context of explicit-comms, staging is the act of duplicating the responsibility of Dask keys. When staging a key, the worker owning the key (as assigned by the Dask scheduler) save a reference to the key and the associated data to its local staging area. From this point on, if the scheduler cancels the key, the worker (and the task running on the worker) now has exclusive access to the key and the associated data. This way, staging makes it possible for long running explicit-comms tasks to free input data ASAP. Parameters ---------- name: str Name for the staging area keys: iterable The keys to stage Returns ------- dict dict that maps each worker-rank to the workers set of staged keys """ return dict(self.run(_stage_keys, name, set(keys)))
def pop_staging_area(session_state: dict, name: str) -> Dict[str, Any]: """Pop the staging area called `name` This function must be called within a running explicit-comms task. Parameters ---------- session_state: dict Worker session state name: str Name for the staging area Returns ------- dict The staging area, which is a dict that maps keys to their data. """ return session_state["stages"].pop(name)