{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Dask Multi-GPU Guide\n", "\n", "This guide demonstrates how to use cuML estimators in multi-GPU and multi-node contexts using Dask. While cuML's single-GPU implementations are highly optimized, distributed computing with Dask enables you to:\n", "\n", "- **Scale beyond single GPU memory**: Process datasets larger than what fits on a single GPU\n", "- **Accelerate training**: Distribute computation across multiple GPUs for faster model training\n", "- **Handle production workloads**: Deploy models that serve high-throughput prediction requests\n", "\n", "cuML's Dask integration uses a **One Process Per GPU (OPG)** architecture, where each Dask worker manages a single GPU. This design maximizes GPU utilization and simplifies memory management across the cluster.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup and Configuration\n", "\n", "### Installing Dependencies\n", "\n", "To use cuML with Dask, you need to install additional dependencies. If you haven't already, install them using conda or pip:\n", "\n", "```bash\n", "# Install with Conda:\n", "conda install rapids-dask-dependency dask-cudf raft-dask\n", "\n", "# Or install with pip (replace cu13 with your CUDA version):\n", "pip install cuml-cu13[dask]\n", "```\n", "\n", "### Setting Up a CUDA Cluster\n", "\n", "For single-node, multi-GPU execution, use `LocalCUDACluster` from `dask-cuda`. This automatically creates one worker per available GPU. For detailed information on configuring local CUDA clusters, including advanced networking options (UCX, InfiniBand) and multi-node cluster setup, see the [RAPIDS dask-cuda documentation](https://docs.rapids.ai/api/dask-cuda/stable/examples/ucx/).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dask.distributed import Client\n", "from dask_cuda import LocalCUDACluster\n", "\n", "# Create a local cluster with one worker per GPU\n", "cluster = LocalCUDACluster()\n", "client = Client(cluster)\n", "\n", "# Display cluster information\n", "print(f\"Cluster dashboard available at: {client.dashboard_link}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 1: K-Means Clustering\n", "\n", "K-Means is one of the most commonly used clustering algorithms. cuML's distributed implementation parallelizes the fit operation for each iteration, sharing only the centroids between iterations.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from cuml.dask.cluster import KMeans\n", "from cuml.dask.datasets import make_blobs\n", "from cuml.metrics import adjusted_rand_score\n", "\n", "# Get number of workers for data partitioning\n", "n_workers = len(client.scheduler_info()['workers'])\n", "\n", "# Generate distributed synthetic data\n", "X, y = make_blobs(\n", " n_samples=10000,\n", " n_features=20,\n", " centers=5,\n", " cluster_std=0.5,\n", " random_state=42,\n", " n_parts=n_workers * 2 # Multiple partitions per worker\n", ")\n", "\n", "print(f\"Generated data with {len(X.to_delayed())} partitions\")\n", "print(f\"Data type: {type(X)}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can train a distributed K-Means model. The API is nearly identical to the single-GPU version:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train distributed K-Means\n", "kmeans = KMeans(n_clusters=5, random_state=42)\n", "kmeans.fit(X)\n", "\n", "# Make predictions\n", "labels = kmeans.predict(X)\n", "\n", "# Evaluate clustering quality\n", "score = adjusted_rand_score(y.compute(), labels.compute())\n", "print(f\"Adjusted Rand Score: {score:.4f}\")\n", "\n", "# View cluster centers\n", "print(f\"\\nCluster centers shape: {kmeans.cluster_centers_.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convert into a single-GPU model\n", "\n", "We can use the distributed model for inference directly (as shown above) or convert it back into a single-GPU version based on needs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract single-GPU model from distributed model\n", "combined_model = kmeans.get_combined_model()\n", "\n", "combined_model.predict(X)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Serialization\n", "\n", "Distributed models cannot be pickled directly. For pickling we need to first extract the single-GPU version as shown above and can then serialize as usual:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "\n", "# Save the model\n", "with open(\"kmeans_model.pkl\", \"wb\") as f:\n", " pickle.dump(combined_model, f, protocol=5)\n", "\n", "# Load and use the model\n", "with open(\"kmeans_model.pkl\", \"rb\") as f:\n", " loaded_model = pickle.load(f)\n", "\n", "print(f\"Loaded model type: {type(loaded_model)}\")\n", "print(f\"Model has {loaded_model.n_clusters} clusters\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The loaded model is a single-GPU estimator that can be used for inference. For distributed inference across a Dask cluster, consider using [Dask-ML's ParallelPostFit](https://ml.dask.org/meta-estimators.html) meta-estimator.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 2: Random Forest Classification\n", "\n", "Random Forest is an ensemble learning method that builds multiple decision trees. cuML's distributed implementation uses embarrassingly parallel training: for a forest with N trees trained by W workers, each worker builds N/W trees.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from cuml.dask.ensemble import RandomForestClassifier\n", "from cuml.dask.datasets import make_classification\n", "from cuml.metrics import accuracy_score\n", "\n", "# Generate classification dataset\n", "n_samples = 5000\n", "n_features = 30\n", "n_classes = 3\n", "\n", "X, y = make_classification(\n", " n_samples=n_samples,\n", " n_features=n_features,\n", " n_informative=int(n_features * 0.7),\n", " n_redundant=int(n_features * 0.2),\n", " n_classes=n_classes,\n", " random_state=42,\n", " n_parts=n_workers * 2\n", ")\n", "\n", "print(f\"Generated classification dataset:\")\n", "print(f\" Samples: {n_samples}\")\n", "print(f\" Features: {n_features}\")\n", "print(f\" Classes: {n_classes}\")\n", "print(f\" Partitions: {len(X.to_delayed())}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Distribution Best Practices\n", "\n", "For Random Forest, data distribution is critical for model accuracy:\n", "\n", "- **Option 1: Well-shuffled data**: Distribute shuffled data evenly across workers (used above)\n", "- **Option 2: Replicated data**: If memory allows, replicate the entire dataset to all workers for training that most closely matches single-GPU behavior\n", "\n", "Both approaches ensure each worker sees a representative sample of the data distribution.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train distributed Random Forest\n", "rf = RandomForestClassifier(\n", " n_estimators=100,\n", " max_depth=16,\n", " n_bins=32,\n", " random_state=42\n", ")\n", "\n", "rf.fit(X, y)\n", "\n", "# Make predictions\n", "predictions = rf.predict(X)\n", "\n", "# Evaluate accuracy\n", "accuracy = accuracy_score(y.compute(), predictions.compute())\n", "print(f\"Training accuracy: {accuracy:.4f}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performance Notes\n", "\n", "Random Forest training performance is heavily influenced by:\n", "\n", "- `max_depth`: Lower values significantly speed up training but may reduce accuracy. Start with 12-16 for balanced performance.\n", "- `n_estimators`: More trees improve accuracy but increase training time linearly. The work is distributed across workers.\n", "- `n_bins`: Controls histogram granularity for split finding. Lower values (8-32) are faster but may miss optimal splits.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example 3: Linear Regression\n", "\n", "Linear models in cuML support both Dask DataFrame and Dask Array inputs. This example demonstrates a distributed linear regression workflow.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from cuml.dask.linear_model import LinearRegression\n", "from cuml.dask.datasets import make_regression\n", "from cuml.metrics import r2_score\n", "\n", "# Generate regression dataset\n", "X, y = make_regression(\n", " n_samples=10000,\n", " n_features=50,\n", " n_informative=40,\n", " random_state=42,\n", " n_parts=n_workers * 2\n", ")\n", "\n", "print(f\"Generated regression dataset with {X.shape[0]:,} samples\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train distributed Linear Regression\n", "lr = LinearRegression()\n", "lr.fit(X, y)\n", "\n", "# Make predictions\n", "predictions = lr.predict(X)\n", "\n", "# Evaluate model\n", "r2 = r2_score(y.compute(), predictions.compute())\n", "print(f\"R² score: {r2:.4f}\")\n", "print(f\"Number of coefficients: {len(lr.coef_)}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using Dask DataFrames\n", "\n", "You can also use Dask cuDF DataFrames as input, which is useful when loading data from files:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dask_cudf\n", "import cudf\n", "\n", "# Convert Dask Array to Dask DataFrame\n", "# In practice, you'd often load data directly: dask_cudf.read_csv(\"data.csv\")\n", "X_computed = X.compute()\n", "y_computed = y.compute()\n", "\n", "# Create cuDF DataFrame\n", "df = cudf.DataFrame(X_computed)\n", "target = cudf.Series(y_computed)\n", "\n", "# Convert to Dask DataFrame\n", "ddf = dask_cudf.from_cudf(df, npartitions=n_workers * 2)\n", "dtarget = dask_cudf.from_cudf(target, npartitions=n_workers * 2)\n", "\n", "# Train with Dask DataFrame input\n", "lr_df = LinearRegression()\n", "lr_df.fit(ddf, dtarget)\n", "\n", "print(f\"Model trained successfully with Dask DataFrame input\")\n", "print(f\"Predictions type: {type(lr_df.predict(ddf))}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Available Dask-Enabled Estimators\n", "\n", "cuML provides Dask implementations for many popular algorithms including clustering (KMeans, DBSCAN), ensemble methods (Random Forest), linear models (LinearRegression, Logistic Regression, Ridge, Lasso), decomposition (PCA, TruncatedSVD), nearest neighbors, and more. All are available in the `cuml.dask` module.\n", "\n", "For a comprehensive list of supported algorithms and their documentation, see the [Multi-Node, Multi-GPU Algorithms](https://docs.rapids.ai/api/cuml/stable/api.html#multi-node-multi-gpu-algorithms) section of the API reference." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Performance Considerations and Best Practices\n", "\n", "### When to Use Multi-GPU\n", "\n", "Use distributed computing with Dask when:\n", "\n", "- **Dataset exceeds single GPU memory**: Your data doesn't fit on a single GPU\n", "- **Training time is critical**: Multiple GPUs can accelerate model training\n", "- **Scaling to production**: You need to handle high-throughput inference workloads\n", "\n", "Single-GPU implementations are often sufficient and simpler for:\n", "\n", "- **Small to medium datasets**: Data that comfortably fits in single GPU memory (typically < 80% of GPU RAM)\n", "- **Rapid prototyping**: Simpler setup and debugging\n", "- **Latency-sensitive inference**: Single-GPU inference has lower overhead\n", "\n", "### Data Partitioning Strategy\n", "\n", "The `n_parts` parameter controls how data is distributed across workers:\n", "\n", "```python\n", "# Rule of thumb: 2-4 partitions per worker\n", "n_parts = n_workers * 2 # Good starting point\n", "\n", "# More partitions: Better load balancing, more overhead\n", "# Fewer partitions: Lower overhead, potential load imbalance\n", "```\n", "\n", "### Network Optimization\n", "\n", "For multi-node clusters, high-performance networking options are available:\n", "\n", "- **NVLink**: For multi-GPU communication on the same node\n", "- **InfiniBand**: For fast inter-node communication\n", "- **UCX protocol**: Unified communication framework for optimal performance\n", "\n", "For detailed configuration examples and setup instructions, see the [RAPIDS dask-cuda documentation](https://docs.rapids.ai/api/dask-cuda/stable/examples/ucx/).\n", "\n", "### Memory Management\n", "\n", "- **RMM pool size**: Pre-allocate GPU memory to reduce allocation overhead\n", "- **Worker memory limit**: Set limits to prevent out-of-memory errors\n", "- **Partition size**: Keep partitions small enough to fit comfortably in GPU memory\n", "\n", "### Performance Profiling\n", "\n", "Monitor your distributed computation:\n", "\n", "1. **Dask Dashboard**: Access at `client.dashboard_link` to visualize task execution\n", "2. **NVIDIA tools**: Use `nvidia-smi` to monitor GPU utilization\n", "3. **RAPIDS Memory Manager**: Enable RMM logging for memory profiling\n", "\n", "### Input Data Types\n", "\n", "cuML's Dask estimators accept multiple input formats:\n", "\n", "- **Dask Array**: Use `dask.array` with CuPy backend for array operations\n", "- **Dask DataFrame**: Use `dask_cudf.DataFrame` for structured data\n", "\n", "Choose based on your workflow:\n", "\n", "```python\n", "# Dask Array - good for numerical operations\n", "import dask.array as da\n", "X = da.random.random((10000, 50), chunks=(1000, 50))\n", "\n", "# Dask DataFrame - good for mixed types and data loading\n", "import dask_cudf\n", "df = dask_cudf.read_csv(\"data.csv\")\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Clean Up\n", "\n", "Optionally close your Dask client and cluster when finished to free up resources.\n", "Both the client and cluster will be shut down automatically when the notebook kernel process is shut down.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Close the client and cluster\n", "client.close()\n", "cluster.close()\n", "\n", "print(\"Cluster shut down successfully\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional Resources\n", "\n", "- **API Reference**: [Multi-Node, Multi-GPU Algorithms](https://docs.rapids.ai/api/cuml/stable/api.html#multi-node-multi-gpu-algorithms)\n", "- **Cluster Setup**: [RAPIDS dask-cuda Examples (UCX, multi-node configuration)](https://docs.rapids.ai/api/dask-cuda/stable/examples/ucx/)\n", "- **Dask Documentation**: [Dask Distributed](https://distributed.dask.org/)\n", "- **Dask-CUDA**: [Dask-CUDA Documentation](https://docs.rapids.ai/api/dask-cuda/stable/)\n" ] } ], "metadata": { "kernelspec": { "display_name": "cuml-work2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.9" } }, "nbformat": 4, "nbformat_minor": 2 }