Classes | |
struct | param |
contains all the hyper-parameters for training More... | |
struct | node |
Represents a node in the syntax tree. More... | |
struct | program |
The main data structure to store the AST that represents a program in the current generation. More... | |
Typedefs | |
typedef program * | program_t |
Enumerations | |
enum class | metric_t : uint32_t { mae , mse , rmse , pearson , spearman , logloss } |
enum class | init_method_t : uint32_t { grow , full , half_and_half } |
enum class | transformer_t : uint32_t { sigmoid } |
enum class | mutation_t : uint32_t { none , crossover , subtree , hoist , point , reproduce } |
Functions | |
std::string | stringify (const program &prog) |
Visualize an AST. More... | |
void | symFit (const raft::handle_t &handle, const float *input, const float *labels, const float *sample_weights, const int n_rows, const int n_cols, param ¶ms, program_t &final_progs, std::vector< std::vector< program >> &history) |
Fit either a regressor, classifier or a transformer to the given dataset. More... | |
void | symRegPredict (const raft::handle_t &handle, const float *input, const int n_rows, const program_t &best_prog, float *output) |
Make predictions for a symbolic regressor. More... | |
void | symClfPredictProbs (const raft::handle_t &handle, const float *input, const int n_rows, const param ¶ms, const program_t &best_prog, float *output) |
Probability prediction for a symbolic classifier. If a transformer(like sigmoid) is specified, then it is applied on the output before returning it. More... | |
void | symClfPredict (const raft::handle_t &handle, const float *input, const int n_rows, const param ¶ms, const program_t &best_prog, float *output) |
Return predictions for a binary classification program defining the decision boundary. More... | |
void | symTransform (const raft::handle_t &handle, const float *input, const param ¶ms, const program_t &final_progs, const int n_rows, const int n_cols, float *output) |
Transform the values in the input feature matrix according to the supplied programs. More... | |
void | execute (const raft::handle_t &h, const program_t &d_progs, const int n_rows, const int n_progs, const float *data, float *y_pred) |
Calls the execution kernel to evaluate all programs on the given dataset. More... | |
void | compute_metric (const raft::handle_t &h, int n_rows, int n_progs, const float *y, const float *y_pred, const float *w, float *score, const param ¶ms) |
Compute the loss based on the metric specified in the training hyperparameters. It performs a batched computation for all programs in one shot. More... | |
void | find_fitness (const raft::handle_t &h, program_t &d_prog, float *score, const param ¶ms, const int n_rows, const float *data, const float *y, const float *sample_weights) |
Computes the fitness scores for a sngle program on the given dataset. More... | |
void | find_batched_fitness (const raft::handle_t &h, int n_progs, program_t &d_progs, float *score, const param ¶ms, const int n_rows, const float *data, const float *y, const float *sample_weights) |
Computes the fitness scores for all programs on the given dataset. More... | |
void | set_fitness (const raft::handle_t &h, program_t &d_prog, program &h_prog, const param ¶ms, const int n_rows, const float *data, const float *y, const float *sample_weights) |
Computes and sets the fitness scores for a single program on the given dataset. More... | |
void | set_batched_fitness (const raft::handle_t &h, int n_progs, program_t &d_progs, std::vector< program > &h_progs, const param ¶ms, const int n_rows, const float *data, const float *y, const float *sample_weights) |
Computes and sets the fitness scores for all programs on the given dataset. More... | |
float | get_fitness (const program &prog, const param ¶ms) |
Returns precomputed fitness score of program on the host, after accounting for parsimony. More... | |
int | get_depth (const program &p_out) |
Evaluates and returns the depth of the current program. More... | |
void | build_program (program &p_out, const param ¶ms, std::mt19937 &rng) |
Build a random program with depth atmost 10. More... | |
void | point_mutation (const program &prog, program &p_out, const param ¶ms, std::mt19937 &rng) |
Perform a point mutation on the given program(AST) More... | |
void | crossover (const program &prog, const program &donor, program &p_out, const param ¶ms, std::mt19937 &rng) |
Perform a 'hoisted' crossover mutation using the parent and donor programs. The donor subtree selected is hoisted to ensure our constrains on total depth. More... | |
void | subtree_mutation (const program &prog, program &p_out, const param ¶ms, std::mt19937 &rng) |
Performs a crossover mutation with a randomly built new program. Since crossover is 'hoisted', this will ensure that depth constrains are not violated. More... | |
void | hoist_mutation (const program &prog, program &p_out, const param ¶ms, std::mt19937 &rng) |
Perform a hoist mutation on a random subtree of the given program (replace a subtree with a subtree of a subtree) More... | |
Variables | |
const int | GENE_TPB = 256 |
const int | MAX_STACK_SIZE = 20 |
typedef program* cuml::genetic::program_t |
program_t is a shorthand for device programs
|
strong |
|
strong |
fitness metric types
|
strong |
|
strong |
Build a random program with depth atmost 10.
p_out | The output program |
params | Training hyperparameters |
rng | RNG to decide nodes to add |
void cuml::genetic::compute_metric | ( | const raft::handle_t & | h, |
int | n_rows, | ||
int | n_progs, | ||
const float * | y, | ||
const float * | y_pred, | ||
const float * | w, | ||
float * | score, | ||
const param & | params | ||
) |
Compute the loss based on the metric specified in the training hyperparameters. It performs a batched computation for all programs in one shot.
h | cuML handle |
n_rows | The number of labels/rows in the expected output |
n_progs | The number of programs being batched |
y | Device pointer to the expected output (SIZE = n_samples) |
y_pred | Device pointer to the predicted output (SIZE = n_samples * n_progs) |
w | Device pointer to sample weights (SIZE = n_samples) |
score | Device pointer to final score (SIZE = n_progs) |
params | Training hyperparameters |
void cuml::genetic::crossover | ( | const program & | prog, |
const program & | donor, | ||
program & | p_out, | ||
const param & | params, | ||
std::mt19937 & | rng | ||
) |
Perform a 'hoisted' crossover mutation using the parent and donor programs. The donor subtree selected is hoisted to ensure our constrains on total depth.
prog | The input program |
donor | The donor program |
p_out | The result program |
params | Training hyperparameters |
rng | RNG for subtree selection |
void cuml::genetic::execute | ( | const raft::handle_t & | h, |
const program_t & | d_progs, | ||
const int | n_rows, | ||
const int | n_progs, | ||
const float * | data, | ||
float * | y_pred | ||
) |
Calls the execution kernel to evaluate all programs on the given dataset.
h | cuML handle |
d_progs | Device pointer to programs |
n_rows | Number of rows in the input dataset |
n_progs | Total number of programs being evaluated |
data | Device pointer to input dataset (in col-major format) |
y_pred | Device pointer to output of program evaluation |
void cuml::genetic::find_batched_fitness | ( | const raft::handle_t & | h, |
int | n_progs, | ||
program_t & | d_progs, | ||
float * | score, | ||
const param & | params, | ||
const int | n_rows, | ||
const float * | data, | ||
const float * | y, | ||
const float * | sample_weights | ||
) |
Computes the fitness scores for all programs on the given dataset.
h | cuML handle |
n_progs | Batch size(Number of programs) |
d_progs | Device pointer to list of programs |
score | Device pointer to fitness vals computed for all programs |
params | Training hyperparameters |
n_rows | Number of rows in the input dataset |
data | Device pointer to input dataset |
y | Device pointer to input labels |
sample_weights | Device pointer to sample weights |
void cuml::genetic::find_fitness | ( | const raft::handle_t & | h, |
program_t & | d_prog, | ||
float * | score, | ||
const param & | params, | ||
const int | n_rows, | ||
const float * | data, | ||
const float * | y, | ||
const float * | sample_weights | ||
) |
Computes the fitness scores for a sngle program on the given dataset.
h | cuML handle |
d_prog | Device pointer to program |
score | Device pointer to fitness vals |
params | Training hyperparameters |
n_rows | Number of rows in the input dataset |
data | Device pointer to input dataset |
y | Device pointer to input labels |
sample_weights | Device pointer to sample weights |
int cuml::genetic::get_depth | ( | const program & | p_out | ) |
Evaluates and returns the depth of the current program.
p_out | The given program |
Returns precomputed fitness score of program on the host, after accounting for parsimony.
prog | The host program |
params | Training hyperparameters |
void cuml::genetic::hoist_mutation | ( | const program & | prog, |
program & | p_out, | ||
const param & | params, | ||
std::mt19937 & | rng | ||
) |
Perform a hoist mutation on a random subtree of the given program (replace a subtree with a subtree of a subtree)
prog | The input program |
p_out | The output program |
params | Training hyperparameters |
rng | RNG to control subtree selection |
void cuml::genetic::point_mutation | ( | const program & | prog, |
program & | p_out, | ||
const param & | params, | ||
std::mt19937 & | rng | ||
) |
Perform a point mutation on the given program(AST)
prog | The input program |
p_out | The result program |
params | Training hyperparameters |
rng | RNG to decide nodes to mutate |
void cuml::genetic::set_batched_fitness | ( | const raft::handle_t & | h, |
int | n_progs, | ||
program_t & | d_progs, | ||
std::vector< program > & | h_progs, | ||
const param & | params, | ||
const int | n_rows, | ||
const float * | data, | ||
const float * | y, | ||
const float * | sample_weights | ||
) |
Computes and sets the fitness scores for all programs on the given dataset.
h | cuML handle |
n_progs | Batch size |
d_progs | Device pointer to list of programs |
h_progs | Host vector of programs corresponding to d_progs |
params | Training hyperparameters |
n_rows | Number of rows in the input dataset |
data | Device pointer to input dataset |
y | Device pointer to input labels |
sample_weights | Device pointer to sample weights |
void cuml::genetic::set_fitness | ( | const raft::handle_t & | h, |
program_t & | d_prog, | ||
program & | h_prog, | ||
const param & | params, | ||
const int | n_rows, | ||
const float * | data, | ||
const float * | y, | ||
const float * | sample_weights | ||
) |
Computes and sets the fitness scores for a single program on the given dataset.
h | cuML handle |
d_prog | Device pointer to program |
h_prog | Host program object |
params | Training hyperparameters |
n_rows | Number of rows in the input dataset |
data | Device pointer to input dataset |
y | Device pointer to input labels |
sample_weights | Device pointer to sample weights |
std::string cuml::genetic::stringify | ( | const program & | prog | ) |
Visualize an AST.
prog | host object containing the AST |
void cuml::genetic::subtree_mutation | ( | const program & | prog, |
program & | p_out, | ||
const param & | params, | ||
std::mt19937 & | rng | ||
) |
Performs a crossover mutation with a randomly built new program. Since crossover is 'hoisted', this will ensure that depth constrains are not violated.
prog | The input program |
p_out | The result mutated program |
params | Training hyperparameters |
rng | RNG to control subtree selection and temporary program addition |
void cuml::genetic::symClfPredict | ( | const raft::handle_t & | handle, |
const float * | input, | ||
const int | n_rows, | ||
const param & | params, | ||
const program_t & | best_prog, | ||
float * | output | ||
) |
Return predictions for a binary classification program defining the decision boundary.
handle | cuML handle |
input | device pointer to feature matrix |
n_rows | number of rows of the feature matrix |
params | host struct containing training hyperparameters |
best_prog | Best program obtained after training |
output | Device pointer to output predictions |
void cuml::genetic::symClfPredictProbs | ( | const raft::handle_t & | handle, |
const float * | input, | ||
const int | n_rows, | ||
const param & | params, | ||
const program_t & | best_prog, | ||
float * | output | ||
) |
Probability prediction for a symbolic classifier. If a transformer(like sigmoid) is specified, then it is applied on the output before returning it.
handle | cuML handle |
input | device pointer to feature matrix |
n_rows | number of rows of the feature matrix |
params | host struct containing training hyperparameters |
best_prog | The best program obtained during training. Inferences are made using this |
output | device pointer to output probability(in col major format) |
void cuml::genetic::symFit | ( | const raft::handle_t & | handle, |
const float * | input, | ||
const float * | labels, | ||
const float * | sample_weights, | ||
const int | n_rows, | ||
const int | n_cols, | ||
param & | params, | ||
program_t & | final_progs, | ||
std::vector< std::vector< program >> & | history | ||
) |
Fit either a regressor, classifier or a transformer to the given dataset.
handle | cuML handle |
input | device pointer to the feature matrix |
labels | device pointer to the label vector of length n_rows |
sample_weights | device pointer to the sample weights of length n_rows |
n_rows | number of rows of the feature matrix |
n_cols | number of columns of the feature matrix |
params | host struct containing hyperparameters needed for training |
final_progs | device pointer to the final generation of programs(sorted by decreasing fitness) |
history | host vector containing the list of all programs in every generation (sorted by decreasing fitness) |
final_progs[i].nodes
for each program i
in final_progs
. The amount of memory allocated is found at runtime, and is final_progs[i].len * sizeof(node)
for each program i
. The reason this isn't deallocated within the function is because the resulting memory is needed for executing predictions in symRegPredict
, symClfPredict
, symClfPredictProbs
and symTransform
functions. The above device memory is expected to be explicitly deallocated by the caller AFTER calling the predict function. void cuml::genetic::symRegPredict | ( | const raft::handle_t & | handle, |
const float * | input, | ||
const int | n_rows, | ||
const program_t & | best_prog, | ||
float * | output | ||
) |
Make predictions for a symbolic regressor.
handle | cuML handle |
input | device pointer to feature matrix |
n_rows | number of rows of the feature matrix |
best_prog | device pointer to best AST fit during training |
output | device pointer to output values |
void cuml::genetic::symTransform | ( | const raft::handle_t & | handle, |
const float * | input, | ||
const param & | params, | ||
const program_t & | final_progs, | ||
const int | n_rows, | ||
const int | n_cols, | ||
float * | output | ||
) |
Transform the values in the input feature matrix according to the supplied programs.
handle | cuML handle |
input | device pointer to feature matrix |
params | Hyperparameters used during training |
final_progs | List of ASTs used for generating new features |
n_rows | number of rows of the feature matrix |
n_cols | number of columns of the feature matrix |
output | device pointer to transformed input |
const int cuml::genetic::GENE_TPB = 256 |
const int cuml::genetic::MAX_STACK_SIZE = 20 |