cuML on GPU and CPU#

cuML is a Scikit-learn-like suite of fast, GPU-accelerated machine learning algorithms designed for data science and analytical tasks.

Starting with version 23.10, cuML provides both GPU-based and CPU-based execution capabilities with zero code change required to switch between them. This unified CPU/GPU cuML:

  • Allows users to prototype in systems without GPUs.

  • Allows library integrations without the need for dispatching and boilerplate code.

  • Allows users to train on one type of system and infer with the other for a subset of estimators (that will expand over time).

  • Provides compatibility with the broader GPU/CPU open source pydata ecosystem.

The majority of estimators of cuML can run in both CPU and GPU systems, with a subset of them supporting exporting models between GPU and CPU systems. The following table shows support for the most common estimators:

Category

Algorithm

Supports Execution on CPU

Supports Exporting between CPU and GPU

Clustering

Density-Based Spatial Clustering of Applications with Noise (DBSCAN)

Yes

No

Hierarchical Density-Based Spatial Clustering of Applications with Noise (HDBSCAN)

Yes

Partial

K-Means

Yes

No

Single-Linkage Agglomerative Clustering

No

No

Dimensionality Reduction

Principal Components Analysis (PCA)

Yes

Yes

Incremental PCA

No

No

Truncated Singular Value Decomposition (tSVD)

Yes

Yes

Uniform Manifold Approximation and Projection (UMAP)

Yes

Partial

Random Projection

No

No

t-Distributed Stochastic Neighbor Embedding (TSNE)

No

No

Linear Models for Regression or Classification

Linear Regression (OLS)

Yes

Yes

Linear Regression with Lasso or Ridge Regularization

Yes

Yes

ElasticNet Regression

Yes

Yes

LARS Regression

No

No

Logistic Regression

Yes

Yes

Naive Bayes

No

No

Solvers

Yes

Nonlinear Models for Regression or Classification

Random Forest (RF) Classification

No

Partial

Random Forest (RF) Regression

No

Partial

Inference for decision tree-based models

No

No

Nearest Neighbors (NN)

Yes

Yes

K-Nearest Neighbors (KNN) Classification

Yes

Yes

K-Nearest Neighbors (KNN) Regression

Yes

Yes

Support Vector Machine Classifier (SVC)

No

No

Epsilon-Support Vector Regression (SVR)

No

No

Time Series

Holt-Winters Exponential Smoothing

No

No

Auto-regressive Integrated Moving Average (ARIMA)

No

No

This allows the same code to be guaranteed to run in both GPU and CPU systems. Version 23.12 is scheduled to add the following algorithms: - Random Forest - Support Vector Machine estimators

Installation#

For GPU systems, cuML still follows the RAPIDS requirements. The cuML package and wheels are universal and can run in both GPU and CPU modes. To use cuML in CPU-only systems, you can install using conda/mamba with:

mamba install -c rapidsai -c nvidia -c conda-forge cuml-cpu=23.10
# mamba install -c rapidsai-nightly -c nvidia -c conda-forge cuml-cpu=23.12 # for nightly builds
  • cuML 23.10 supports Linux and WSL2 on GPU and CPU systems using conda.

  • cuML 23.12 will bring support for pip wheels and MacOS support for CPU execution.

How to Use#

There are two main ways to use the CPU capabilities of cuML:

1. Using CPU Package directly#

The CPU package, cuml-cpu is a subset of the cuml package, so there are zero code changes required to run the code when using a CPU-only system. For example, the following script can be run both in a system with GPU and cuml, as well as a system without GPU and cuml-cpu:

[1]:
import cuml # no change is needed for even the importing!
import pandas as pd

from cuml.manifold.umap import UMAP
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.manifold import trustworthiness

# load the iris dataset from sklearn and extract the required information
iris = datasets.load_iris()
dataset = iris.data

iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)

# define the cuml UMAP model and use fit_transform function to obtain the low dimensional output of the input dataset
embedding = UMAP(
    n_neighbors=10, min_dist=0.01,  init="random"
).fit_transform(iris_df)

# calculate the trust worthiness of the results obtaind from the cuml UMAP
trust = trustworthiness(iris_df, embedding)
print(trust)
0.9822816901408451

This allows easy prototyping on CPU systems and running production code on GPU servers, or the other way around. Some estimators support training on one type of system and then exporting models to the other type, as noted above and explained by example in the corresponding section.

2. Managing Execution Platform with GPU package#

In addition to allowing the zero-code change execution in CPU systems, users can also manually control which device executes parts of the code when using a system with the full cuML.

For example, using the following data:

[2]:
import cuml
from cuml.neighbors import NearestNeighbors
from cuml.datasets import make_regression, make_blobs
from cuml.model_selection import train_test_split

X_blobs, y_blobs = make_blobs(n_samples=2000,
                              n_features=20)
X_train_blobs, X_test_blobs, y_train_blobs, y_test_blobs = train_test_split(X_blobs,
                                                                            y_blobs,
                                                                            test_size=0.2, shuffle=True)

X_reg, y_reg = make_regression(n_samples=2000,
                               n_features=20)
X_train_reg, X_test_reg, y_train_reg, y_tes_reg = train_test_split(X_reg,
                                                                   y_reg,
                                                                   test_size=0.2,
                                                                   shuffle=True)

There are two ways to control the execution of the code:

a) using_device_type context manager#

[3]:
from cuml.neighbors import NearestNeighbors
from cuml.common.device_selection import using_device_type

nn = NearestNeighbors()
with using_device_type('cpu'):
    nn.fit(X_train_blobs)
    nearest_neighbors = nn.kneighbors(X_test_blobs)

This makes it easy to prototype and run different estimators on different devices, for example in the case where data is small so that moving the data around wouldn’t allow the GPU to accelerate an estimator.

It also allows running estimators using unsupported parameters:

from cuml.manifold import UMAP

umap_model = UMAP(angular_rp_forest=True) # `angular_rp_forest` hyperparameter only available in UMAP library
with using_device_type('cpu'):
    umap_model.fit(X_train_blobs) # will run the UMAP library with the hyperparameter
with using_device_type('gpu'):
    transformed = umap_model.transform(X_test_blobs) # will run the cuML implementation of UMAP, ignoring the unsupported parameter.

An upcoming feature will allow for this dispatch to occur automatically under-the-hood. This can be very useful for when integrating cuML into other libraries, so that if users use parameters not supported on GPUs, the code automatically will dispatch to a CPU implementation.

b) Global configuration with set_global_device_type#

By default, cuml will execute estimators on the GPU/device. But it also allows a global configuration option to change the default device, which could be useful in shared systems where cuML is running alongside deep learning frameworks that are occupying most of a GPU. This can be accomplished with the set_global_device_type function:

[4]:
from cuml.common.device_selection import set_global_device_type, get_global_device_type

initial_device_type = get_global_device_type()
print('default execution device:', initial_device_type)
default execution device: DeviceType.device
[5]:
set_global_device_type('cpu')
print('new device type:', get_global_device_type())
new device type: DeviceType.host

Cross Device Training and Inference Serialization#

As stated above, a subset of the estimators support training on one type of device (CPU or GPU), serializing the trained model, and then deserializing and executing it on the other type of device.

To do this, a simple API is provided. For example, To train a model on GPU but deploy it on CPU, first, train the estimator on device and save it to disk:

[6]:
import pickle
from cuml.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X_train_reg, y_train_reg)

pickle.dump(lin_reg, open("lin_reg.pkl", "wb"))
del lin_reg
/opt/conda/envs/docs/lib/python3.11/site-packages/cuml/internals/api_decorators.py:382: UserWarning: Starting from version 23.08, the new 'copy_X' parameter defaults to 'True', ensuring a copy of X is created after passing it to fit(), preventing any changes to the input, but with increased memory usage. This represents a change in behavior from previous versions. With `copy_X=False` a copy might still be created if necessary. Explicitly set 'copy_X' to either True or False to suppress this warning.
  return init_func(self, *args, **filtered_kwargs)

Then, on the server/other device, recover the estimator on a node with cuml-cpu installed:

[7]:
recovered_lin_reg = pickle.load(open("lin_reg.pkl", "rb"))
predictions = recovered_lin_reg.predict(X_test_reg)
print(predictions[0:10])
[[ -12.657121]
 [-127.28822 ]
 [-139.03052 ]
 [ 125.29277 ]
 [  92.38602 ]
 [-191.14426 ]
 [ -71.97271 ]
 [ 121.1872  ]
 [  21.012989]
 [  29.913351]]

Conclusion#

cuML’s CPU capabilities are designed to facilitate different use cases, lower the barriers to using the capabilities of cuML, an streamline integrating cuML into other tools and deploying models.

Upcoming versions of cuML will expand the supported estimators, both for CPU execution as well as serializing/exporting models between systems with and without GPUs.