Pickling cuML Models for Persistence¶
This notebook demonstrates simple pickling of both single-GPU and multi-GPU cuML models for persistence
[1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
Single GPU Model Pickling¶
All single-GPU estimators are pickleable. The following example demonstrates the creation of a synthetic dataset, training, and pickling of the resulting model for storage. Trained single-GPU models can also be used to distribute the inference on a Dask cluster, which the Distributed Model Pickling
section below demonstrates.
[2]:
from cuml.datasets import make_blobs
X, y = make_blobs(n_samples=50,
n_features=10,
centers=5,
cluster_std=0.4,
random_state=0)
[3]:
from cuml.cluster import KMeans
model = KMeans(n_clusters=5)
model.fit(X)
[3]:
KMeans()
[4]:
import pickle
pickle.dump(model, open("kmeans_model.pkl", "wb"))
[5]:
model = pickle.load(open("kmeans_model.pkl", "rb"))
[6]:
model.cluster_centers_
[6]:
array([[ 4.6749854, 8.213466 , -9.075721 , 9.568374 , 8.454807 ,
-1.2327975, 3.3903713, -7.8282413, -0.8454461, 0.6288572],
[-4.243999 , 5.610707 , -5.6697764, -1.7957243, -9.255529 ,
0.7177438, 4.4435897, -2.874715 , -5.0900965, 9.684122 ],
[ 5.2615476, -4.0487256, 4.464928 , -2.9367518, 3.5061095,
-4.016832 , -3.463885 , 6.078449 , -6.9533257, -1.004144 ],
[-3.008261 , 4.625961 , -4.483249 , 2.228457 , 1.643532 ,
-2.4505193, -5.258201 , -1.6679403, -7.9857535, 2.8311467],
[-5.6072407, 2.2695985, -3.7516537, -1.8182003, -5.1430283,
7.599363 , 2.8252366, 8.773042 , 1.6198314, 1.1772047]],
dtype=float32)
Distributed Model Pickling¶
The distributed estimator wrappers inside of the cuml.dask
are not intended to be pickled directly. The Dask cuML estimators provide a function get_combined_model()
, which returns the trained single-GPU model for pickling. The combined model can be used for inference on a single-GPU, and the ParallelPostFit
wrapper from the Dask-ML library can be used to perform distributed inference on a Dask cluster.
[7]:
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster()
client = Client(cluster)
client
2022-05-16 13:38:28,763 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-05-16 13:38:28,836 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
[7]:
Client
Client-78596b18-d51d-11ec-8ced-5536f399c3b0
Connection method: Cluster object | Cluster type: dask_cuda.LocalCUDACluster |
Dashboard: http://127.0.0.1:8787/status |
Cluster Info
LocalCUDACluster
bdc3ff81
Dashboard: http://127.0.0.1:8787/status | Workers: 2 |
Total threads: 2 | Total memory: 45.79 GiB |
Status: running | Using processes: True |
Scheduler Info
Scheduler
Scheduler-de13fee9-149d-4340-b35f-a0b4d3b2a407
Comm: tcp://127.0.0.1:41097 | Workers: 2 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 2 |
Started: Just now | Total memory: 45.79 GiB |
Workers
Worker: 0
Comm: tcp://127.0.0.1:38555 | Total threads: 1 |
Dashboard: http://127.0.0.1:35565/status | Memory: 22.89 GiB |
Nanny: tcp://127.0.0.1:34641 | |
Local directory: /rapids/cuml/docs/source/dask-worker-space/worker-klxmx4pl | |
GPU: Quadro GV100 | GPU memory: 32.00 GiB |
Worker: 1
Comm: tcp://127.0.0.1:34817 | Total threads: 1 |
Dashboard: http://127.0.0.1:33535/status | Memory: 22.89 GiB |
Nanny: tcp://127.0.0.1:37453 | |
Local directory: /rapids/cuml/docs/source/dask-worker-space/worker-m2_axwlo | |
GPU: Quadro GV100 | GPU memory: 32.00 GiB |
[8]:
from cuml.dask.datasets import make_blobs
n_workers = len(client.scheduler_info()["workers"].keys())
X, y = make_blobs(n_samples=5000,
n_features=30,
centers=5,
cluster_std=0.4,
random_state=0,
n_parts=n_workers*5)
X = X.persist()
y = y.persist()
[9]:
from cuml.dask.cluster import KMeans
dist_model = KMeans(n_clusters=5)
[10]:
dist_model.fit(X)
2022-05-16 13:38:33,765 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-c5fc61e2-d409-4335-8864-57f93143e9da
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,787 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-a5fc0bf5-e2d2-4198-b9dd-142179bffa02
Function: _get_model_attr
args: (KMeansMG(), '_ipython_display_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_display_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,808 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-dbcfe929-e40c-4e52-bb35-048ae69c446b
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,831 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-cca8f96d-9eec-4c62-8546-757735fb4193
Function: _get_model_attr
args: (KMeansMG(), '_repr_mimebundle_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_mimebundle_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,853 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-0d025c6e-eab8-4384-bdb2-e4e1fa49114e
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,865 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-3ec63c37-0b7d-4185-b75d-8f1432a14c35
Function: _get_model_attr
args: (KMeansMG(), '_repr_html_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_html_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,878 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-6b14f5b3-2775-476f-abc8-f92bdcd0f9b3
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,892 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-18caa2ae-5e0a-43e1-8668-bfe9dcfdf948
Function: _get_model_attr
args: (KMeansMG(), '_repr_markdown_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_markdown_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,906 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-5535d684-c92a-4b05-8b4a-7ee5e3e07a58
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,927 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-05cedf5d-fdc4-4a6b-8486-358eea02a962
Function: _get_model_attr
args: (KMeansMG(), '_repr_svg_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_svg_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,941 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-93491d86-d274-4942-b219-646e7b23d939
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,965 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-a84eeb4f-3255-40db-b88d-2cf0cc2ffd8b
Function: _get_model_attr
args: (KMeansMG(), '_repr_png_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_png_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:33,988 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-2aacc3ab-5ec6-4912-a19a-288978327cc7
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,011 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-048d61c6-e977-4989-8dce-356f0efa580e
Function: _get_model_attr
args: (KMeansMG(), '_repr_pdf_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_pdf_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,034 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-617885f7-fbe4-4db3-8a6a-60184e20b1d2
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,058 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-918b21e7-b8a4-411b-9fb5-726cdfa689e9
Function: _get_model_attr
args: (KMeansMG(), '_repr_jpeg_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_jpeg_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,081 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-9a0109da-681d-49a4-8dfe-17174e5ae584
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,103 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-9d449f48-c4a8-478e-bf21-3198330950cd
Function: _get_model_attr
args: (KMeansMG(), '_repr_latex_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_latex_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,124 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-56d2d80f-cb71-41a1-b251-f56e5a5f7f45
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,149 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-8e33c7ae-c608-4379-978c-39e7dce771a0
Function: _get_model_attr
args: (KMeansMG(), '_repr_json_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_json_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,174 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-eb8c19b5-562f-4d4c-9d82-19842dc1b471
Function: _get_model_attr
args: (KMeansMG(), '_ipython_canary_method_should_not_exist_')
kwargs: {}
Exception: 'AttributeError("Attribute _ipython_canary_method_should_not_exist_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
2022-05-16 13:38:34,197 - distributed.worker - WARNING - Compute Failed
Key: _get_model_attr-e3f42774-384c-45c6-8802-06f53720dfcc
Function: _get_model_attr
args: (KMeansMG(), '_repr_javascript_')
kwargs: {}
Exception: 'AttributeError("Attribute _repr_javascript_ does not exist on model <class \'cuml.cluster.kmeans_mg.KMeansMG\'>")'
[10]:
<cuml.dask.cluster.kmeans.KMeans at 0x7f6e7018c3a0>
[11]:
import pickle
single_gpu_model = dist_model.get_combined_model()
pickle.dump(single_gpu_model, open("kmeans_model.pkl", "wb"))
[12]:
single_gpu_model = pickle.load(open("kmeans_model.pkl", "rb"))
[13]:
single_gpu_model.cluster_centers_
[13]:
array([[-2.8719957 , 4.4702854 , -4.4291744 , 2.3982146 , 1.7456868 ,
-2.492674 , -5.21906 , -1.70623 , -8.130277 , 2.6402013 ,
-4.312218 , 5.576193 , -5.7422023 , -1.7193272 , -9.359728 ,
0.72390175, 4.4446826 , -2.9161723 , -4.9331355 , 9.68974 ,
8.391241 , -6.2390246 , -6.368295 , 1.9610256 , 4.1622376 ,
-9.158494 , 4.6112323 , 8.802037 , 6.856204 , 2.2432854 ],
[ 6.2584553 , 9.214628 , 8.375932 , 9.034932 , 7.707301 ,
-1.0178816 , -6.2575994 , 1.3853358 , -6.9605255 , -5.9653416 ,
1.0682615 , -0.02890242, 2.810602 , 1.8463548 , -8.250103 ,
3.0578778 , -8.495707 , 9.739371 , -7.7505846 , 3.433176 ,
-3.9421573 , -4.1097155 , 2.6854286 , 1.2847115 , 1.0240605 ,
5.2621164 , -1.6469141 , 6.161966 , -6.9160676 , -9.655835 ],
[-6.926763 , -9.769037 , -6.5095406 , -0.43609676, 6.1015596 ,
3.7510426 , -3.9631264 , 6.1858416 , -1.8483105 , 5.030259 ,
-6.8440013 , 1.3500856 , 8.995467 , -1.0015213 , 9.676681 ,
9.767222 , -8.614724 , 5.985621 , 2.2195487 , -3.628886 ,
7.0963764 , -7.395383 , -5.312048 , -6.9717402 , -7.9194913 ,
6.669967 , -5.576503 , 7.13301 , 6.602726 , -8.295306 ],
[ 4.7999144 , 8.405615 , -9.217355 , 9.391141 , 8.513527 ,
-1.0933163 , 3.3241372 , -7.807658 , -0.5957081 , 0.25495645,
5.516926 , -4.113342 , 4.290226 , -2.841153 , 3.6347358 ,
-4.1733875 , -3.6206803 , 6.2202787 , -6.9130416 , -1.084563 ,
-5.854552 , 2.2391484 , -3.8561985 , -1.6751809 , -5.32013 ,
7.5758333 , 2.9306953 , 8.522391 , 1.5873817 , 1.0953335 ],
[-4.657932 , -9.604437 , 6.6668396 , 4.433426 , 2.1555626 ,
2.6000512 , 0.6004752 , 6.2651596 , -8.827596 , -0.39495564,
9.799334 , 7.58164 , 10.005553 , -5.8738556 , -1.2868681 ,
-2.5448568 , -1.0834315 , -5.2448344 , -9.322358 , 4.6100206 ,
-0.09926538, -3.9473362 , 6.1890383 , -7.3990617 , 5.6591 ,
-8.5463705 , -7.526781 , -5.555117 , 4.843207 , 2.5287533 ]],
dtype=float32)
Exporting cuML Random Forest models for inferencing on machines without GPUs¶
Starting with cuML version 21.06, you can export cuML Random Forest models and run predictions with them on machines without an NVIDIA GPUs. The Treelite package defines an efficient exchange format that lets you portably move the cuML Random Forest models to other machines. We will refer to the exchange format as “checkpoints.”
Here are the steps to export the model:
Call
to_treelite_checkpoint()
to obtain the checkpoint file from the cuML Random Forest model.
[14]:
from cuml.ensemble import RandomForestClassifier as cumlRandomForestClassifier
from sklearn.datasets import load_iris
import numpy as np
X, y = load_iris(return_X_y=True)
X, y = X.astype(np.float32), y.astype(np.int32)
clf = cumlRandomForestClassifier(max_depth=3, random_state=0, n_estimators=10)
clf.fit(X, y)
checkpoint_path = './checkpoint.tl'
# Export cuML RF model as Treelite checkpoint
clf.convert_to_treelite_model().to_treelite_checkpoint(checkpoint_path)
/opt/conda/envs/rapids/lib/python3.9/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams=1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
return func(**kwargs)
Copy the generated checkpoint file
checkpoint.tl
to another machine on which you’d like to run predictions.On the target machine, install Treelite by running
pip install treelite
orconda install -c conda-forge treelite
. The machine does not need to have an NVIDIA GPUs and does not need to have cuML installed.You can now load the model from the checkpoint, by running the following on the target machine:
[15]:
import treelite
# The checkpoint file has been copied over
checkpoint_path = './checkpoint.tl'
tl_model = treelite.Model.deserialize(checkpoint_path)
out_prob = treelite.gtil.predict(tl_model, X, pred_margin=True)
print(out_prob)
[[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[0.9 0.1 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[1. 0. 0. ]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.5274104 0.47258964]
[0. 0.9870021 0.0129979 ]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.7036687 0.29633126]
[0. 0.9981133 0.00188679]
[0. 0.9870021 0.0129979 ]
[0. 0.9536688 0.04633124]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.23634085 0.7636592 ]
[0. 0.9981133 0.00188679]
[0. 0.567213 0.432787 ]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.60401857 0.3959814 ]
[0. 0.41788664 0.5821134 ]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9870021 0.0129979 ]
[0. 0.9870021 0.0129979 ]
[0. 0.9981133 0.00188679]
[0. 0.30794033 0.6920597 ]
[0.1 0.89811325 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9870021 0.0129979 ]
[0. 0.9870021 0.0129979 ]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9536688 0.04633124]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9981133 0.00188679]
[0. 0.9870021 0.0129979 ]
[0. 0.9981133 0.00188679]
[0. 0.01794118 0.9820588 ]
[0. 0.05149381 0.9485062 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01169118 0.9883088 ]
[0. 0.5655735 0.4344265 ]
[0. 0.01169118 0.9883088 ]
[0. 0.05149381 0.9485062 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01794118 0.9820588 ]
[0. 0.05149381 0.9485062 ]
[0. 0.01169118 0.9883088 ]
[0. 0.19290936 0.80709064]
[0. 0.03274381 0.9672562 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01794118 0.9820588 ]
[0. 0.05149381 0.9485062 ]
[0. 0.44935593 0.5506441 ]
[0. 0.01794118 0.9820588 ]
[0. 0.2170165 0.78298354]
[0. 0.03274381 0.9672562 ]
[0. 0.2357665 0.7642335 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01794118 0.9820588 ]
[0. 0.2511435 0.74885654]
[0. 0.19596387 0.80403614]
[0. 0.03274381 0.9672562 ]
[0. 0.10335784 0.8966421 ]
[0. 0.03274381 0.9672562 ]
[0. 0.01794118 0.9820588 ]
[0. 0.03274381 0.9672562 ]
[0. 0.28919035 0.7108097 ]
[0. 0.20982714 0.79017293]
[0. 0.01169118 0.9883088 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01169118 0.9883088 ]
[0. 0.23009086 0.7699092 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01169118 0.9883088 ]
[0. 0.01169118 0.9883088 ]
[0. 0.05149381 0.9485062 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01169118 0.9883088 ]
[0. 0.19290936 0.80709064]
[0. 0.01169118 0.9883088 ]
[0. 0.01794118 0.9820588 ]
[0. 0.01169118 0.9883088 ]]