Model Serialization and Persistence#
This notebook demonstrates how to save and load cuML models using various serialization methods, including pickle, joblib, and cross-platform deployment strategies.
Single GPU Model Serialization#
All single-GPU cuML estimators support serialization using standard Python libraries. This section demonstrates:
Training a model on synthetic data
Saving the model using pickle and joblib
Loading the model for future use
Trained single-GPU models can also be used for distributed inference on Dask clusters, as shown in the Distributed Model Serialization section.
[1]:
from cuml.cluster import KMeans
from cuml.datasets import make_blobs
# Generate synthetic dataset for clustering
X, y = make_blobs(
n_samples=50, n_features=10, centers=5, cluster_std=0.4, random_state=0
)
# Initialize and fit KMeans model
kmeans = KMeans(n_clusters=5).fit(X)
Recommendation: Use Pickle protocol 5 for better performance with large arrays and models. Protocol 5 provides significant speed improvements for NumPy arrays and cuML models with large parameter sets.
[2]:
import pickle
# Save the fitted model to disk
with open("kmeans_model.pkl", "wb") as output_file:
pickle.dump(kmeans, output_file, protocol=5)
Important: The model can be restored using pickle, but requires the same cuML version used for training. If you need to load models across different cuML versions, consider using the scikit-learn conversion approach instead.
[3]:
# Load the model from disk
with open("kmeans_model.pkl", "rb") as input_file:
kmeans_loaded_model = pickle.load(input_file)
# Display the loaded model's cluster centers
kmeans_loaded_model.cluster_centers_
[3]:
array([[-2.922248 , 4.7528377 , -4.3529677 , 2.2710595 , 1.7184176 ,
-2.5451765 , -5.50611 , -1.7181125 , -8.245671 , 2.8203053 ],
[-5.8374496 , 2.0425208 , -3.8477435 , -1.8293772 , -5.257385 ,
7.710398 , 2.9743197 , 8.42101 , 1.5094917 , 1.0263587 ],
[ 4.781445 , 8.392482 , -9.312664 , 9.438168 , 8.540471 ,
-1.0861524 , 3.437934 , -8.072111 , -0.6570339 , 0.2782365 ],
[ 5.317544 , -4.372343 , 4.2193136 , -2.7930846 , 3.766153 ,
-4.301045 , -3.730563 , 6.330142 , -6.965777 , -1.1038128 ],
[-4.0034103 , 5.5426564 , -5.8204336 , -1.8451873 , -9.445931 ,
0.72651756, 4.2096705 , -2.5796611 , -5.0424485 , 9.633467 ]],
dtype=float32)
Using joblib for Model Serialization#
joblib is an optimized alternative to pickle for machine learning models, offering:
Better performance for large NumPy arrays and cuML models
Efficient compression for models with many parameters
Memory mapping for faster loading of large models
Optimized serialization specifically designed for ML workloads
Note: While pickle and joblib files are often compatible, we recommend using the same library for both saving and loading to ensure reliability.
[4]:
import joblib
joblib.dump(kmeans, "kmeans_model.joblib")
[4]:
['kmeans_model.joblib']
Then reload the model with joblib.
[5]:
kmeans_loaded_model = joblib.load("kmeans_model.joblib")
kmeans_loaded_model.cluster_centers_
[5]:
array([[-2.922248 , 4.7528377 , -4.3529677 , 2.2710595 , 1.7184176 ,
-2.5451765 , -5.50611 , -1.7181125 , -8.245671 , 2.8203053 ],
[-5.8374496 , 2.0425208 , -3.8477435 , -1.8293772 , -5.257385 ,
7.710398 , 2.9743197 , 8.42101 , 1.5094917 , 1.0263587 ],
[ 4.781445 , 8.392482 , -9.312664 , 9.438168 , 8.540471 ,
-1.0861524 , 3.437934 , -8.072111 , -0.6570339 , 0.2782365 ],
[ 5.317544 , -4.372343 , 4.2193136 , -2.7930846 , 3.766153 ,
-4.301045 , -3.730563 , 6.330142 , -6.965777 , -1.1038128 ],
[-4.0034103 , 5.5426564 , -5.8204336 , -1.8451873 , -9.445931 ,
0.72651756, 4.2096705 , -2.5796611 , -5.0424485 , 9.633467 ]],
dtype=float32)
Distributed Model Serialization#
When working with distributed cuML models using Dask, the distributed estimator wrappers in cuml.dask
are not designed to be pickled directly. Instead, cuML provides a specialized workflow:
Workflow Steps#
Extract the combined model: Use
get_combined_model()
to extract a single-GPU version of the trained distributed modelSerialize the combined model: Save the extracted model using pickle or joblib (same as any cuML model)
Flexible inference: Use the saved model in multiple ways:
Single-GPU inference: Load directly for single-GPU predictions
Distributed inference: Use
ParallelPostFit
from Dask-ML to distribute inference across a Dask cluster
This approach allows you to choose the optimal resources for both training and inference phases.
[6]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
# Set up Dask cluster
cluster = LocalCUDACluster()
client = Client(cluster)
[7]:
from cuml.dask.datasets import make_blobs
from cuml.dask.cluster import KMeans as DistributedKMeans
# Get number of workers
n_workers = client.scheduler_info()["n_workers"]
# Generate distributed dataset
X, y = make_blobs(
n_samples=5000,
n_features=30,
centers=5,
cluster_std=0.4,
random_state=0,
# 5 parts per worker to demonstrate distributed inference
n_parts=n_workers * 5,
)
# Initialize and train the distributed KMeans model
distributed_kmeans = DistributedKMeans(n_clusters=5).fit(X)
Now we can save it with pickle like before, but we have to combine it into a non-distributed model first.
[8]:
# Extract single-GPU model and save it
combined_kmeans = distributed_kmeans.get_combined_model()
with open("kmeans_model.pkl", "wb") as output_file:
pickle.dump(combined_kmeans, output_file, protocol=5)
And we can reload this model just like before.
[9]:
# Load the single-GPU model
with open("kmeans_model.pkl", "rb") as input_file:
combined_kmeans_loaded_model = pickle.load(input_file)
# Display the first 3 rows of the loaded model's cluster centers
combined_kmeans_loaded_model.cluster_centers_[:3]
[9]:
array([[ 4.821157 , 8.41044 , -9.220306 , 9.358317 , 8.497288 ,
-1.0601722 , 3.3365138 , -7.795003 , -0.6003679 , 0.25098842,
5.5205603 , -4.1234083 , 4.297074 , -2.8325503 , 3.6292467 ,
-4.1684513 , -3.614415 , 6.212969 , -6.915091 , -1.0821886 ,
-5.8427916 , 2.2023475 , -3.8588023 , -1.6982683 , -5.287985 ,
7.592434 , 2.9302495 , 8.511086 , 1.5768247 , 1.0933212 ],
[-2.8774009 , 4.479107 , -4.435228 , 2.3590753 , 1.7383353 ,
-2.5148928 , -5.1813664 , -1.6930894 , -8.126738 , 2.658165 ,
-4.2927976 , 5.5758367 , -5.730178 , -1.744792 , -9.35795 ,
0.7091946 , 4.4193583 , -2.9347353 , -4.933291 , 9.705753 ,
8.379933 , -6.276512 , -6.3580914 , 1.9855402 , 4.153807 ,
-9.153121 , 4.6185417 , 8.818425 , 6.8634167 , 2.2497067 ],
[-6.9409394 , -9.775783 , -6.5518556 , -0.43954796, 6.0999207 ,
3.7421818 , -3.965521 , 6.1366067 , -1.8634121 , 5.0342364 ,
-6.826796 , 1.342927 , 9.008171 , -1.00592 , 9.645004 ,
9.789133 , -8.619173 , 5.9947176 , 2.212121 , -3.6181026 ,
7.083663 , -7.378212 , -5.3021903 , -6.9675446 , -7.9429984 ,
6.653303 , -5.58039 , 7.1386843 , 6.6048436 , -8.308932 ]],
dtype=float32)
Converting Between cuML and scikit-learn Models#
Many cuML estimators provide as_sklearn()
and from_sklearn()
methods for seamless conversion between cuML and scikit-learn formats.
Use Cases#
Cross-platform deployment: Train on GPU systems, deploy on CPU-only machines
Maximum compatibility: Use standard scikit-learn serialization tools
Hybrid workflows: Mix cuML and scikit-learn in the same pipeline
Legacy integration: Convert existing scikit-learn models to cuML for GPU acceleration
This approach eliminates the need to install cuML on deployment machines while maintaining model compatibility.
[10]:
import pickle
from cuml.cluster import KMeans
from cuml.datasets import make_blobs
from cuml.metrics.cluster import adjusted_rand_score
# Generate synthetic dataset for clustering
X, y = make_blobs(
n_samples=1000, n_features=20, centers=5, cluster_std=0.5, random_state=42
)
# Train cuML KMeans
kmeans = KMeans(n_clusters=5, random_state=42).fit(X)
# Make predictions with cuML model
predictions = kmeans.predict(X)
score = adjusted_rand_score(y, predictions)
print(f"cuML KMeans ARI score: {score:.4f}")
print(f"cuML KMeans cluster centers shape: {kmeans.cluster_centers_.shape}")
cuML KMeans ARI score: 1.0000
cuML KMeans cluster centers shape: (5, 20)
We can convert this cuML model into a native scikit-learn estimator using the as_sklearn()
method. This enables standard scikit-learn serialization and deployment on any Python environment.
[11]:
# Convert cuML model to scikit-learn model
kmeans_sklearn = kmeans.as_sklearn()
print(f"Converted to scikit-learn model: {type(kmeans_sklearn)}")
# Save scikit-learn model to disk
pickle.dump(kmeans_sklearn, open("kmeans_model_sklearn.pkl", "wb"), protocol=5)
print("scikit-learn KMeans model saved with pickle")
Converted to scikit-learn model: <class 'sklearn.cluster._kmeans.KMeans'>
scikit-learn KMeans model saved with pickle
The pickled scikit-learn model can be loaded and executed on any Python environment with only scikit-learn installed – no cuML or GPU required.
[12]:
from cupy import asnumpy
# Load scikit-learn model and verify prediction quality
kmeans_loaded_sklearn = pickle.load(open("kmeans_model_sklearn.pkl", "rb"))
sklearn_predictions = kmeans_loaded_sklearn.predict(asnumpy(X))
sklearn_score = adjusted_rand_score(y, sklearn_predictions)
print(f"Loaded sklearn KMeans ARI score: {sklearn_score:.4f}")
Loaded sklearn KMeans ARI score: 1.0000
You can also reconstruct a cuML model from a scikit-learn model using from_sklearn()
. This is particularly useful for:
Pre-trained models: Convert existing scikit-learn models for GPU acceleration
Performance optimization: Run faster inference on GPU hardware
Hybrid workflows: Switch between CPU and GPU execution as needed
[13]:
# Re-construct the cuML model from the scikit-learn model
kmeans_from_sklearn = KMeans.from_sklearn(kmeans_loaded_sklearn)
predictions = kmeans_from_sklearn.predict(X)
print("Re-constructed cuML KMeans ARI Score: ", adjusted_rand_score(y, predictions))
Re-constructed cuML KMeans ARI Score: 1.0
Exporting Random Forest Models for CPU-Only Deployment#
You can export cuML Random Forest models for deployment on machines without NVIDIA GPUs using the Treelite library.
Benefits#
CPU-only deployment: Run trained models on any machine
Optimized inference: Treelite provides highly optimized CPU inference
Small footprint: No cuML or GPU dependencies required
Production ready: Efficient serialization and fast loading
Export Process#
Convert to Treelite format: Use
as_treelite()
to transform your cuML Random Forest modelSerialize the model: Call
.serialize()
to create a portable checkpoint fileDeploy anywhere: Install Treelite on the target machine and load the model for inference
[14]:
import numpy as np
from cuml.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# Load and prepare iris dataset
X, y = load_iris(return_X_y=True)
X, y = X.astype(np.float32), y.astype(np.int32)
# Train Random Forest model
random_forest = RandomForestClassifier(
max_depth=3, random_state=0, n_estimators=10
).fit(X, y)
# Export cuML RF model as Treelite checkpoint
treelite_checkpoint_path = "./checkpoint.tl"
random_forest.as_treelite().serialize(treelite_checkpoint_path)
Deployment Steps#
Copy the checkpoint file: Transfer
checkpoint.tl
to your target machineInstall Treelite: Run
pip install treelite
orconda install -c conda-forge treelite
No NVIDIA GPUs required
No cuML installation needed
Load and use the model: Run the code below on the target machine
[15]:
import treelite
# Load the Treelite model (checkpoint file has been copied over)
treelite_checkpoint_path = "./checkpoint.tl"
treelite_model = treelite.Model.deserialize(treelite_checkpoint_path)
# Make predictions using Treelite
predictions = treelite.gtil.predict(treelite_model, X, pred_margin=True)