Pickling cuML Models for Persistence

This notebook demonstrates simple pickling of both single-GPU and multi-GPU cuML models for persistence

[1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

Single GPU Model Pickling

All single-GPU estimators are pickleable. The following example demonstrates the creation of a synthetic dataset, training, and pickling of the resulting model for storage. Trained single-GPU models can also be used to distribute the inference on a Dask cluster, which the Distributed Model Pickling section below demonstrates.

[2]:
from cuml.datasets import make_blobs

X, y = make_blobs(n_samples=50,
                  n_features=10,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0)
[3]:
from cuml.cluster import KMeans

model = KMeans(n_clusters=5)

model.fit(X)
[3]:
KMeans()
[4]:
import pickle

pickle.dump(model, open("kmeans_model.pkl", "wb"))
[5]:
model = pickle.load(open("kmeans_model.pkl", "rb"))
[6]:
model.cluster_centers_
[6]:
array([[ 4.6749854 ,  8.213466  , -9.075721  ,  9.568374  ,  8.454808  ,
        -1.2327975 ,  3.390371  , -7.8282413 , -0.8454461 ,  0.62885725],
       [-4.243999  ,  5.610707  , -5.669777  , -1.7957243 , -9.255528  ,
         0.7177438 ,  4.4435897 , -2.874715  , -5.0900965 ,  9.684121  ],
       [ 5.2615476 , -4.0487256 ,  4.464928  , -2.9367518 ,  3.5061095 ,
        -4.016832  , -3.4638855 ,  6.078449  , -6.953326  , -1.004144  ],
       [-3.008261  ,  4.6259604 , -4.4832487 ,  2.228457  ,  1.6435319 ,
        -2.4505196 , -5.258201  , -1.6679403 , -7.9857535 ,  2.8311467 ],
       [-5.6072407 ,  2.2695985 , -3.7516537 , -1.8182003 , -5.1430283 ,
         7.599363  ,  2.8252368 ,  8.773042  ,  1.6198314 ,  1.1772048 ]],
      dtype=float32)

Distributed Model Pickling

The distributed estimator wrappers inside of the cuml.dask are not intended to be pickled directly. The Dask cuML estimators provide a function get_combined_model(), which returns the trained single-GPU model for pickling. The combined model can be used for inference on a single-GPU, and the ParallelPostFit wrapper from the Dask-ML library can be used to perform distributed inference on a Dask cluster.

[7]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

cluster = LocalCUDACluster()
client = Client(cluster)
client
2022-10-17 18:03:42,611 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2022-10-17 18:03:42,611 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-10-17 18:03:42,612 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2022-10-17 18:03:42,612 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
[7]:

Client

Client-0957760b-4e46-11ed-8498-0242ac110002

Connection method: Cluster object Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

[8]:
from cuml.dask.datasets import make_blobs

n_workers = len(client.scheduler_info()["workers"].keys())

X, y = make_blobs(n_samples=5000,
                  n_features=30,
                  centers=5,
                  cluster_std=0.4,
                  random_state=0,
                  n_parts=n_workers*5)

X = X.persist()
y = y.persist()
[9]:
from cuml.dask.cluster import KMeans

dist_model = KMeans(n_clusters=5)
[10]:
dist_model.fit(X)
2022-10-17 18:03:48,934 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-55fc5ec5-16ef-4b0d-bbb4-4178b83c23d1
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:48,945 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-ceec6c5b-2555-4b77-b45a-b08c63ca6575
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_display_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_display_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:48,970 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-14be6f07-b8c9-43fe-814f-c2a92e1d5f69
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:48,994 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-5054efd5-e5e2-49ea-a6bf-650a155b1a53
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_mimebundle_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_mimebundle_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,016 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-ab4ed533-e0ab-4968-8d3a-ffac6d6c0e93
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,039 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-64b5753b-e257-4049-8c3b-5e540d4a4546
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_html_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_html_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,061 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-d37558e0-f945-4357-9ebb-7ff6f25ac62c
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,082 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-1b3c5658-51b2-4f56-949b-0f389e66c7dc
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_markdown_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_markdown_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,103 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-8de34c1f-7348-4b4d-a783-2315adad1e70
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,126 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-0b9b8465-e4a8-4c5f-a464-36b28e15c4fe
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_svg_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_svg_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,148 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-01fccfff-5a64-4f21-b3b1-35a73e29cc8d
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,169 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-0af9c862-8887-4131-af39-dcc9d370d4e8
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_png_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_png_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,186 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-d12f47dc-50f9-43e7-8fa2-7dfb5fc64bf5
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,207 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-ad32e580-72b2-4669-9fbc-13581c88b9fb
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_pdf_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_pdf_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,229 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-255e9fb3-28d9-4703-8cad-75c237ff2ea9
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,250 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-8c23df13-6082-4821-93a4-9fcf17e2f7bc
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_jpeg_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_jpeg_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,270 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-4fddd13d-552e-46ac-9c25-6828694f3aeb
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,292 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-cac16123-6b4b-425d-8724-1801aae67f25
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_latex_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_latex_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,306 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-73f0df53-7e84-4e7d-a37b-0b6a41451baf
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,327 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-3708fd80-2178-4e3b-935e-ec172b5d548d
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_json_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_json_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,349 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-3425ce2a-f1c2-4eef-990d-8ef3f8bb1c7f
Function:  _get_model_attr
args:      (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs:    {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

2022-10-17 18:03:49,371 - distributed.worker - WARNING - Compute Failed
Key:       _get_model_attr-56cb5716-fd7e-465d-911b-7ab635a05096
Function:  _get_model_attr
args:      (KMeansMG(), '_repr_javascript_')
kwargs:    {}
Exception: 'AttributeError("Attribute _repr_javascript_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'

[10]:
<cuml.dask.cluster.kmeans.KMeans at 0x7ff932c54f10>
[11]:
import pickle

single_gpu_model = dist_model.get_combined_model()
pickle.dump(single_gpu_model, open("kmeans_model.pkl", "wb"))
[12]:
single_gpu_model = pickle.load(open("kmeans_model.pkl", "rb"))
[13]:
single_gpu_model.cluster_centers_
[13]:
array([[-4.6557665 , -9.605046  ,  6.6638994 ,  4.435602  ,  2.156688  ,
         2.599673  ,  0.6010148 ,  6.262877  , -8.829993  , -0.39445522,
         9.801517  ,  7.584996  , 10.004724  , -5.8716373 , -1.2833364 ,
        -2.5475292 , -1.0870931 , -5.2439094 , -9.321915  ,  4.6081567 ,
        -0.10159794, -3.9462247 ,  6.1869664 , -7.401992  ,  5.6567736 ,
        -8.548569  , -7.5288224 , -5.5547514 ,  4.849038  ,  2.5301983 ],
       [ 4.7991467 ,  8.40242   , -9.21459   ,  9.392471  ,  8.512869  ,
        -1.0980052 ,  3.3258238 , -7.80285   , -0.5990244 ,  0.25806776,
         5.5174656 , -4.113201  ,  4.29229   , -2.8411753 ,  3.6327324 ,
        -4.173102  , -3.6205482 ,  6.2173705 , -6.9105277 , -1.084521  ,
        -5.8539176 ,  2.2375815 , -3.8543427 , -1.6783282 , -5.322575  ,
         7.575617  ,  2.9321434 ,  8.521326  ,  1.5875131 ,  1.0917971 ],
       [-2.872203  ,  4.469733  , -4.431363  ,  2.3996613 ,  1.7438418 ,
        -2.4938552 , -5.2212667 , -1.7067925 , -8.130271  ,  2.6409218 ,
        -4.307933  ,  5.579306  , -5.741948  , -1.7193332 , -9.359335  ,
         0.7162489 ,  4.4438004 , -2.917387  , -4.9321446 ,  9.692951  ,
         8.393692  , -6.2387233 , -6.363846  ,  1.963377  ,  4.162584  ,
        -9.159682  ,  4.6117425 ,  8.8011265 ,  6.8551817 ,  2.2458148 ],
       [-6.928107  , -9.766996  , -6.5138397 , -0.43525633,  6.100162  ,
         3.7533102 , -3.9653108 ,  6.1827755 , -1.850568  ,  5.028263  ,
        -6.843763  ,  1.3515666 ,  8.996503  , -1.0031245 ,  9.674829  ,
         9.7697115 , -8.616943  ,  5.982676  ,  2.2226048 , -3.6281207 ,
         7.0979915 , -7.3974366 , -5.3140364 , -6.9729123 , -7.9171224 ,
         6.6703353 , -5.5767226 ,  7.13434   ,  6.606858  , -8.299497  ],
       [ 6.2599006 ,  9.218424  ,  8.374798  ,  9.035377  ,  7.7094774 ,
        -1.0123167 , -6.2563047 ,  1.3844215 , -6.956254  , -5.965097  ,
         1.0701916 , -0.02766196,  2.811688  ,  1.8430991 , -8.250472  ,
         3.0570164 , -8.49589   ,  9.738966  , -7.748305  ,  3.4321895 ,
        -3.9439018 , -4.113308  ,  2.6874824 ,  1.2842503 ,  1.019016  ,
         5.26193   , -1.6500072 ,  6.1615205 , -6.911384  , -9.656554  ]],
      dtype=float32)

Exporting cuML Random Forest models for inferencing on machines without GPUs

Starting with cuML version 21.06, you can export cuML Random Forest models and run predictions with them on machines without an NVIDIA GPUs. The Treelite package defines an efficient exchange format that lets you portably move the cuML Random Forest models to other machines. We will refer to the exchange format as “checkpoints.”

Here are the steps to export the model:

  1. Call to_treelite_checkpoint() to obtain the checkpoint file from the cuML Random Forest model.

[14]:
from cuml.ensemble import RandomForestClassifier as cumlRandomForestClassifier
from sklearn.datasets import load_iris
import numpy as np

X, y = load_iris(return_X_y=True)
X, y = X.astype(np.float32), y.astype(np.int32)
clf = cumlRandomForestClassifier(max_depth=3, random_state=0, n_estimators=10)
clf.fit(X, y)

checkpoint_path = './checkpoint.tl'
# Export cuML RF model as Treelite checkpoint
clf.convert_to_treelite_model().to_treelite_checkpoint(checkpoint_path)
/opt/conda/envs/rapids/lib/python3.9/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams=1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
  1. Copy the generated checkpoint file checkpoint.tl to another machine on which you’d like to run predictions.

  2. On the target machine, install Treelite by running pip install treelite or conda install -c conda-forge treelite. The machine does not need to have an NVIDIA GPUs and does not need to have cuML installed.

  3. You can now load the model from the checkpoint, by running the following on the target machine:

[15]:
import treelite

# The checkpoint file has been copied over
checkpoint_path = './checkpoint.tl'
tl_model = treelite.Model.deserialize(checkpoint_path)
out_prob = treelite.gtil.predict(tl_model, X, pred_margin=True)
print(out_prob)
[[1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [1.         0.         0.        ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.8317397  0.16826029]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.16169165 0.83830833]
 [0.         0.9941856  0.0058144 ]
 [0.         0.8317397  0.16826029]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.8317397  0.16826029]
 [0.         0.5163647  0.48363543]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.3457792  0.65422076]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.9841856  0.0158144 ]
 [0.         0.9941856  0.0058144 ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.02447553 0.9755244 ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.62521124 0.37478876]
 [0.         0.02447553 0.9755244 ]
 [0.         0.02447553 0.9755244 ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.04102564 0.95897436]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.02447553 0.9755244 ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.600339   0.399661  ]
 [0.         0.         1.        ]
 [0.         0.10388279 0.8961172 ]
 [0.         0.         1.        ]
 [0.         0.12835832 0.87164164]
 [0.         0.         1.        ]
 [0.         0.02447553 0.9755244 ]
 [0.         0.16169165 0.83830833]
 [0.         0.12835832 0.87164164]
 [0.         0.         1.        ]
 [0.         0.3457792  0.65422076]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.3457792  0.65422076]
 [0.         0.3457792  0.65422076]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.02447553 0.9755244 ]
 [0.         0.16169165 0.83830833]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.04102564 0.95897436]
 [0.         0.         1.        ]
 [0.         0.         1.        ]
 [0.         0.02447553 0.9755244 ]]