pylibcugraphops.operators.mha_gat_v2_n2n_efeat_fwd#

pylibcugraphops.operators.mha_gat_v2_n2n_efeat_fwd = <nanobind.nb_func object>#
Computes the forward pass for a multi-head attention layer (GAT-like)

without using cudnn (mha_gat_v2) with an activation prior to the dot product but none afterwards operating on bipartite graphs in a node-to-node reduction (n2n) but using edge features, too, for computing the dot product (efeat).

mha_gat_v2_n2n_efeat_fwd(
    output_embedding: device array, softmax_scores: device array,
    activation_scores: device array, src_node_embedding: device array,
    dst_node_embedding: device array, edge_embedding: device array,
    attention_weights: device array, graph: pylibcugraphops.bipartite_csc_int[64|32],
    params: pylibcugraphops.operators.mha_params, stream_id: int = 0
) -> None
Parameters:
output_embeddingdevice array type

Device array containing the output node embeddings. Shape: (graph.n_dst_nodes, dim_out), with dim_out = dim_node when params.concat_heads is True; dim_out = dim_node / params.num_heads otherwise.

softmax_scoresdevice array type

Device array containing the pre- and post-softmax-scores (for backward). Shape: (2, params.num_heads, graph.n_indices).

activation_scoresdevice array type

Device array containing the scores after the activation (for backward only). Shape: (graph.n_indices, dim_node).

src_node_embeddingdevice array type

Device array containing the source node embeddings. Shape: (graph.n_src_nodes, dim_node).

dst_node_embeddingdevice array type

Device array containing the destination node embeddings. Shape: (graph.n_dst_nodes, dim_node).

edge_embeddingdevice array type

Device array containing the input edge embeddings. Shape: (graph.n_indices, dim_node).

attention_weightsdevice array type

Device array containing the (learnable) attention weights. Shape: (dim_node, ).

graphopaque graph type

graph used for the operation.

paramsopaque mha_params type

Structure summarizing hyperparameters of the primitive like num_heads, concat_heads or the used activation function.

stream_idint, default=0

CUDA stream pointer as a python int.