Model Serialization and Persistence#

This notebook demonstrates how to save and load cuML models using various serialization methods, including pickle, joblib, and cross-platform deployment strategies.

Single GPU Model Serialization#

All single-GPU cuML estimators support serialization using standard Python libraries. This section demonstrates:

  1. Training a model on synthetic data

  2. Saving the model using pickle and joblib

  3. Loading the model for future use

Trained single-GPU models can also be used for distributed inference on Dask clusters, as shown in the Distributed Model Serialization section.

[1]:
from cuml.cluster import KMeans
from cuml.datasets import make_blobs

# Generate synthetic dataset for clustering
X, y = make_blobs(
    n_samples=50, n_features=10, centers=5, cluster_std=0.4, random_state=0
)
# Initialize and fit KMeans model
kmeans = KMeans(n_clusters=5).fit(X)

Recommendation: Use Pickle protocol 5 for better performance with large arrays and models. Protocol 5 provides significant speed improvements for NumPy arrays and cuML models with large parameter sets.

[2]:
import pickle

# Save the fitted model to disk
with open("kmeans_model.pkl", "wb") as output_file:
    pickle.dump(kmeans, output_file, protocol=5)

Important: The model can be restored using pickle, but requires the same cuML version used for training. If you need to load models across different cuML versions, consider using the scikit-learn conversion approach instead.

[3]:
# Load the model from disk
with open("kmeans_model.pkl", "rb") as input_file:
    kmeans_loaded_model = pickle.load(input_file)

# Display the loaded model's cluster centers
kmeans_loaded_model.cluster_centers_
[3]:
array([[-2.922248  ,  4.7528377 , -4.3529677 ,  2.2710595 ,  1.7184176 ,
        -2.5451765 , -5.50611   , -1.7181125 , -8.245671  ,  2.8203053 ],
       [-5.8374496 ,  2.0425208 , -3.8477435 , -1.8293772 , -5.257385  ,
         7.710398  ,  2.9743197 ,  8.42101   ,  1.5094917 ,  1.0263587 ],
       [ 4.781445  ,  8.392482  , -9.312664  ,  9.438168  ,  8.540471  ,
        -1.0861524 ,  3.437934  , -8.072111  , -0.6570339 ,  0.2782365 ],
       [ 5.317544  , -4.372343  ,  4.2193136 , -2.7930846 ,  3.766153  ,
        -4.301045  , -3.730563  ,  6.330142  , -6.965777  , -1.1038128 ],
       [-4.0034103 ,  5.5426564 , -5.8204336 , -1.8451873 , -9.445931  ,
         0.72651756,  4.2096705 , -2.5796611 , -5.0424485 ,  9.633467  ]],
      dtype=float32)

Using joblib for Model Serialization#

joblib is an optimized alternative to pickle for machine learning models, offering:

  • Better performance for large NumPy arrays and cuML models

  • Efficient compression for models with many parameters

  • Memory mapping for faster loading of large models

  • Optimized serialization specifically designed for ML workloads

Note: While pickle and joblib files are often compatible, we recommend using the same library for both saving and loading to ensure reliability.

[4]:
import joblib

joblib.dump(kmeans, "kmeans_model.joblib")
[4]:
['kmeans_model.joblib']

Then reload the model with joblib.

[5]:
kmeans_loaded_model = joblib.load("kmeans_model.joblib")
kmeans_loaded_model.cluster_centers_
[5]:
array([[-2.922248  ,  4.7528377 , -4.3529677 ,  2.2710595 ,  1.7184176 ,
        -2.5451765 , -5.50611   , -1.7181125 , -8.245671  ,  2.8203053 ],
       [-5.8374496 ,  2.0425208 , -3.8477435 , -1.8293772 , -5.257385  ,
         7.710398  ,  2.9743197 ,  8.42101   ,  1.5094917 ,  1.0263587 ],
       [ 4.781445  ,  8.392482  , -9.312664  ,  9.438168  ,  8.540471  ,
        -1.0861524 ,  3.437934  , -8.072111  , -0.6570339 ,  0.2782365 ],
       [ 5.317544  , -4.372343  ,  4.2193136 , -2.7930846 ,  3.766153  ,
        -4.301045  , -3.730563  ,  6.330142  , -6.965777  , -1.1038128 ],
       [-4.0034103 ,  5.5426564 , -5.8204336 , -1.8451873 , -9.445931  ,
         0.72651756,  4.2096705 , -2.5796611 , -5.0424485 ,  9.633467  ]],
      dtype=float32)

Distributed Model Serialization#

When working with distributed cuML models using Dask, the distributed estimator wrappers in cuml.dask are not designed to be pickled directly. Instead, cuML provides a specialized workflow:

Workflow Steps#

  1. Extract the combined model: Use get_combined_model() to extract a single-GPU version of the trained distributed model

  2. Serialize the combined model: Save the extracted model using pickle or joblib (same as any cuML model)

  3. Flexible inference: Use the saved model in multiple ways:

    • Single-GPU inference: Load directly for single-GPU predictions

    • Distributed inference: Use ParallelPostFit from Dask-ML to distribute inference across a Dask cluster

This approach allows you to choose the optimal resources for both training and inference phases.

[6]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

# Set up Dask cluster
cluster = LocalCUDACluster()
client = Client(cluster)
[7]:
from cuml.dask.datasets import make_blobs
from cuml.dask.cluster import KMeans as DistributedKMeans

# Get number of workers
n_workers = client.scheduler_info()["n_workers"]

# Generate distributed dataset
X, y = make_blobs(
    n_samples=5000,
    n_features=30,
    centers=5,
    cluster_std=0.4,
    random_state=0,
    # 5 parts per worker to demonstrate distributed inference
    n_parts=n_workers * 5,
)

# Initialize and train the distributed KMeans model
distributed_kmeans = DistributedKMeans(n_clusters=5).fit(X)

Now we can save it with pickle like before, but we have to combine it into a non-distributed model first.

[8]:
# Extract single-GPU model and save it
combined_kmeans = distributed_kmeans.get_combined_model()

with open("kmeans_model.pkl", "wb") as output_file:
    pickle.dump(combined_kmeans, output_file, protocol=5)

And we can reload this model just like before.

[9]:
# Load the single-GPU model
with open("kmeans_model.pkl", "rb") as input_file:
    combined_kmeans_loaded_model = pickle.load(input_file)

# Display the first 3 rows of the loaded model's cluster centers
combined_kmeans_loaded_model.cluster_centers_[:3]
[9]:
array([[ 4.821157  ,  8.41044   , -9.220306  ,  9.358317  ,  8.497288  ,
        -1.0601722 ,  3.3365138 , -7.795003  , -0.6003679 ,  0.25098842,
         5.5205603 , -4.1234083 ,  4.297074  , -2.8325503 ,  3.6292467 ,
        -4.1684513 , -3.614415  ,  6.212969  , -6.915091  , -1.0821886 ,
        -5.8427916 ,  2.2023475 , -3.8588023 , -1.6982683 , -5.287985  ,
         7.592434  ,  2.9302495 ,  8.511086  ,  1.5768247 ,  1.0933212 ],
       [-2.8774009 ,  4.479107  , -4.435228  ,  2.3590753 ,  1.7383353 ,
        -2.5148928 , -5.1813664 , -1.6930894 , -8.126738  ,  2.658165  ,
        -4.2927976 ,  5.5758367 , -5.730178  , -1.744792  , -9.35795   ,
         0.7091946 ,  4.4193583 , -2.9347353 , -4.933291  ,  9.705753  ,
         8.379933  , -6.276512  , -6.3580914 ,  1.9855402 ,  4.153807  ,
        -9.153121  ,  4.6185417 ,  8.818425  ,  6.8634167 ,  2.2497067 ],
       [-6.9409394 , -9.775783  , -6.5518556 , -0.43954796,  6.0999207 ,
         3.7421818 , -3.965521  ,  6.1366067 , -1.8634121 ,  5.0342364 ,
        -6.826796  ,  1.342927  ,  9.008171  , -1.00592   ,  9.645004  ,
         9.789133  , -8.619173  ,  5.9947176 ,  2.212121  , -3.6181026 ,
         7.083663  , -7.378212  , -5.3021903 , -6.9675446 , -7.9429984 ,
         6.653303  , -5.58039   ,  7.1386843 ,  6.6048436 , -8.308932  ]],
      dtype=float32)

Converting Between cuML and scikit-learn Models#

Many cuML estimators provide as_sklearn() and from_sklearn() methods for seamless conversion between cuML and scikit-learn formats.

Use Cases#

  • Cross-platform deployment: Train on GPU systems, deploy on CPU-only machines

  • Maximum compatibility: Use standard scikit-learn serialization tools

  • Hybrid workflows: Mix cuML and scikit-learn in the same pipeline

  • Legacy integration: Convert existing scikit-learn models to cuML for GPU acceleration

This approach eliminates the need to install cuML on deployment machines while maintaining model compatibility.

[10]:
import pickle

from cuml.cluster import KMeans
from cuml.datasets import make_blobs
from cuml.metrics.cluster import adjusted_rand_score

# Generate synthetic dataset for clustering
X, y = make_blobs(
    n_samples=1000, n_features=20, centers=5, cluster_std=0.5, random_state=42
)

# Train cuML KMeans
kmeans = KMeans(n_clusters=5, random_state=42).fit(X)

# Make predictions with cuML model
predictions = kmeans.predict(X)
score = adjusted_rand_score(y, predictions)
print(f"cuML KMeans ARI score: {score:.4f}")
print(f"cuML KMeans cluster centers shape: {kmeans.cluster_centers_.shape}")
cuML KMeans ARI score: 1.0000
cuML KMeans cluster centers shape: (5, 20)

We can convert this cuML model into a native scikit-learn estimator using the as_sklearn() method. This enables standard scikit-learn serialization and deployment on any Python environment.

[11]:
# Convert cuML model to scikit-learn model
kmeans_sklearn = kmeans.as_sklearn()
print(f"Converted to scikit-learn model: {type(kmeans_sklearn)}")

# Save scikit-learn model to disk
pickle.dump(kmeans_sklearn, open("kmeans_model_sklearn.pkl", "wb"), protocol=5)
print("scikit-learn KMeans model saved with pickle")
Converted to scikit-learn model: <class 'sklearn.cluster._kmeans.KMeans'>
scikit-learn KMeans model saved with pickle

The pickled scikit-learn model can be loaded and executed on any Python environment with only scikit-learn installed – no cuML or GPU required.

[12]:
from cupy import asnumpy

# Load scikit-learn model and verify prediction quality
kmeans_loaded_sklearn = pickle.load(open("kmeans_model_sklearn.pkl", "rb"))
sklearn_predictions = kmeans_loaded_sklearn.predict(asnumpy(X))
sklearn_score = adjusted_rand_score(y, sklearn_predictions)
print(f"Loaded sklearn KMeans ARI score: {sklearn_score:.4f}")
Loaded sklearn KMeans ARI score: 1.0000

You can also reconstruct a cuML model from a scikit-learn model using from_sklearn(). This is particularly useful for:

  • Pre-trained models: Convert existing scikit-learn models for GPU acceleration

  • Performance optimization: Run faster inference on GPU hardware

  • Hybrid workflows: Switch between CPU and GPU execution as needed

[13]:
# Re-construct the cuML model from the scikit-learn model
kmeans_from_sklearn = KMeans.from_sklearn(kmeans_loaded_sklearn)
predictions = kmeans_from_sklearn.predict(X)
print("Re-constructed cuML KMeans ARI Score: ", adjusted_rand_score(y, predictions))
Re-constructed cuML KMeans ARI Score:  1.0

Exporting Random Forest Models for CPU-Only Deployment#

You can export cuML Random Forest models for deployment on machines without NVIDIA GPUs using the Treelite library.

Benefits#

  • CPU-only deployment: Run trained models on any machine

  • Optimized inference: Treelite provides highly optimized CPU inference

  • Small footprint: No cuML or GPU dependencies required

  • Production ready: Efficient serialization and fast loading

Export Process#

  1. Convert to Treelite format: Use as_treelite() to transform your cuML Random Forest model

  2. Serialize the model: Call .serialize() to create a portable checkpoint file

  3. Deploy anywhere: Install Treelite on the target machine and load the model for inference

[14]:
import numpy as np
from cuml.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# Load and prepare iris dataset
X, y = load_iris(return_X_y=True)
X, y = X.astype(np.float32), y.astype(np.int32)

# Train Random Forest model
random_forest = RandomForestClassifier(
    max_depth=3, random_state=0, n_estimators=10
).fit(X, y)

# Export cuML RF model as Treelite checkpoint
treelite_checkpoint_path = "./checkpoint.tl"
random_forest.as_treelite().serialize(treelite_checkpoint_path)

Deployment Steps#

  1. Copy the checkpoint file: Transfer checkpoint.tl to your target machine

  2. Install Treelite: Run pip install treelite or conda install -c conda-forge treelite

    • No NVIDIA GPUs required

    • No cuML installation needed

  3. Load and use the model: Run the code below on the target machine

[15]:
import treelite

# Load the Treelite model (checkpoint file has been copied over)
treelite_checkpoint_path = "./checkpoint.tl"
treelite_model = treelite.Model.deserialize(treelite_checkpoint_path)

# Make predictions using Treelite
predictions = treelite.gtil.predict(treelite_model, X, pred_margin=True)