{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exporting cuml.accel Models to ONNX" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Models trained under `cuml.accel` can be exported to the [ONNX](https://onnx.ai/) format using [sklearn-onnx](https://onnx.ai/sklearn-onnx/). Because `cuml.accel` proxy objects are recognized by `sklearn-onnx` as scikit-learn estimators, you can pass them directly to `convert_sklearn()` without any conversion step.\n", "\n", "The resulting `.onnx` file can then be loaded with [ONNX Runtime](https://onnxruntime.ai/) for inference on both CPU and GPU, with no cuML dependency at inference time.\n", "\n", "### Supported Estimators\n", "\n", "The following `cuml.accel` estimators have been tested with `sklearn-onnx`:\n", "\n", "- **Classifiers:** `RandomForestClassifier`, `KNeighborsClassifier`, `LinearSVC`\n", "- **Regressors:** `RandomForestRegressor`, `KNeighborsRegressor`, `LinearSVR`\n", "- **Transformers/Clusterers:** `PCA`, `TruncatedSVD`, `KMeans`\n", "\n", "Estimators not listed here (e.g. `SVC`, `SVR`, `DBSCAN`, `TSNE`, `HDBSCAN`, `UMAP`) are **not currently supported** for ONNX export." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext cuml.accel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a Model\n", "\n", "Train a `RandomForestClassifier` on the Iris dataset. Under `cuml.accel`, this runs on the GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import load_iris\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "\n", "# ONNX requires float32 input\n", "X, y = load_iris(return_X_y=True)\n", "X = X.astype(np.float32)\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "clf = RandomForestClassifier(n_estimators=10, max_depth=3, random_state=0)\n", "clf.fit(X_train, y_train)\n", "\n", "accel_predictions = clf.predict(X_test)\n", "print(f\"Training accuracy: {clf.score(X_train, y_train):.4f}\")\n", "print(f\"Test accuracy: {clf.score(X_test, y_test):.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exporting to ONNX\n", "\n", "Convert the fitted model to ONNX using `skl2onnx.convert_sklearn()`. The proxy object is passed directly — no `as_sklearn()` call needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skl2onnx import convert_sklearn\n", "from skl2onnx.common.data_types import FloatTensorType\n", "\n", "# zipmap=False returns class probabilities as a 2D array instead of a\n", "# list of dicts. This option only applies to classifiers with predict_proba.\n", "initial_type = [(\"float_input\", FloatTensorType([None, X_train.shape[1]]))]\n", "onnx_model = convert_sklearn(\n", " clf, initial_types=initial_type, options={\"zipmap\": False}\n", ")\n", "\n", "onnx_path = \"./rf_classifier.onnx\"\n", "with open(onnx_path, \"wb\") as f:\n", " f.write(onnx_model.SerializeToString())\n", "\n", "print(f\"ONNX model saved to {onnx_path} ({len(onnx_model.SerializeToString())} bytes)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference with ONNX Runtime\n", "\n", "The saved `.onnx` file can be loaded on any machine with `onnxruntime` installed — no cuML or `sklearn-onnx` needed. Use `onnxruntime` for CPU inference, or `onnxruntime-gpu` for GPU inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import onnxruntime as ort\n", "\n", "sess = ort.InferenceSession(onnx_path)\n", "input_name = sess.get_inputs()[0].name\n", "onnx_results = sess.run(None, {input_name: X_test})\n", "\n", "onnx_predictions = onnx_results[0]\n", "onnx_probabilities = onnx_results[1]\n", "\n", "print(f\"ONNX predictions shape: {onnx_predictions.shape}\")\n", "print(f\"ONNX probabilities shape: {onnx_probabilities.shape}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.13.7" } }, "nbformat": 4, "nbformat_minor": 4 }