pylibcugraphops.operators.mha_gat_n2n_bwd#

pylibcugraphops.operators.mha_gat_n2n_bwd = <nanobind.nb_func object>#
Computes the backward pass for a multi-head attention layer (GAT-like)

without using cudnn (mha_gat) operating on bipartite graphs in a node-to-node reduction (n2n).

mha_gat_n2n_bwd(
    grad_src_node_embedding: device array, grad_dst_node_embedding: device array,
    grad_attention_weights: device array,
    grad_softmax_scores: device array, grad_output_embedding: device array,
    node_embedding: device array, attention_weights: device array,
    softmax_scores: device array, graph: pylibcugraphops.bipartite_csc_int[64|32],
    params: pylibcugraphops.operators.mha_params,
    grad_attention_weights_workspace: Optional[device array] = none, stream_id: int = 0
) -> None
Parameters:
grad_src_node_embeddingdevice array type

Device array containing the output gradient on the source node embeddings. May be None if not needed. Shape: (graph.n_src_nodes, dim_node).

grad_dst_node_embeddingdevice array type

Device array containing the output gradient on the destination node embeddings. May be None if not needed. Shape: (graph.n_dst_nodes, dim_node).

grad_attention_weightsdevice array type

Device array containing the output gradient on the attention weights. May be None if not needed. Shape: (2 * dim_node, ).

grad_softmax_scoresdevice array type

Device array containing the gradient of softmax scores (workspace device array). Shape: (2, params.num_heads, graph.n_indices).

grad_output_embeddingdevice array type

Device array containing the input gradients on the output node embeddings. Shape: (graph.n_dst_nodes, dim_out), with dim_out = dim_node when params.concat_heads is True; dim_out = dim_node / params.num_heads otherwise.

src_node_embeddingdevice array type

Device array containing the source node embeddings. Shape: same as grad_src_node_embedding.

dst_node_embeddingdevice array type

Device array containing the destination node embeddings. Shape: same as grad_dst_node_embedding.

attention_weights: device array type

Device array containing the (learnable) attention weights. Shape: same as grad_attention_weights.

softmax_scoresdevice array type

Device array containing the pre- and post-softmax-scores (from forward pass). Shape: same as grad_softmax_scores.

graphopaque graph type

graph used for the operation.

paramsopaque mha_params type

Structure summarizing hyperparameters of the primitive like num_heads, concat_heads or the used activation function.

grad_attention_weights_workspaceoptional device array type

Optional workspace for computation of weight gradient. If not passed as None, the gradient on the weights will be computed deterministically through the passed workspace. Shape: (graph.n_dst_nodes, 2 * dim_node).

stream_idint, default=0

CUDA stream pointer as a python int.