Pickling cuML Models for Persistence

This notebook demonstrates simple pickling of both single-GPU and multi-GPU cuML models for persistence

[1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

Single GPU Model Pickling

All single-GPU estimators are pickleable. The following example demonstrates the creation of a synthetic dataset, training, and pickling of the resulting model for storage. Trained single-GPU models can also be used to distribute the inference on a Dask cluster, which the Distributed Model Pickling section below demonstrates.

[2]:
from cuml.datasets import make_blobs

X, y = make_blobs(n_samples=50,
                  n_features=10,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0)
[3]:
from cuml.cluster import KMeans

model = KMeans(n_clusters=5)

model.fit(X)
[3]:
KMeans()
[4]:
import pickle

pickle.dump(model, open("kmeans_model.pkl", "wb"))
[5]:
model = pickle.load(open("kmeans_model.pkl", "rb"))
[6]:
model.cluster_centers_
[6]:
array([[ 5.2615476, -4.0487256,  4.464928 , -2.9367518,  3.5061095,
        -4.016832 , -3.463885 ,  6.078449 , -6.9533257, -1.004144 ],
       [-4.243999 ,  5.610707 , -5.6697764, -1.7957246, -9.255528 ,
         0.7177438,  4.4435897, -2.874715 , -5.0900955,  9.684121 ],
       [ 4.6749854,  8.213466 , -9.075721 ,  9.568374 ,  8.454807 ,
        -1.2327975,  3.3903713, -7.8282413, -0.8454461,  0.6288572],
       [-3.008261 ,  4.625961 , -4.483249 ,  2.228457 ,  1.643532 ,
        -2.4505196, -5.2582016, -1.6679403, -7.9857535,  2.8311467],
       [-5.6072407,  2.2695985, -3.7516537, -1.8182005, -5.1430273,
         7.599364 ,  2.8252368,  8.773042 ,  1.6198314,  1.1772048]],
      dtype=float32)

Distributed Model Pickling

The distributed estimator wrappers inside of the cuml.dask are not intended to be pickled directly. The Dask cuML estimators provide a function get_combined_model(), which returns the trained single-GPU model for pickling. The combined model can be used for inference on a single-GPU, and the ParallelPostFit wrapper from the Dask-ML library can be used to perform distributed inference on a Dask cluster.

[7]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

cluster = LocalCUDACluster()
client = Client(cluster)
client
[7]:

Client

Cluster

  • Workers: 1
  • Cores: 1
  • Memory: 251.80 GiB
[8]:
from cuml.dask.datasets import make_blobs

n_workers = len(client.scheduler_info()["workers"].keys())

X, y = make_blobs(n_samples=5000,
                  n_features=30,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0,
                  n_parts=n_workers*5)

X = X.persist()
y = y.persist()
[9]:
from cuml.dask.cluster import KMeans

dist_model = KMeans(n_clusters=5)
[10]:
dist_model.fit(X)
[10]:
<cuml.dask.cluster.kmeans.KMeans at 0x7fd2b021a050>
[11]:
import pickle

single_gpu_model = dist_model.get_combined_model()
pickle.dump(single_gpu_model, open("kmeans_model.pkl", "wb"))
[12]:
single_gpu_model = pickle.load(open("kmeans_model.pkl", "rb"))
[13]:
single_gpu_model.cluster_centers_
[13]:
array([[-2.8796387e+00,  4.4348764e+00, -4.4264812e+00,  2.3959143e+00,
         1.7286434e+00, -2.4991984e+00, -5.1794519e+00, -1.6771443e+00,
        -8.1329165e+00,  2.6659224e+00, -4.3131094e+00,  5.5827808e+00,
        -5.7318311e+00, -1.7427168e+00, -9.3456116e+00,  7.1365571e-01,
         4.4255495e+00, -2.9118376e+00, -4.9467444e+00,  9.6786413e+00,
         8.4222746e+00, -6.2710242e+00, -6.3596501e+00,  1.9645507e+00,
         4.1715994e+00, -9.1683636e+00,  4.6156683e+00,  8.7916489e+00,
         6.8754416e+00,  2.2288749e+00],
       [-6.9536943e+00, -9.7635870e+00, -6.5648260e+00, -4.3536153e-01,
         6.0998116e+00,  3.7550371e+00, -3.9558537e+00,  6.1595526e+00,
        -1.8599318e+00,  5.0400310e+00, -6.8397551e+00,  1.3435433e+00,
         8.9749012e+00, -9.9621779e-01,  9.6651945e+00,  9.8009663e+00,
        -8.6188364e+00,  5.9978366e+00,  2.2295928e+00, -3.6477714e+00,
         7.0758510e+00, -7.3772254e+00, -5.3214231e+00, -6.9927959e+00,
        -7.9296322e+00,  6.6705360e+00, -5.5850182e+00,  7.1526051e+00,
         6.5703220e+00, -8.3389406e+00],
       [ 4.8136683e+00,  8.3985281e+00, -9.2161236e+00,  9.4185524e+00,
         8.5280876e+00, -1.0969982e+00,  3.3253176e+00, -7.8064370e+00,
        -5.9660637e-01,  2.5423864e-01,  5.5004082e+00, -4.1162963e+00,
         4.2832375e+00, -2.8173413e+00,  3.6207721e+00, -4.1576214e+00,
        -3.6048706e+00,  6.2125397e+00, -6.9080992e+00, -1.0732135e+00,
        -5.8362112e+00,  2.2357666e+00, -3.8588786e+00, -1.6835877e+00,
        -5.3240366e+00,  7.5769196e+00,  2.9358525e+00,  8.5267372e+00,
         1.5667247e+00,  1.0779675e+00],
       [-4.6475401e+00, -9.5672178e+00,  6.6923518e+00,  4.4359236e+00,
         2.1902738e+00,  2.5834756e+00,  5.9448934e-01,  6.2568665e+00,
        -8.7821655e+00, -4.1232008e-01,  9.8151779e+00,  7.5641570e+00,
         1.0003010e+01, -5.8680439e+00, -1.2743111e+00, -2.5393457e+00,
        -1.0847499e+00, -5.2629204e+00, -9.3071022e+00,  4.6179361e+00,
        -9.7068965e-02, -3.9351211e+00,  6.1767273e+00, -7.4346881e+00,
         5.6496077e+00, -8.5544844e+00, -7.5265574e+00, -5.5195603e+00,
         4.8197627e+00,  2.5235438e+00],
       [ 6.2794294e+00,  9.2293940e+00,  8.3403702e+00,  9.0330496e+00,
         7.6893492e+00, -9.9538225e-01, -6.2780757e+00,  1.3599334e+00,
        -6.9744482e+00, -5.9463458e+00,  1.0695115e+00, -8.0422508e-03,
         2.8183138e+00,  1.8317667e+00, -8.2557344e+00,  3.0514314e+00,
        -8.4958000e+00,  9.7238474e+00, -7.7455082e+00,  3.4521687e+00,
        -3.9248333e+00, -4.1106420e+00,  2.6693089e+00,  1.2985626e+00,
         1.0421573e+00,  5.2490616e+00, -1.6496239e+00,  6.1451659e+00,
        -6.9103327e+00, -9.6390305e+00]], dtype=float32)