Attention

The vector search and clustering algorithms in RAFT are being migrated to a new library dedicated to vector search called cuVS. We will continue to support the vector search algorithms in RAFT during this move, but will no longer update them after the RAPIDS 24.06 (June) release. We plan to complete the migration by RAPIDS 24.08 (August) release.

K-Means#

#include <raft/cluster/kmeans.cuh>

template<typename DataT, typename IndexT>
using raft::cluster::kmeans::SamplingOp = detail::SamplingOp<DataT, IndexT>#

Functor used for sampling centroids

template<typename IndexT, typename DataT>
using raft::cluster::kmeans::KeyValueIndexOp = detail::KeyValueIndexOp<IndexT, DataT>#

Functor used to extract the index from a KeyValue pair storing both index and a distance.

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::fit(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight, raft::device_matrix_view<DataT, IndexT> centroids, raft::host_scalar_view<DataT> inertia, raft::host_scalar_view<IndexT> n_iter)#

Find clusters with k-means algorithm. Initial centroids are chosen with k-means++ algorithm. Empty clusters are reinitialized by choosing new centroids with k-means++ algorithm.

#include <raft/core/resources.hpp>
#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
using namespace raft::cluster;
...
raft::raft::resources handle;
raft::cluster::KMeansParams params;
int n_features = 15, inertia, n_iter;
auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);

kmeans::fit(handle,
            params,
            X,
            std::nullopt,
            centroids,
            raft::make_scalar_view(&inertia),
            raft::make_scalar_view(&n_iter));
Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle.

  • params[in] Parameters for KMeans model.

  • X[in] Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features]

  • sample_weight[in] Optional weights for each observation in X. [len = n_samples]

  • centroids[inout] [in] When init is InitMethod::Array, use centroids as the initial cluster centers. [out] The generated centroids from the kmeans algorithm are stored at the address pointed by ‘centroids’. [dim = n_clusters x n_features]

  • inertia[out] Sum of squared distances of samples to their closest cluster center.

  • n_iter[out] Number of iterations run.

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::predict(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight, raft::device_matrix_view<const DataT, IndexT> centroids, raft::device_vector_view<IndexT, IndexT> labels, bool normalize_weight, raft::host_scalar_view<DataT> inertia)#

Predict the closest cluster each sample in X belongs to.

#include <raft/core/resources.hpp>
#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
using namespace raft::cluster;
...
raft::raft::resources handle;
raft::cluster::KMeansParams params;
int n_features = 15, inertia, n_iter;
auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);

kmeans::fit(handle,
            params,
            X,
            std::nullopt,
            centroids.view(),
            raft::make_scalar_view(&inertia),
            raft::make_scalar_view(&n_iter));
...
auto labels = raft::make_device_vector<int, int>(handle, X.extent(0));

kmeans::predict(handle,
                params,
                X,
                std::nullopt,
                centroids.view(),
                false,
                labels.view(),
                raft::make_scalar_view(&ineratia));
Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle.

  • params[in] Parameters for KMeans model.

  • X[in] New data to predict. [dim = n_samples x n_features]

  • sample_weight[in] Optional weights for each observation in X. [len = n_samples]

  • centroids[in] Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features]

  • normalize_weight[in] True if the weights should be normalized

  • labels[out] Index of the cluster each sample in X belongs to. [len = n_samples]

  • inertia[out] Sum of squared distances of samples to their closest cluster center.

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::fit_predict(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight, std::optional<raft::device_matrix_view<DataT, IndexT>> centroids, raft::device_vector_view<IndexT, IndexT> labels, raft::host_scalar_view<DataT> inertia, raft::host_scalar_view<IndexT> n_iter)#

Compute k-means clustering and predicts cluster index for each sample in the input.

#include <raft/core/resources.hpp>
#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>
using namespace raft::cluster;
...
raft::raft::resources handle;
raft::cluster::KMeansParams params;
int n_features = 15, inertia, n_iter;
auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);
auto labels = raft::make_device_vector<int, int>(handle, X.extent(0));

kmeans::fit_predict(handle,
                    params,
                    X,
                    std::nullopt,
                    centroids.view(),
                    labels.view(),
                    raft::make_scalar_view(&inertia),
                    raft::make_scalar_view(&n_iter));
Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle.

  • params[in] Parameters for KMeans model.

  • X[in] Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features]

  • sample_weight[in] Optional weights for each observation in X. [len = n_samples]

  • centroids[inout] Optional [in] When init is InitMethod::Array, use centroids as the initial cluster centers [out] The generated centroids from the kmeans algorithm are stored at the address pointed by ‘centroids’. [dim = n_clusters x n_features]

  • labels[out] Index of the cluster each sample in X belongs to. [len = n_samples]

  • inertia[out] Sum of squared distances of samples to their closest cluster center.

  • n_iter[out] Number of iterations run.

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::transform(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<const DataT, IndexT> centroids, raft::device_matrix_view<DataT, IndexT> X_new)#

Transform X to a cluster-distance space.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle.

  • params[in] Parameters for KMeans model.

  • X[in] Training instances to cluster. The data must be in row-major format [dim = n_samples x n_features]

  • centroids[in] Cluster centroids. The data must be in row-major format. [dim = n_clusters x n_features]

  • X_new[out] X transformed in the new space. [dim = n_samples x n_features]

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::transform(raft::resources const &handle, const KMeansParams &params, const DataT *X, const DataT *centroids, IndexT n_samples, IndexT n_features, DataT *X_new)#
template<typename idx_t, typename value_t>
void raft::cluster::kmeans::find_k(raft::resources const &handle, raft::device_matrix_view<const value_t, idx_t> X, raft::host_scalar_view<idx_t> best_k, raft::host_scalar_view<value_t> inertia, raft::host_scalar_view<idx_t> n_iter, idx_t kmax, idx_t kmin = 1, idx_t maxiter = 100, value_t tol = 1e-3)#

Automatically find the optimal value of k using a binary search. This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia.

#include <raft/core/handle.hpp>
#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/kmeans_types.hpp>

#include <raft/random/make_blobs.cuh>

using namespace raft::cluster;

raft::handle_t handle;
int n_samples = 100, n_features = 15, n_clusters = 10;
auto X = raft::make_device_matrix<float, int>(handle, n_samples, n_features);
auto labels = raft::make_device_vector<float, int>(handle, n_samples);

raft::random::make_blobs(handle, X, labels, n_clusters);

auto best_k = raft::make_host_scalar<int>(0);
auto n_iter = raft::make_host_scalar<int>(0);
auto inertia = raft::make_host_scalar<int>(0);

kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1);
Template Parameters:
  • idx_t – indexing type (should be integral)

  • value_t – value type (should be floating point)

Parameters:
  • handle – raft handle

  • X – input observations (shape n_samples, n_dims)

  • best_k – best k found from binary search

  • inertia – inertia of best k found

  • n_iter – number of iterations used to find best k

  • kmax – maximum k to try in search

  • kmin – minimum k to try in search (should be >= 1)

  • maxiter – maximum number of iterations to run

  • tol – tolerance for early stopping convergence

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::sample_centroids(raft::resources const &handle, raft::device_matrix_view<const DataT, IndexT> X, raft::device_vector_view<DataT, IndexT> minClusterDistance, raft::device_vector_view<std::uint8_t, IndexT> isSampleCentroid, SamplingOp<DataT, IndexT> &select_op, rmm::device_uvector<DataT> &inRankCp, rmm::device_uvector<char> &workspace)#

Select centroids according to a sampling operation.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • X[in] The data in row-major format [dim = n_samples x n_features]

  • minClusterDistance[in] Distance for every sample to it’s nearest centroid [dim = n_samples]

  • isSampleCentroid[in] Flag the sample chosen as initial centroid [dim = n_samples]

  • select_op[in] The sampling operation used to select the centroids

  • inRankCp[out] The sampled centroids [dim = n_selected_centroids x n_features]

  • workspace[in] Temporary workspace buffer which can get resized

template<typename DataT, typename IndexT, typename ReductionOpT>
void raft::cluster::kmeans::cluster_cost(raft::resources const &handle, raft::device_vector_view<DataT, IndexT> minClusterDistance, rmm::device_uvector<char> &workspace, raft::device_scalar_view<DataT> clusterCost, ReductionOpT reduction_op)#

Compute cluster cost.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • ReductionOpT – the type of data used for the reduction operation.

Parameters:
  • handle[in] The raft handle

  • minClusterDistance[in] Distance for every sample to it’s nearest centroid [dim = n_samples]

  • workspace[in] Temporary workspace buffer which can get resized

  • clusterCost[out] Resulting cluster cost

  • reduction_op[in] The reduction operation used for the cost

template<typename DataT, typename IndexT, typename LabelsIterator>
void raft::cluster::kmeans::update_centroids(raft::resources const &handle, raft::device_matrix_view<const DataT, IndexT, row_major> X, raft::device_vector_view<const DataT, IndexT> sample_weights, raft::device_matrix_view<const DataT, IndexT, row_major> centroids, LabelsIterator labels, raft::device_vector_view<DataT, IndexT> weight_per_cluster, raft::device_matrix_view<DataT, IndexT, row_major> new_centroids)#

Update centroids given current centroids and number of points assigned to each centroid. This function also produces a vector of RAFT key/value pairs containing the cluster assignment for each point and its distance.

Template Parameters:
  • DataT

  • IndexT

Parameters:
  • handle[in] Raft handle to use for managing library resources

  • X[in] input matrix (size n_samples, n_features)

  • sample_weights[in] number of samples currently assigned to each centroid (size n_samples)

  • centroids[in] matrix of current centroids (size n_clusters, n_features)

  • labels[in] Iterator of labels (can also be a raw pointer)

  • weight_per_cluster[out] sum of sample weights per cluster (size n_clusters)

  • new_centroids[out] output matrix of updated centroids (size n_clusters, n_features)

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::min_cluster_distance(raft::resources const &handle, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<DataT, IndexT> centroids, raft::device_vector_view<DataT, IndexT> minClusterDistance, raft::device_vector_view<DataT, IndexT> L2NormX, rmm::device_uvector<DataT> &L2NormBuf_OR_DistBuf, raft::distance::DistanceType metric, int batch_samples, int batch_centroids, rmm::device_uvector<char> &workspace)#

Compute distance for every sample to it’s nearest centroid.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • X[in] The data in row-major format [dim = n_samples x n_features]

  • centroids[in] Centroids data [dim = n_cluster x n_features]

  • minClusterDistance[out] Distance for every sample to it’s nearest centroid [dim = n_samples]

  • L2NormX[in] L2 norm of X : ||x||^2 [dim = n_samples]

  • L2NormBuf_OR_DistBuf[out] Resizable buffer to store L2 norm of centroids or distance matrix

  • metric[in] Distance metric to use

  • batch_samples[in] batch size for input data samples

  • batch_centroids[in] batch size for input centroids

  • workspace[in] Temporary workspace buffer which can get resized

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::min_cluster_and_distance(raft::resources const &handle, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<const DataT, IndexT> centroids, raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance, raft::device_vector_view<DataT, IndexT> L2NormX, rmm::device_uvector<DataT> &L2NormBuf_OR_DistBuf, raft::distance::DistanceType metric, int batch_samples, int batch_centroids, rmm::device_uvector<char> &workspace)#

Calculates a <key, value> pair for every sample in input ‘X’ where key is an index of one of the ‘centroids’ (index of the nearest centroid) and ‘value’ is the distance between the sample and the ‘centroid[key]’.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • X[in] The data in row-major format [dim = n_samples x n_features]

  • centroids[in] Centroids data [dim = n_cluster x n_features]

  • minClusterAndDistance[out] Distance vector that contains for every sample, the nearest centroid and it’s distance [dim = n_samples]

  • L2NormX[in] L2 norm of X : ||x||^2 [dim = n_samples]

  • L2NormBuf_OR_DistBuf[out] Resizable buffer to store L2 norm of centroids or distance matrix

  • metric[in] distance metric

  • batch_samples[in] batch size of data samples

  • batch_centroids[in] batch size of centroids

  • workspace[in] Temporary workspace buffer which can get resized

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::shuffle_and_gather(raft::resources const &handle, raft::device_matrix_view<const DataT, IndexT> in, raft::device_matrix_view<DataT, IndexT> out, uint32_t n_samples_to_gather, uint64_t seed)#

Shuffle and randomly select ‘n_samples_to_gather’ from input ‘in’ and stores in ‘out’ does not modify the input.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • in[in] The data to shuffle and gather [dim = n_samples x n_features]

  • out[out] The sampled data [dim = n_samples_to_gather x n_features]

  • n_samples_to_gather[in] Number of sample to gather

  • seed[in] Seed for the shuffle

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::count_samples_in_cluster(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_vector_view<DataT, IndexT> L2NormX, raft::device_matrix_view<DataT, IndexT> centroids, rmm::device_uvector<char> &workspace, raft::device_vector_view<DataT, IndexT> sampleCountInCluster)#

Count the number of samples in each cluster.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • params[in] The parameters for KMeans

  • X[in] The data in row-major format [dim = n_samples x n_features]

  • L2NormX[in] L2 norm of X : ||x||^2 [dim = n_samples]

  • centroids[in] Centroids data [dim = n_cluster x n_features]

  • workspace[in] Temporary workspace buffer which can get resized

  • sampleCountInCluster[out] The count for each centroid [dim = n_cluster]

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::init_plus_plus(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<DataT, IndexT> centroids, rmm::device_uvector<char> &workspace)#

Selects ‘n_clusters’ samples from the input X using kmeans++ algorithm.

See also

“k-means++: the advantages of careful seeding”. 2007, Arthur, D. and Vassilvitskii, S. ACM-SIAM symposium on Discrete algorithms.

Template Parameters:
  • DataT – the type of data used for weights, distances.

  • IndexT – the type of data used for indexing.

Parameters:
  • handle[in] The raft handle

  • params[in] The parameters for KMeans

  • X[in] The data in row-major format [dim = n_samples x n_features]

  • centroids[out] Centroids data [dim = n_cluster x n_features]

  • workspace[in] Temporary workspace buffer which can get resized

template<typename DataT, typename IndexT>
void raft::cluster::kmeans::fit_main(raft::resources const &handle, const KMeansParams &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_vector_view<const DataT, IndexT> sample_weights, raft::device_matrix_view<DataT, IndexT> centroids, raft::host_scalar_view<DataT> inertia, raft::host_scalar_view<IndexT> n_iter, rmm::device_uvector<char> &workspace)#
struct KMeansParams : public raft::cluster::kmeans_base_params#
#include <kmeans_types.hpp>

Simple object to specify hyper-parameters to the kmeans algorithm.

Public Members

int n_clusters = 8#

The number of clusters to form as well as the number of centroids to generate (default:8).

InitMethod init = KMeansPlusPlus#

Method for initialization, defaults to k-means++:

  • InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm to select the initial cluster centers.

  • InitMethod::Random (random): Choose ‘n_clusters’ observations (rows) at random from the input data for the initial centroids.

  • InitMethod::Array (ndarray): Use ‘centroids’ as initial cluster centers.

int max_iter = 300#

Maximum number of iterations of the k-means algorithm for a single run.

double tol = 1e-4#

Relative tolerance with regards to inertia to declare convergence.

int verbosity = RAFT_LEVEL_INFO#

verbosity level.

raft::random::RngState rng_state = {0}#

Seed to the random number generator.

int n_init = 1#

Number of instance k-means algorithm will be run with different seeds.

double oversampling_factor = 2.0#

Oversampling factor for use in the k-means|| algorithm

int batch_centroids = 0#

if 0 then batch_centroids = n_clusters