Scaling up Hyperparameter Optimization with Multi-GPU Workload on Kubernetes#
Choosing an optimal set of hyperparameters is a daunting task, especially for algorithms like XGBoost that have many hyperparameters to tune. In this notebook, we will speed up hyperparameter optimization by running multiple training jobs in parallel on a Kubernetes cluster. We handle larger data sets by splitting the data into multiple GPU devices.
Prerequisites#
Please follow instructions in Dask Operator: Installation to install the Dask operator on top of a GPU-enabled Kubernetes cluster. (For the purpose of this example, you may ignore other sections of the linked document.
Optional: Kubeflow#
Kubeflow gives you a nice notebook environment to run this notebook within the k8s cluster. Install Kubeflow by following instructions in Installing Kubeflow. You may choose any method; we tested this example after installing Kubeflow from manifests.
Install extra Python modules#
We’ll need a few extra Python modules.
!pip install dask_kubernetes optuna
Collecting dask_kubernetes
Downloading dask_kubernetes-2024.5.0-py3-none-any.whl.metadata (4.2 kB)
Collecting optuna
Downloading optuna-3.6.1-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: dask>=2022.08.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (2024.1.1)
Requirement already satisfied: distributed>=2022.08.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (2024.1.1)
Collecting kopf>=1.35.3 (from dask_kubernetes)
Downloading kopf-1.37.2-py3-none-any.whl.metadata (9.7 kB)
Collecting kr8s==0.14.* (from dask_kubernetes)
Downloading kr8s-0.14.4-py3-none-any.whl.metadata (6.7 kB)
Collecting kubernetes-asyncio>=12.0.1 (from dask_kubernetes)
Downloading kubernetes_asyncio-29.0.0-py3-none-any.whl.metadata (1.3 kB)
Collecting kubernetes>=12.0.1 (from dask_kubernetes)
Downloading kubernetes-29.0.0-py2.py3-none-any.whl.metadata (1.5 kB)
Collecting pykube-ng>=22.9.0 (from dask_kubernetes)
Downloading pykube_ng-23.6.0-py3-none-any.whl.metadata (8.0 kB)
Requirement already satisfied: rich>=12.5.1 in /opt/conda/lib/python3.11/site-packages (from dask_kubernetes) (13.7.1)
Requirement already satisfied: anyio>=3.7.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (4.3.0)
Collecting asyncache>=0.3.1 (from kr8s==0.14.*->dask_kubernetes)
Downloading asyncache-0.3.1-py3-none-any.whl.metadata (2.0 kB)
Collecting cryptography>=35 (from kr8s==0.14.*->dask_kubernetes)
Downloading cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.3 kB)
Requirement already satisfied: exceptiongroup>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (1.2.0)
Collecting httpx-ws>=0.5.1 (from kr8s==0.14.*->dask_kubernetes)
Downloading httpx_ws-0.6.0-py3-none-any.whl.metadata (7.8 kB)
Requirement already satisfied: httpx>=0.24.1 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (0.27.0)
Collecting python-box>=7.0.1 (from kr8s==0.14.*->dask_kubernetes)
Downloading python_box-7.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.8 kB)
Collecting python-jsonpath>=0.7.1 (from kr8s==0.14.*->dask_kubernetes)
Downloading python_jsonpath-1.1.1-py3-none-any.whl.metadata (5.3 kB)
Requirement already satisfied: pyyaml>=6.0 in /opt/conda/lib/python3.11/site-packages (from kr8s==0.14.*->dask_kubernetes) (6.0.1)
Collecting alembic>=1.5.0 (from optuna)
Downloading alembic-1.13.1-py3-none-any.whl.metadata (7.4 kB)
Collecting colorlog (from optuna)
Downloading colorlog-6.8.2-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from optuna) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.11/site-packages (from optuna) (24.0)
Collecting sqlalchemy>=1.3.0 (from optuna)
Downloading SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (from optuna) (4.66.2)
Collecting Mako (from alembic>=1.5.0->optuna)
Downloading Mako-1.3.3-py3-none-any.whl.metadata (2.9 kB)
Requirement already satisfied: typing-extensions>=4 in /opt/conda/lib/python3.11/site-packages (from alembic>=1.5.0->optuna) (4.11.0)
Requirement already satisfied: click>=8.1 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (8.1.7)
Requirement already satisfied: cloudpickle>=1.5.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: fsspec>=2021.09.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (2024.3.1)
Requirement already satisfied: partd>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (1.4.1)
Requirement already satisfied: toolz>=0.10.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (0.12.1)
Requirement already satisfied: importlib-metadata>=4.13.0 in /opt/conda/lib/python3.11/site-packages (from dask>=2022.08.1->dask_kubernetes) (7.1.0)
Requirement already satisfied: jinja2>=2.10.3 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.1.3)
Requirement already satisfied: locket>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.0.0)
Requirement already satisfied: msgpack>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.0.7)
Requirement already satisfied: psutil>=5.7.2 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (5.9.8)
Requirement already satisfied: sortedcontainers>=2.0.5 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (2.4.0)
Requirement already satisfied: tblib>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: tornado>=6.0.4 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (6.4)
Requirement already satisfied: urllib3>=1.24.3 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (1.26.18)
Requirement already satisfied: zict>=3.0.0 in /opt/conda/lib/python3.11/site-packages (from distributed>=2022.08.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: python-json-logger in /opt/conda/lib/python3.11/site-packages (from kopf>=1.35.3->dask_kubernetes) (2.0.7)
Collecting iso8601 (from kopf>=1.35.3->dask_kubernetes)
Downloading iso8601-2.1.0-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from kopf>=1.35.3->dask_kubernetes) (3.9.5)
Requirement already satisfied: certifi>=14.05.14 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2024.2.2)
Requirement already satisfied: six>=1.9.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (1.16.0)
Requirement already satisfied: python-dateutil>=2.5.3 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2.9.0)
Collecting google-auth>=1.0.1 (from kubernetes>=12.0.1->dask_kubernetes)
Downloading google_auth-2.29.0-py2.py3-none-any.whl.metadata (4.7 kB)
Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (1.8.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.11/site-packages (from kubernetes>=12.0.1->dask_kubernetes) (2.31.0)
Collecting requests-oauthlib (from kubernetes>=12.0.1->dask_kubernetes)
Downloading requests_oauthlib-2.0.0-py2.py3-none-any.whl.metadata (11 kB)
Collecting oauthlib>=3.2.2 (from kubernetes>=12.0.1->dask_kubernetes)
Downloading oauthlib-3.2.2-py3-none-any.whl.metadata (7.5 kB)
Requirement already satisfied: setuptools>=21.0.0 in /opt/conda/lib/python3.11/site-packages (from kubernetes-asyncio>=12.0.1->dask_kubernetes) (69.5.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from rich>=12.5.1->dask_kubernetes) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.11/site-packages (from rich>=12.5.1->dask_kubernetes) (2.17.2)
Collecting greenlet!=0.4.17 (from sqlalchemy>=1.3.0->optuna)
Downloading greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->kopf>=1.35.3->dask_kubernetes) (1.9.4)
Requirement already satisfied: idna>=2.8 in /opt/conda/lib/python3.11/site-packages (from anyio>=3.7.0->kr8s==0.14.*->dask_kubernetes) (3.7)
Requirement already satisfied: sniffio>=1.1 in /opt/conda/lib/python3.11/site-packages (from anyio>=3.7.0->kr8s==0.14.*->dask_kubernetes) (1.3.1)
Requirement already satisfied: cachetools<6.0.0,>=5.2.0 in /opt/conda/lib/python3.11/site-packages (from asyncache>=0.3.1->kr8s==0.14.*->dask_kubernetes) (5.3.3)
Requirement already satisfied: cffi>=1.12 in /opt/conda/lib/python3.11/site-packages (from cryptography>=35->kr8s==0.14.*->dask_kubernetes) (1.16.0)
Collecting pyasn1-modules>=0.2.1 (from google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
Downloading pyasn1_modules-0.4.0-py3-none-any.whl.metadata (3.4 kB)
Collecting rsa<5,>=3.1.4 (from google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
Downloading rsa-4.9-py3-none-any.whl.metadata (4.2 kB)
Requirement already satisfied: httpcore==1.* in /opt/conda/lib/python3.11/site-packages (from httpx>=0.24.1->kr8s==0.14.*->dask_kubernetes) (1.0.5)
Requirement already satisfied: h11<0.15,>=0.13 in /opt/conda/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.24.1->kr8s==0.14.*->dask_kubernetes) (0.14.0)
Collecting wsproto (from httpx-ws>=0.5.1->kr8s==0.14.*->dask_kubernetes)
Downloading wsproto-1.2.0-py3-none-any.whl.metadata (5.6 kB)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.11/site-packages (from importlib-metadata>=4.13.0->dask>=2022.08.1->dask_kubernetes) (3.17.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2>=2.10.3->distributed>=2022.08.1->dask_kubernetes) (2.1.5)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.5.1->dask_kubernetes) (0.1.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests->kubernetes>=12.0.1->dask_kubernetes) (3.3.2)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.11/site-packages (from cffi>=1.12->cryptography>=35->kr8s==0.14.*->dask_kubernetes) (2.22)
Collecting pyasn1<0.7.0,>=0.4.6 (from pyasn1-modules>=0.2.1->google-auth>=1.0.1->kubernetes>=12.0.1->dask_kubernetes)
Downloading pyasn1-0.6.0-py2.py3-none-any.whl.metadata (8.3 kB)
Downloading dask_kubernetes-2024.5.0-py3-none-any.whl (157 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 157.2/157.2 kB 2.1 MB/s eta 0:00:0000:010:01
?25hDownloading kr8s-0.14.4-py3-none-any.whl (60 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.7/60.7 kB 7.2 MB/s eta 0:00:00
?25hDownloading optuna-3.6.1-py3-none-any.whl (380 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 380.1/380.1 kB 8.0 MB/s eta 0:00:00a 0:00:01
?25hDownloading alembic-1.13.1-py3-none-any.whl (233 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 233.4/233.4 kB 23.5 MB/s eta 0:00:00
?25hDownloading kopf-1.37.2-py3-none-any.whl (207 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.8/207.8 kB 15.9 MB/s eta 0:00:00
?25hDownloading kubernetes-29.0.0-py2.py3-none-any.whl (1.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 27.2 MB/s eta 0:00:00a 0:00:01
?25hDownloading kubernetes_asyncio-29.0.0-py3-none-any.whl (2.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 83.5 MB/s eta 0:00:00
?25hDownloading pykube_ng-23.6.0-py3-none-any.whl (26 kB)
Downloading SQLAlchemy-2.0.30-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 122.4 MB/s eta 0:00:00
?25hDownloading colorlog-6.8.2-py3-none-any.whl (11 kB)
Downloading asyncache-0.3.1-py3-none-any.whl (3.7 kB)
Downloading cryptography-42.0.7-cp39-abi3-manylinux_2_28_x86_64.whl (3.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.8/3.8 MB 125.8 MB/s eta 0:00:00
?25hDownloading google_auth-2.29.0-py2.py3-none-any.whl (189 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 189.2/189.2 kB 29.6 MB/s eta 0:00:00
?25hDownloading greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (620 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 620.0/620.0 kB 61.9 MB/s eta 0:00:00
?25hDownloading httpx_ws-0.6.0-py3-none-any.whl (13 kB)
Downloading oauthlib-3.2.2-py3-none-any.whl (151 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 151.7/151.7 kB 24.6 MB/s eta 0:00:00
?25hDownloading python_box-7.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 131.1 MB/s eta 0:00:00
?25hDownloading python_jsonpath-1.1.1-py3-none-any.whl (51 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.5/51.5 kB 8.8 MB/s eta 0:00:00
?25hDownloading iso8601-2.1.0-py3-none-any.whl (7.5 kB)
Downloading Mako-1.3.3-py3-none-any.whl (78 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.8/78.8 kB 12.7 MB/s eta 0:00:00
?25hDownloading requests_oauthlib-2.0.0-py2.py3-none-any.whl (24 kB)
Downloading pyasn1_modules-0.4.0-py3-none-any.whl (181 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 181.2/181.2 kB 27.3 MB/s eta 0:00:00
?25hDownloading rsa-4.9-py3-none-any.whl (34 kB)
Downloading wsproto-1.2.0-py3-none-any.whl (24 kB)
Downloading pyasn1-0.6.0-py2.py3-none-any.whl (85 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 85.3/85.3 kB 12.4 MB/s eta 0:00:00
?25hInstalling collected packages: wsproto, python-jsonpath, python-box, pyasn1, oauthlib, Mako, iso8601, greenlet, colorlog, asyncache, sqlalchemy, rsa, requests-oauthlib, pykube-ng, pyasn1-modules, cryptography, kubernetes-asyncio, kopf, httpx-ws, google-auth, alembic, optuna, kubernetes, kr8s, dask_kubernetes
Successfully installed Mako-1.3.3 alembic-1.13.1 asyncache-0.3.1 colorlog-6.8.2 cryptography-42.0.7 dask_kubernetes-2024.5.0 google-auth-2.29.0 greenlet-3.0.3 httpx-ws-0.6.0 iso8601-2.1.0 kopf-1.37.2 kr8s-0.14.4 kubernetes-29.0.0 kubernetes-asyncio-29.0.0 oauthlib-3.2.2 optuna-3.6.1 pyasn1-0.6.0 pyasn1-modules-0.4.0 pykube-ng-23.6.0 python-box-7.1.1 python-jsonpath-1.1.1 requests-oauthlib-2.0.0 rsa-4.9 sqlalchemy-2.0.30 wsproto-1.2.0
Import Python modules#
import threading
import warnings
import cupy as cp
import cuspatial
import dask_cudf
import optuna
from cuml.dask.common import utils as dask_utils
from dask.distributed import Client, wait
from dask_kubernetes.operator import KubeCluster
from dask_ml.metrics import mean_squared_error
from dask_ml.model_selection import KFold
from xgboost import dask as dxgb
Set up multiple Dask clusters#
To run multi-GPU training jobs in parallel, we will create multiple Dask clusters each controlling its share of GPUs. It’s best to think of each Dask cluster as a portion of the compute resource of the Kubernetes cluster.
Fill in the following variables:
# Number of nodes in the Kubernetes cluster.
# Each node is assumed to have a single NVIDIA GPU attached
n_nodes = 7
# Number of worker nodes to be assigned to each Dask cluster
n_worker_per_dask_cluster = 2
# Number of nodes to be assigned to each Dask cluster
# 1 is added since the Dask cluster's scheduler process needs to be mapped to its own node
n_node_per_dask_cluster = n_worker_per_dask_cluster + 1
# Number of Dask clusters to be created
# Subtract 1 to account for the notebook pod (it requires its own node)
n_clusters = (n_nodes - 1) // n_node_per_dask_cluster
print(f"{n_clusters=}")
if n_clusters == 0:
raise ValueError(
"No cluster can be created. Reduce `n_worker_per_dask_cluster` or create more compute nodes"
)
print(f"{n_worker_per_dask_cluster=}")
print(f"{n_node_per_dask_cluster=}")
n_node_active = n_clusters * n_node_per_dask_cluster + 1
if n_node_active != n_nodes:
n_idle = n_nodes - n_node_active
warnings.warn(f"{n_idle} node(s) will not be used", stacklevel=2)
n_clusters=2
n_worker_per_dask_cluster=2
n_node_per_dask_cluster=3
Once we’ve determined the number of Dask clusters and their size, we are now ready to launch them:
# Choose the same RAPIDS image you used for launching the notebook session rapids_image = ""
clusters = []
for i in range(n_clusters):
print(f"Launching cluster {i}...")
clusters.append(
KubeCluster(
name=f"rapids-dask{i}",
image=rapids_image,
worker_command="dask-cuda-worker",
n_workers=2,
resources={"limits": {"nvidia.com/gpu": "1"}},
env={"EXTRA_PIP_PACKAGES": "optuna"},
)
)
Launching cluster 0...
Launching cluster 1...
Set up Hyperparameter Optimization Task with NYC Taxi data#
Anaconda has graciously made some of the NYC Taxi dataset available in a public Google Cloud Storage bucket. We’ll use our Cluster of GPUs to process it and train a model that predicts the fare amount. We’ll use our Dask clusters to process it and train a model that predicts the fare amount.
col_dtype = {
"VendorID": "int32",
"tpep_pickup_datetime": "datetime64[ms]",
"tpep_dropoff_datetime": "datetime64[ms]",
"passenger_count": "int32",
"trip_distance": "float32",
"pickup_longitude": "float32",
"pickup_latitude": "float32",
"RatecodeID": "int32",
"store_and_fwd_flag": "int32",
"dropoff_longitude": "float32",
"dropoff_latitude": "float32",
"payment_type": "int32",
"fare_amount": "float32",
"extra": "float32",
"mta_tax": "float32",
"tip_amount": "float32",
"total_amount": "float32",
"tolls_amount": "float32",
"improvement_surcharge": "float32",
}
must_haves = {
"pickup_datetime": "datetime64[ms]",
"dropoff_datetime": "datetime64[ms]",
"passenger_count": "int32",
"trip_distance": "float32",
"pickup_longitude": "float32",
"pickup_latitude": "float32",
"rate_code": "int32",
"dropoff_longitude": "float32",
"dropoff_latitude": "float32",
"fare_amount": "float32",
}
def compute_haversine_distance(df):
pickup = cuspatial.GeoSeries.from_points_xy(
df[["pickup_longitude", "pickup_latitude"]].interleave_columns()
)
dropoff = cuspatial.GeoSeries.from_points_xy(
df[["dropoff_longitude", "dropoff_latitude"]].interleave_columns()
)
df["haversine_distance"] = cuspatial.haversine_distance(pickup, dropoff)
df["haversine_distance"] = df["haversine_distance"].astype("float32")
return df
def clean(ddf, must_haves):
# replace the extraneous spaces in column names and lower the font type
tmp = {col: col.strip().lower() for col in list(ddf.columns)}
ddf = ddf.rename(columns=tmp)
ddf = ddf.rename(
columns={
"tpep_pickup_datetime": "pickup_datetime",
"tpep_dropoff_datetime": "dropoff_datetime",
"ratecodeid": "rate_code",
}
)
ddf["pickup_datetime"] = ddf["pickup_datetime"].astype("datetime64[ms]")
ddf["dropoff_datetime"] = ddf["dropoff_datetime"].astype("datetime64[ms]")
for col in ddf.columns:
if col not in must_haves:
ddf = ddf.drop(columns=col)
continue
if ddf[col].dtype == "object":
# Fixing error: could not convert arg to str
ddf = ddf.drop(columns=col)
else:
# downcast from 64bit to 32bit types
# Tesla T4 are faster on 32bit ops
if "int" in str(ddf[col].dtype):
ddf[col] = ddf[col].astype("int32")
if "float" in str(ddf[col].dtype):
ddf[col] = ddf[col].astype("float32")
ddf[col] = ddf[col].fillna(-1)
return ddf
def prepare_data(client):
taxi_df = dask_cudf.read_csv(
"https://storage.googleapis.com/anaconda-public-data/nyc-taxi/csv/2016/yellow_tripdata_2016-02.csv",
dtype=col_dtype,
)
taxi_df = taxi_df.map_partitions(clean, must_haves, meta=must_haves)
## add features
taxi_df["hour"] = taxi_df["pickup_datetime"].dt.hour.astype("int32")
taxi_df["year"] = taxi_df["pickup_datetime"].dt.year.astype("int32")
taxi_df["month"] = taxi_df["pickup_datetime"].dt.month.astype("int32")
taxi_df["day"] = taxi_df["pickup_datetime"].dt.day.astype("int32")
taxi_df["day_of_week"] = taxi_df["pickup_datetime"].dt.weekday.astype("int32")
taxi_df["is_weekend"] = (taxi_df["day_of_week"] >= 5).astype("int32")
# calculate the time difference between dropoff and pickup.
taxi_df["diff"] = taxi_df["dropoff_datetime"].astype("int32") - taxi_df[
"pickup_datetime"
].astype("int32")
taxi_df["diff"] = (taxi_df["diff"] / 1000).astype("int32")
taxi_df["pickup_latitude_r"] = taxi_df["pickup_latitude"] // 0.01 * 0.01
taxi_df["pickup_longitude_r"] = taxi_df["pickup_longitude"] // 0.01 * 0.01
taxi_df["dropoff_latitude_r"] = taxi_df["dropoff_latitude"] // 0.01 * 0.01
taxi_df["dropoff_longitude_r"] = taxi_df["dropoff_longitude"] // 0.01 * 0.01
taxi_df = taxi_df.drop("pickup_datetime", axis=1)
taxi_df = taxi_df.drop("dropoff_datetime", axis=1)
taxi_df = taxi_df.map_partitions(compute_haversine_distance)
X = (
taxi_df.drop(["fare_amount"], axis=1)
.astype("float32")
.to_dask_array(lengths=True)
)
y = taxi_df["fare_amount"].astype("float32").to_dask_array(lengths=True)
X._meta = cp.asarray(X._meta)
y._meta = cp.asarray(y._meta)
X, y = dask_utils.persist_across_workers(client, [X, y])
return X, y
def train_model(params):
cluster = get_cluster(threading.get_ident())
default_params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"verbosity": 0,
"tree_method": "hist",
"device": "cuda",
}
params = dict(default_params, **params)
with Client(cluster) as client:
X, y = prepare_data(client)
wait([X, y])
scores = []
kfold = KFold(n_splits=5, shuffle=False)
for train_index, test_index in kfold.split(X, y):
dtrain = dxgb.DaskQuantileDMatrix(client, X[train_index, :], y[train_index])
dtest = dxgb.DaskQuantileDMatrix(client, X[test_index, :], y[test_index])
model = dxgb.train(
client,
params,
dtrain,
num_boost_round=10,
verbose_eval=False,
)
y_test_pred = dxgb.predict(client, model, dtest).to_backend("cupy")
rmse_score = mean_squared_error(y[test_index], y_test_pred, squared=False)
scores.append(rmse_score)
return sum(scores) / len(scores)
def objective(trial):
params = {
"n_estimators": trial.suggest_int("n_estimators", 2, 4),
"learning_rate": trial.suggest_float("learning_rate", 0.5, 0.7),
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1),
"colsample_bynode": trial.suggest_float("colsample_bynode", 0.5, 1),
"colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.5, 1),
"reg_lambda": trial.suggest_float("reg_lambda", 0, 1),
"max_depth": trial.suggest_int("max_depth", 1, 6),
"max_leaves": trial.suggest_int("max_leaves", 0, 2),
"max_cat_to_onehot": trial.suggest_int("max_cat_to_onehot", 1, 10),
}
return train_model(params)
To kick off multiple training jobs in parallel, we will launch multiple threads, so that each thread controls a Dask cluster.
One important utility function is get_cluster
, which returns the Dask cluster that’s mapped to a given thread.
# Map each thread's integer ID to a sequential number (0, 1, 2 ...)
thread_id_map: dict[int, KubeCluster] = {}
thread_id_map_lock = threading.Lock()
def get_cluster(thread_id: int) -> KubeCluster:
with thread_id_map_lock:
try:
return clusters[thread_id_map[thread_id]]
except KeyError:
seq_id = len(thread_id_map)
thread_id_map[thread_id] = seq_id
return clusters[seq_id]
Now we are ready to start hyperparameter optimization.
n_trials = (
10 # set to a low number so that the demo finishes quickly. Feel free to adjust
)
study = optuna.create_study(direction="minimize")
[I 2024-05-09 07:53:00,718] A new study created in memory with name: no-name-da830427-bce3-4e42-98e6-c98c0c3da0d7
# With n_jobs parameter, Optuna will launch [n_clusters] threads internally
# Each thread will deploy a training job to a Dask cluster
study.optimize(objective, n_trials=n_trials, n_jobs=n_clusters)
[I 2024-05-09 07:54:10,229] Trial 1 finished with value: 59.449462890625 and parameters: {'n_estimators': 4, 'learning_rate': 0.6399993857892183, 'colsample_bytree': 0.7020623988319513, 'colsample_bynode': 0.777468318546648, 'colsample_bylevel': 0.7890749134903386, 'reg_lambda': 0.4464953694744921, 'max_depth': 3, 'max_leaves': 0, 'max_cat_to_onehot': 9}. Best is trial 1 with value: 59.449462890625.
[I 2024-05-09 07:54:19,507] Trial 0 finished with value: 57.77985763549805 and parameters: {'n_estimators': 4, 'learning_rate': 0.674087333032356, 'colsample_bytree': 0.557642421113256, 'colsample_bynode': 0.9719449711676733, 'colsample_bylevel': 0.6984302171973646, 'reg_lambda': 0.7201514298169174, 'max_depth': 4, 'max_leaves': 1, 'max_cat_to_onehot': 4}. Best is trial 0 with value: 57.77985763549805.
[I 2024-05-09 07:54:59,524] Trial 2 finished with value: 57.77985763549805 and parameters: {'n_estimators': 2, 'learning_rate': 0.6894880267544121, 'colsample_bytree': 0.8171662437182604, 'colsample_bynode': 0.549527686217645, 'colsample_bylevel': 0.890212178266078, 'reg_lambda': 0.5847298606135033, 'max_depth': 2, 'max_leaves': 1, 'max_cat_to_onehot': 5}. Best is trial 0 with value: 57.77985763549805.
[I 2024-05-09 07:55:22,013] Trial 3 finished with value: 55.01234817504883 and parameters: {'n_estimators': 4, 'learning_rate': 0.6597614733926671, 'colsample_bytree': 0.8437061126308156, 'colsample_bynode': 0.621479934699203, 'colsample_bylevel': 0.8330951489228277, 'reg_lambda': 0.7830102753448884, 'max_depth': 2, 'max_leaves': 2, 'max_cat_to_onehot': 2}. Best is trial 3 with value: 55.01234817504883.
[I 2024-05-09 07:56:00,678] Trial 4 finished with value: 57.77985763549805 and parameters: {'n_estimators': 4, 'learning_rate': 0.5994587326401378, 'colsample_bytree': 0.9799078215504886, 'colsample_bynode': 0.9766955839079614, 'colsample_bylevel': 0.5088864363378924, 'reg_lambda': 0.18103184809548734, 'max_depth': 3, 'max_leaves': 1, 'max_cat_to_onehot': 4}. Best is trial 3 with value: 55.01234817504883.
[I 2024-05-09 07:56:11,773] Trial 5 finished with value: 54.936126708984375 and parameters: {'n_estimators': 2, 'learning_rate': 0.5208827661289628, 'colsample_bytree': 0.866258912492528, 'colsample_bynode': 0.6368815844513638, 'colsample_bylevel': 0.9539603435186208, 'reg_lambda': 0.21390618865079458, 'max_depth': 4, 'max_leaves': 2, 'max_cat_to_onehot': 4}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:56:48,737] Trial 6 finished with value: 57.77985763549805 and parameters: {'n_estimators': 2, 'learning_rate': 0.6137888371528442, 'colsample_bytree': 0.9621063205689744, 'colsample_bynode': 0.5306812468481084, 'colsample_bylevel': 0.8527827651989199, 'reg_lambda': 0.3315799968401767, 'max_depth': 6, 'max_leaves': 1, 'max_cat_to_onehot': 9}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:56:59,261] Trial 7 finished with value: 55.204200744628906 and parameters: {'n_estimators': 3, 'learning_rate': 0.6831416027240611, 'colsample_bytree': 0.5311840770388268, 'colsample_bynode': 0.9572535535110238, 'colsample_bylevel': 0.6846894032354778, 'reg_lambda': 0.6091211134408249, 'max_depth': 3, 'max_leaves': 2, 'max_cat_to_onehot': 5}. Best is trial 5 with value: 54.936126708984375.
[I 2024-05-09 07:57:37,674] Trial 8 finished with value: 54.93584442138672 and parameters: {'n_estimators': 4, 'learning_rate': 0.620742285616388, 'colsample_bytree': 0.7969398985157778, 'colsample_bynode': 0.9049707375663323, 'colsample_bylevel': 0.7209693969245297, 'reg_lambda': 0.6158847054585023, 'max_depth': 1, 'max_leaves': 0, 'max_cat_to_onehot': 10}. Best is trial 8 with value: 54.93584442138672.
[I 2024-05-09 07:57:50,310] Trial 9 finished with value: 57.76123809814453 and parameters: {'n_estimators': 3, 'learning_rate': 0.5475197727057007, 'colsample_bytree': 0.5381502848057452, 'colsample_bynode': 0.8514705732161596, 'colsample_bylevel': 0.9139277684007088, 'reg_lambda': 0.5117732009332318, 'max_depth': 4, 'max_leaves': 0, 'max_cat_to_onehot': 5}. Best is trial 8 with value: 54.93584442138672.