Scaling up hyperparameter optimization with multi-GPU workload on Kubernetes#

Choosing an optimal set of hyperparameters is a daunting task, especially for algorithms like XGBoost that have many hyperparameters to tune. In this notebook, we will speed up hyperparameter optimization by running multiple training jobs in parallel on a Kubernetes cluster. We handle larger data sets by splitting the data into multiple GPU devices.

Prerequisites#

Please follow instructions in Dask Operator: Installation to install the Dask operator on top of a GPU-enabled Kubernetes cluster. (For the purpose of this example, you may ignore other sections of the linked document.

Optional: Kubeflow#

Kubeflow gives you a nice notebook environment to run this notebook within the k8s cluster. Install Kubeflow by following instructions in Installing Kubeflow. You may choose any method; we tested this example after installing Kubeflow from manifests.

Install extra Python modules#

We’ll need a few extra Python modules.

!pip install dask_kubernetes optuna
Collecting dask_kubernetes
  Downloading dask_kubernetes-2024.5.0-py3-none-any.whl.metadata (4.2 kB)
Collecting optuna
  Downloading optuna-3.6.1-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: dask>=2022.08.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (2024.1.1)
Requirement already satisfied: distributed>=2022.08.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (2024.1.1)
Collecting kopf>=1.35.3 (from dask_kubernetes)
  Downloading kopf-1.37.2-py3-none-any.whl.metadata (9.7 kB)
Collecting kr8s==0.14.* (from dask_kubernetes)
  Downloading kr8s-0.14.4-py3-none-any.whl.metadata (6.7 kB)
Collecting kubernetes-asyncio>=12.0.1 (from dask_kubernetes)
  Downloading kubernetes_asyncio-29.0.0-py3-none-any.whl.metadata (1.3 kB)
Collecting kubernetes>=12.0.1 (from dask_kubernetes)
  Downloading kubernetes-29.0.0-py2.py3-none-any.whl.metadata (1.5 kB)
Collecting pykube-ng>=22.9.0 (from dask_kubernetes)
  Downloading pykube_ng-23.6.0-py3-none-any.whl.metadata (8.0 kB)
Requirement already satisfied: rich>=12.5.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (13.7.1)
Requirement already satisfied: anyio>=3.7.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (4.3.0)
Collecting asyncache>=0.3.1 (from kr8s==0.14.*->dask_kubernetes)
  Downloading asyncache-0.3.1-py3-none-any.whl.metadata (2.0 kB)
Collecting cryptography>=35 (from kr8s==0.14.*->dask_kubernetes)
  Downloading cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.3 kB)
Requirement already satisfied: exceptiongroup>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (1.2.0)
Collecting httpx-ws>=0.5.1 (from kr8s==0.14.*->dask_kubernetes)
  Downloading httpx_ws-0.6.0-py3-none-any.whl.metadata (7.8 kB)
Requirement already satisfied: httpx>=0.24.1 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (0.27.0)
Collecting python-box>=7.0.1 (from kr8s==0.14.*->dask_kubernetes)
  Downloading python_box-7.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.8 kB)
Collecting python-jsonpath>=0.7.1 (from kr8s==0.14.*->dask_kubernetes)
  Downloading python_jsonpath-1.1.1-py3-none-any.whl.metadata (5.3 kB)
Requirement already satisfied: pyyaml>=6.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (6.0.1)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.13.1-py3-none-any.whl.metadata (7.4 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.8.2-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from optuna) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.11/site-packages (from optuna) (24.0)
Collecting sqlalchemy>=1.3.0 (from optuna)
  Downloading SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (from optuna) (4.66.2)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.3.3-py3-none-any.whl.metadata (2.9 kB)
Requirement already satisfied: typing-extensions>=4 in /opt/conda/lib/python3.11/site-packages (from alembic>=1.5.0->optuna) (4.11.0)
Requirement already satisfied: click>=8.1 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (8.1.7)
Requirement already satisfied: cloudpickle>=1.5.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: fsspec>=2021.09.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (2024.3.1)
Requirement already satisfied: partd>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (1.4.1)
Requirement already satisfied: toolz>=0.10.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (0.12.1)
Requirement already satisfied: importlib-metadata>=4.13.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (7.1.0)
Requirement already satisfied: jinja2>=2.10.3 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.1.3)
Requirement already satisfied: locket>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.0.0)
Requirement already satisfied: msgpack>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.0.7)
Requirement already satisfied: psutil>=5.7.2 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (5.9.8)
Requirement already satisfied: sortedcontainers>=2.0.5 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (2.4.0)
Requirement already satisfied: tblib>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: tornado>=6.0.4 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (6.4)
Requirement already satisfied: urllib3>=1.24.3 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.26.18)
Requirement already satisfied: zict>=3.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: python-json-logger in /opt/conda/lib/python3.11/site-packages (from kopf>=1.35.3->dask_kubernetes) (2.0.7)
Collecting iso8601 (from kopf>=1.35.3->dask_kubernetes)
  Downloading iso8601-2.1.0-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from kopf>=1.35.3->dask_kubernetes) (3.9.5)
Requirement already satisfied: certifi>=14.05.14 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2024.2.2)
Requirement already satisfied: six>=1.9.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (1.16.0)
Requirement already satisfied: python-dateutil>=2.5.3 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2.9.0)
Collecting google-auth>=1.0.1 (from kubernetes>=12.0.1->dask_kubernetes)
  Downloading google_auth-2.29.0-py2.py3-none-any.whl.metadata (4.7 kB)
Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (1.8.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2.31.0)
Collecting requests-oauthlib (from kubernetes>=12.0.1->dask_kubernetes)
  Downloading requests_oauthlib-2.0.0-py2.py3-none-any.whl.metadata (11 kB)
Collecting oauthlib>=3.2.2 (from kubernetes>=12.0.1->dask_kubernetes)
  Downloading oauthlib-3.2.2-py3-none-any.whl.metadata (7.5 kB)
Requirement already satisfied: setuptools>=21.0.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes-asyncio>=12.0.1->dask_kubernetes) (69.5.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from rich>=12.5.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.11/site-packages (from rich>=12.5.1->dask_kubernetes) (2.17.2)
Collecting greenlet!=0.4.17 (from sqlalchemy>=1.3.0->optuna)
  Downloading greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.9.4)
Requirement already satisfied: idna>=2.8 in /opt/conda/lib/python3.11/site-packages (from anyio>=3.7.0->kr8s==0.14.*->dask_kubernetes) (3.7)
Requirement already satisfied: sniffio>=1.1 in /opt/conda/lib/python3.11/site-packages (from anyio>=3.7.0->kr8s==0.14.*->dask_kubernetes) (1.3.1)
Requirement already satisfied: cachetools<6.0.0,>=5.2.0 in /opt/conda/lib/python3.11/site-packages (from asyncache>=0.3.1->kr8s==0.14.*->dask_kubernetes) (5.3.3)
Requirement already satisfied: cffi>=1.12 in /opt/conda/lib/python3.11/site-packages (from cryptography>=35->kr8s==0.14.*->dask_kubernetes) (1.16.0)
Collecting pyasn1-modules>=0.2.1 (from google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
  Downloading pyasn1_modules-0.4.0-py3-none-any.whl.metadata (3.4 kB)
Collecting rsa<5,>=3.1.4 (from google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
  Downloading rsa-4.9-py3-none-any.whl.metadata (4.2 kB)
Requirement already satisfied: httpcore==1.* in /opt/conda/lib/python3.11/site-packages (from httpx>=0.24.1->kr8s==0.14.*->dask_kubernetes) (1.0.5)
Requirement already satisfied: h11<0.15,>=0.13 in /opt/conda/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.24.1->kr8s==0.14.*->dask_kubernetes) (0.14.0)
Collecting wsproto (from httpx-ws>=0.5.1->kr8s==0.14.*->dask_kubernetes)
  Downloading wsproto-1.2.0-py3-none-any.whl.metadata (5.6 kB)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->dask>=2022.08.1->dask_kubernetes) (3.17.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2>=2.10.3->distributed>=2022.08.1->dask_kubernetes) (2.1.5)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.5.1->dask_kubernetes) (0.1.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests->kubernetes>=12.0.1->dask_kubernetes) (3.3.2)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.11/site-packages (from cffi>=1.12->cryptography>=35->kr8s==0.14.*->dask_kubernetes) (2.22)
Collecting pyasn1<0.7.0,>=0.4.6 (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
  Downloading pyasn1-0.6.0-py2.py3-none-any.whl.metadata (8.3 kB)
Downloading dask_kubernetes-2024.5.0-py3-none-any.whl (157 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 157.2/157.2 kB 2.1 MB/s eta 0:00:0000:010:01
?25hDownloading kr8s-0.14.4-py3-none-any.whl (60 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.7/60.7 kB 7.2 MB/s eta 0:00:00
?25hDownloading optuna-3.6.1-py3-none-any.whl (380 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 380.1/380.1 kB 8.0 MB/s eta 0:00:00a 0:00:01
?25hDownloading alembic-1.13.1-py3-none-any.whl (233 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 233.4/233.4 kB 23.5 MB/s eta 0:00:00
?25hDownloading kopf-1.37.2-py3-none-any.whl (207 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.8/207.8 kB 15.9 MB/s eta 0:00:00
?25hDownloading kubernetes-29.0.0-py2.py3-none-any.whl (1.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 27.2 MB/s eta 0:00:00a 0:00:01
?25hDownloading kubernetes_asyncio-29.0.0-py3-none-any.whl (2.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 83.5 MB/s eta 0:00:00
?25hDownloading pykube_ng-23.6.0-py3-none-any.whl (26 kB)
Downloading SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 122.4 MB/s eta 0:00:00
?25hDownloading colorlog-6.8.2-py3-none-any.whl (11 kB)
Downloading asyncache-0.3.1-py3-none-any.whl (3.7 kB)
Downloading cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl (3.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.8/3.8 MB 125.8 MB/s eta 0:00:00
?25hDownloading google_auth-2.29.0-py2.py3-none-any.whl (189 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 189.2/189.2 kB 29.6 MB/s eta 0:00:00
?25hDownloading greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (620 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 620.0/620.0 kB 61.9 MB/s eta 0:00:00
?25hDownloading httpx_ws-0.6.0-py3-none-any.whl (13 kB)
Downloading oauthlib-3.2.2-py3-none-any.whl (151 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 151.7/151.7 kB 24.6 MB/s eta 0:00:00
?25hDownloading python_box-7.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 131.1 MB/s eta 0:00:00
?25hDownloading python_jsonpath-1.1.1-py3-none-any.whl (51 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.5/51.5 kB 8.8 MB/s eta 0:00:00
?25hDownloading iso8601-2.1.0-py3-none-any.whl (7.5 kB)
Downloading Mako-1.3.3-py3-none-any.whl (78 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.8/78.8 kB 12.7 MB/s eta 0:00:00
?25hDownloading requests_oauthlib-2.0.0-py2.py3-none-any.whl (24 kB)
Downloading pyasn1_modules-0.4.0-py3-none-any.whl (181 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 181.2/181.2 kB 27.3 MB/s eta 0:00:00
?25hDownloading rsa-4.9-py3-none-any.whl (34 kB)
Downloading wsproto-1.2.0-py3-none-any.whl (24 kB)
Downloading pyasn1-0.6.0-py2.py3-none-any.whl (85 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.3/85.3 kB 12.4 MB/s eta 0:00:00
?25hInstalling collected packages: wsproto, python-jsonpath, python-box, pyasn1, oauthlib, Mako, iso8601, greenlet, colorlog, asyncache, sqlalchemy, rsa, requests-oauthlib, pykube-ng, pyasn1-modules, cryptography, kubernetes-asyncio, kopf, httpx-ws, google-auth, alembic, optuna, kubernetes, kr8s, dask_kubernetes
Successfully installed Mako-1.3.3 alembic-1.13.1 asyncache-0.3.1 colorlog-6.8.2 cryptography-42.0.7 dask_kubernetes-2024.5.0 google-auth-2.29.0 greenlet-3.0.3 httpx-ws-0.6.0 iso8601-2.1.0 kopf-1.37.2 kr8s-0.14.4 kubernetes-29.0.0 kubernetes-asyncio-29.0.0 oauthlib-3.2.2 optuna-3.6.1 pyasn1-0.6.0 pyasn1-modules-0.4.0 pykube-ng-23.6.0 python-box-7.1.1 python-jsonpath-1.1.1 requests-oauthlib-2.0.0 rsa-4.9 sqlalchemy-2.0.30 wsproto-1.2.0

Import Python modules#

import threading
import warnings

import cupy as cp
import cuspatial
import dask_cudf
import optuna
from cuml.dask.common import utils as dask_utils
from dask.distributed import Client, wait
from dask_kubernetes.operator import KubeCluster
from dask_ml.metrics import mean_squared_error
from dask_ml.model_selection import KFold
from xgboost import dask as dxgb

Set up multiple Dask clusters#

To run multi-GPU training jobs in parallel, we will create multiple Dask clusters each controlling its share of GPUs. It’s best to think of each Dask cluster as a portion of the compute resource of the Kubernetes cluster.

Fill in the following variables:

# Number of nodes in the Kubernetes cluster.
# Each node is assumed to have a single NVIDIA GPU attached
n_nodes = 7

# Number of worker nodes to be assigned to each Dask cluster
n_worker_per_dask_cluster = 2

# Number of nodes to be assigned to each Dask cluster
# 1 is added since the Dask cluster's scheduler process needs to be mapped to its own node
n_node_per_dask_cluster = n_worker_per_dask_cluster + 1

# Number of Dask clusters to be created
# Subtract 1 to account for the notebook pod (it requires its own node)
n_clusters = (n_nodes - 1) // n_node_per_dask_cluster

print(f"{n_clusters=}")
if n_clusters == 0:
    raise ValueError(
        "No cluster can be created. Reduce `n_worker_per_dask_cluster` or create more compute nodes"
    )
print(f"{n_worker_per_dask_cluster=}")
print(f"{n_node_per_dask_cluster=}")

n_node_active = n_clusters * n_node_per_dask_cluster + 1
if n_node_active != n_nodes:
    n_idle = n_nodes - n_node_active
    warnings.warn(f"{n_idle} node(s) will not be used", stacklevel=2)
n_clusters=2
n_worker_per_dask_cluster=2
n_node_per_dask_cluster=3

Once we’ve determined the number of Dask clusters and their size, we are now ready to launch them:

# Choose the same RAPIDS image you used for launching the notebook session
rapids_image = "rapidsai/notebooks:24.04-cuda12.2-py3.11"
clusters = []
for i in range(n_clusters):
    print(f"Launching cluster {i}...")
    clusters.append(
        KubeCluster(
            name=f"rapids-dask{i}",
            image=rapids_image,
            worker_command="dask-cuda-worker",
            n_workers=2,
            resources={"limits": {"nvidia.com/gpu": "1"}},
            env={"EXTRA_PIP_PACKAGES": "optuna"},
        )
    )
Launching cluster 0...

Launching cluster 1...

Set up Hyperparameter Optimization Task with NYC Taxi data#

Anaconda has graciously made some of the NYC Taxi dataset available in a public Google Cloud Storage bucket. We’ll use our Cluster of GPUs to process it and train a model that predicts the fare amount. We’ll use our Dask clusters to process it and train a model that predicts the fare amount.

col_dtype = {
    "VendorID": "int32",
    "tpep_pickup_datetime": "datetime64[ms]",
    "tpep_dropoff_datetime": "datetime64[ms]",
    "passenger_count": "int32",
    "trip_distance": "float32",
    "pickup_longitude": "float32",
    "pickup_latitude": "float32",
    "RatecodeID": "int32",
    "store_and_fwd_flag": "int32",
    "dropoff_longitude": "float32",
    "dropoff_latitude": "float32",
    "payment_type": "int32",
    "fare_amount": "float32",
    "extra": "float32",
    "mta_tax": "float32",
    "tip_amount": "float32",
    "total_amount": "float32",
    "tolls_amount": "float32",
    "improvement_surcharge": "float32",
}


must_haves = {
    "pickup_datetime": "datetime64[ms]",
    "dropoff_datetime": "datetime64[ms]",
    "passenger_count": "int32",
    "trip_distance": "float32",
    "pickup_longitude": "float32",
    "pickup_latitude": "float32",
    "rate_code": "int32",
    "dropoff_longitude": "float32",
    "dropoff_latitude": "float32",
    "fare_amount": "float32",
}


def compute_haversine_distance(df):
    pickup = cuspatial.GeoSeries.from_points_xy(
        df[["pickup_longitude", "pickup_latitude"]].interleave_columns()
    )
    dropoff = cuspatial.GeoSeries.from_points_xy(
        df[["dropoff_longitude", "dropoff_latitude"]].interleave_columns()
    )
    df["haversine_distance"] = cuspatial.haversine_distance(pickup, dropoff)
    df["haversine_distance"] = df["haversine_distance"].astype("float32")
    return df


def clean(ddf, must_haves):
    # replace the extraneous spaces in column names and lower the font type
    tmp = {col: col.strip().lower() for col in list(ddf.columns)}
    ddf = ddf.rename(columns=tmp)

    ddf = ddf.rename(
        columns={
            "tpep_pickup_datetime": "pickup_datetime",
            "tpep_dropoff_datetime": "dropoff_datetime",
            "ratecodeid": "rate_code",
        }
    )

    ddf["pickup_datetime"] = ddf["pickup_datetime"].astype("datetime64[ms]")
    ddf["dropoff_datetime"] = ddf["dropoff_datetime"].astype("datetime64[ms]")

    for col in ddf.columns:
        if col not in must_haves:
            ddf = ddf.drop(columns=col)
            continue
        if ddf[col].dtype == "object":
            # Fixing error: could not convert arg to str
            ddf = ddf.drop(columns=col)
        else:
            # downcast from 64bit to 32bit types
            # Tesla T4 are faster on 32bit ops
            if "int" in str(ddf[col].dtype):
                ddf[col] = ddf[col].astype("int32")
            if "float" in str(ddf[col].dtype):
                ddf[col] = ddf[col].astype("float32")
            ddf[col] = ddf[col].fillna(-1)

    return ddf


def prepare_data(client):
    taxi_df = dask_cudf.read_csv(
        "https://storage.googleapis.com/anaconda-public-data/nyc-taxi/csv/2016/yellow_tripdata_2016-02.csv",
        dtype=col_dtype,
    )
    taxi_df = taxi_df.map_partitions(clean, must_haves, meta=must_haves)

    ## add features
    taxi_df["hour"] = taxi_df["pickup_datetime"].dt.hour.astype("int32")
    taxi_df["year"] = taxi_df["pickup_datetime"].dt.year.astype("int32")
    taxi_df["month"] = taxi_df["pickup_datetime"].dt.month.astype("int32")
    taxi_df["day"] = taxi_df["pickup_datetime"].dt.day.astype("int32")
    taxi_df["day_of_week"] = taxi_df["pickup_datetime"].dt.weekday.astype("int32")
    taxi_df["is_weekend"] = (taxi_df["day_of_week"] >= 5).astype("int32")

    # calculate the time difference between dropoff and pickup.
    taxi_df["diff"] = taxi_df["dropoff_datetime"].astype("int32") - taxi_df[
        "pickup_datetime"
    ].astype("int32")
    taxi_df["diff"] = (taxi_df["diff"] / 1000).astype("int32")

    taxi_df["pickup_latitude_r"] = taxi_df["pickup_latitude"] // 0.01 * 0.01
    taxi_df["pickup_longitude_r"] = taxi_df["pickup_longitude"] // 0.01 * 0.01
    taxi_df["dropoff_latitude_r"] = taxi_df["dropoff_latitude"] // 0.01 * 0.01
    taxi_df["dropoff_longitude_r"] = taxi_df["dropoff_longitude"] // 0.01 * 0.01

    taxi_df = taxi_df.drop("pickup_datetime", axis=1)
    taxi_df = taxi_df.drop("dropoff_datetime", axis=1)

    taxi_df = taxi_df.map_partitions(compute_haversine_distance)

    X = (
        taxi_df.drop(["fare_amount"], axis=1)
        .astype("float32")
        .to_dask_array(lengths=True)
    )
    y = taxi_df["fare_amount"].astype("float32").to_dask_array(lengths=True)

    X._meta = cp.asarray(X._meta)
    y._meta = cp.asarray(y._meta)

    X, y = dask_utils.persist_across_workers(client, [X, y])
    return X, y


def train_model(params):
    cluster = get_cluster(threading.get_ident())

    default_params = {
        "objective": "reg:squarederror",
        "eval_metric": "rmse",
        "verbosity": 0,
        "tree_method": "hist",
        "device": "cuda",
    }
    params = dict(default_params, **params)

    with Client(cluster) as client:
        X, y = prepare_data(client)
        wait([X, y])

        scores = []
        kfold = KFold(n_splits=5, shuffle=False)
        for train_index, test_index in kfold.split(X, y):
            dtrain = dxgb.DaskQuantileDMatrix(client, X[train_index, :], y[train_index])
            dtest = dxgb.DaskQuantileDMatrix(client, X[test_index, :], y[test_index])
            model = dxgb.train(
                client,
                params,
                dtrain,
                num_boost_round=10,
                verbose_eval=False,
            )
            y_test_pred = dxgb.predict(client, model, dtest).to_backend("cupy")
            rmse_score = mean_squared_error(y[test_index], y_test_pred, squared=False)
            scores.append(rmse_score)
        return sum(scores) / len(scores)


def objective(trial):
    params = {
        "n_estimators": trial.suggest_int("n_estimators", 2, 4),
        "learning_rate": trial.suggest_float("learning_rate", 0.5, 0.7),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1),
        "colsample_bynode": trial.suggest_float("colsample_bynode", 0.5, 1),
        "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.5, 1),
        "reg_lambda": trial.suggest_float("reg_lambda", 0, 1),
        "max_depth": trial.suggest_int("max_depth", 1, 6),
        "max_leaves": trial.suggest_int("max_leaves", 0, 2),
        "max_cat_to_onehot": trial.suggest_int("max_cat_to_onehot", 1, 10),
    }
    return train_model(params)

To kick off multiple training jobs in parallel, we will launch multiple threads, so that each thread controls a Dask cluster. One important utility function is get_cluster, which returns the Dask cluster that’s mapped to a given thread.

# Map each thread's integer ID to a sequential number (0, 1, 2 ...)
thread_id_map: dict[int, KubeCluster] = {}
thread_id_map_lock = threading.Lock()


def get_cluster(thread_id: int) -> KubeCluster:
    with thread_id_map_lock:
        try:
            return clusters[thread_id_map[thread_id]]
        except KeyError:
            seq_id = len(thread_id_map)
            thread_id_map[thread_id] = seq_id
            return clusters[seq_id]

Now we are ready to start hyperparameter optimization.

n_trials = (
    10  # set to a low number so that the demo finishes quickly. Feel free to adjust
)
study = optuna.create_study(direction="minimize")
[I 2024-05-09 07:53:00,718] A new study created in memory with name: no-name-da830427-bce3-4e42-98e6-c98c0c3da0d7
# With n_jobs parameter, Optuna will launch [n_clusters] threads internally
# Each thread will deploy a training job to a Dask cluster
study.optimize(objective, n_trials=n_trials, n_jobs=n_clusters)
[I 2024-05-09 07:54:10,229] Trial 1 finished with value: 59.449462890625 and parameters: {'n_estimators': 4, 'learning_rate': 0.6399993857892183, 'colsample_bytree': 0.7020623988319513, 'colsample_bynode': 0.777468318546648, 'colsample_bylevel': 0.7890749134903386, 'reg_lambda': 0.4464953694744921, 'max_depth': 3, 'max_leaves': 0, 'max_cat_to_onehot': 9}. Best is trial 1 with value: 59.449462890625.
[I 2024-05-09 07:54:19,507] Trial 0 finished with value: 57.77985763549805 and parameters: {'n_estimators': 4, 'learning_rate': 0.674087333032356, 'colsample_bytree': 0.557642421113256, 'colsample_bynode': 0.9719449711676733, 'colsample_bylevel': 0.6984302171973646, 'reg_lambda': 0.7201514298169174, 'max_depth': 4, 'max_leaves': 1, 'max_cat_to_onehot': 4}. Best is trial 0 with value: 57.77985763549805.
[I 2024-05-09 07:54:59,524] Trial 2 finished with value: 57.77985763549805 and parameters: {'n_estimators': 2, 'learning_rate': 0.6894880267544121, 'colsample_bytree': 0.8171662437182604, 'colsample_bynode': 0.549527686217645, 'colsample_bylevel': 0.890212178266078, 'reg_lambda': 0.5847298606135033, 'max_depth': 2, 'max_leaves': 1, 'max_cat_to_onehot': 5}. Best is trial 0 with value: 57.77985763549805.
[I 2024-05-09 07:55:22,013] Trial 3 finished with value: 55.01234817504883 and parameters: {'n_estimators': 4, 'learning_rate': 0.6597614733926671, 'colsample_bytree': 0.8437061126308156, 'colsample_bynode': 0.621479934699203, 'colsample_bylevel': 0.8330951489228277, 'reg_lambda': 0.7830102753448884, 'max_depth': 2, 'max_leaves': 2, 'max_cat_to_onehot': 2}. Best is trial 3 with value: 55.01234817504883.
[I 2024-05-09 07:56:00,678] Trial 4 finished with value: 57.77985763549805 and parameters: {'n_estimators': 4, 'learning_rate': 0.5994587326401378, 'colsample_bytree': 0.9799078215504886, 'colsample_bynode': 0.9766955839079614, 'colsample_bylevel': 0.5088864363378924, 'reg_lambda': 0.18103184809548734, 'max_depth': 3, 'max_leaves': 1, 'max_cat_to_onehot': 4}. Best is trial 3 with value: 55.01234817504883.
[I 2024-05-09 07:56:11,773] Trial 5 finished with value: 54.936126708984375 and parameters: {'n_estimators': 2, 'learning_rate': 0.5208827661289628, 'colsample_bytree': 0.866258912492528, 'colsample_bynode': 0.6368815844513638, 'colsample_bylevel': 0.9539603435186208, 'reg_lambda': 0.21390618865079458, 'max_depth': 4, 'max_leaves': 2, 'max_cat_to_onehot': 4}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:56:48,737] Trial 6 finished with value: 57.77985763549805 and parameters: {'n_estimators': 2, 'learning_rate': 0.6137888371528442, 'colsample_bytree': 0.9621063205689744, 'colsample_bynode': 0.5306812468481084, 'colsample_bylevel': 0.8527827651989199, 'reg_lambda': 0.3315799968401767, 'max_depth': 6, 'max_leaves': 1, 'max_cat_to_onehot': 9}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:56:59,261] Trial 7 finished with value: 55.204200744628906 and parameters: {'n_estimators': 3, 'learning_rate': 0.6831416027240611, 'colsample_bytree': 0.5311840770388268, 'colsample_bynode': 0.9572535535110238, 'colsample_bylevel': 0.6846894032354778, 'reg_lambda': 0.6091211134408249, 'max_depth': 3, 'max_leaves': 2, 'max_cat_to_onehot': 5}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:57:37,674] Trial 8 finished with value: 54.93584442138672 and parameters: {'n_estimators': 4, 'learning_rate': 0.620742285616388, 'colsample_bytree': 0.7969398985157778, 'colsample_bynode': 0.9049707375663323, 'colsample_bylevel': 0.7209693969245297, 'reg_lambda': 0.6158847054585023, 'max_depth': 1, 'max_leaves': 0, 'max_cat_to_onehot': 10}. Best is trial 8 with value: 54.93584442138672.
[I 2024-05-09 07:57:50,310] Trial 9 finished with value: 57.76123809814453 and parameters: {'n_estimators': 3, 'learning_rate': 0.5475197727057007, 'colsample_bytree': 0.5381502848057452, 'colsample_bynode': 0.8514705732161596, 'colsample_bylevel': 0.9139277684007088, 'reg_lambda': 0.5117732009332318, 'max_depth': 4, 'max_leaves': 0, 'max_cat_to_onehot': 5}. Best is trial 8 with value: 54.93584442138672.