pylibcugraphops.operators.mha_gat_n2n_efeat_bwd_bf16_fp32_fp32_fp32#
- pylibcugraphops.operators.mha_gat_n2n_efeat_bwd_bf16_fp32_fp32_fp32 = <nanobind.nb_func object>#
- Computes the backward pass for a multi-head attention layer (GAT-like)
without using cudnn (mha_gat) operating on bipartite graphs in a node-to-node reduction (n2n) but using edge features, too, for computing the dot product (efeat).
mha_gat_n2n_efeat_bwd( grad_src_node_embedding: device array, grad_dst_node_embedding: device array, grad_edge_embedding: device array, grad_attention_weights : device array, grad_softmax_scores: device array, grad_output_embedding: device array, src_node_embedding: device array, dst_node_embedding: device array, edge_embedding: device array, attention_weights: device array, softmax_scores: device array, graph: pylibcugraphops.bipartite_csc_int[64|32], params: pylibcugraphops.operators.mha_params, grad_attention_weights_workspace: Optional[device array] = none, stream_id: int = 0 ) -> None
- Parameters:
- grad_src_node_embeddingdevice array type
Device array containing the gradients on source node embeddings. May be None if not needed. Shape:
(graph.n_src_nodes, dim_node)
.- grad_dst_node_embeddingdevice array type
Device array containing the gradients on destination node embeddings. May be None if not needed. Shape:
(graph.n_dst_nodes, dim_node)
.- grad_edge_embeddingdevice array type
Device array containing the gradients on edge embeddings. May be None if not needed. Shape:
(graph.n_indices, dim_edge)
.- grad_attention_weightsdevice array type
Device array containing the output gradient on the attention weights. May be None if not needed. Shape:
(2 * dim_node + dim_edge, )
.- grad_softmax_scoresdevice array type
Device array containing the gradient of softmax scores (workspace device array). Shape:
(2, params.num_heads, graph.n_indices)
.- grad_output_embeddingdevice array type
Device array containing the input gradients on the output node embeddings. Shape:
(graph.n_dst_nodes, dim_out)
, withdim_out = dim_node
whenparams.concat_heads
isTrue
;dim_out = dim_node / params.num_heads
otherwise.- src_node_embeddingdevice array type
Device array containing the source node embeddings. Shape: same as
grad_src_node_embedding
.- dst_node_embeddingdevice array type
Device array containing the destination node embeddings. Shape: same as
grad_dst_node_embedding
.- edge_embeddingdevice array type
Device array containing the input edge embeddings. Shape: same as
grad_edge_embedding
.- attention_weightsdevice array type
Device array containing the (learnable) attention weights. Shape: same as
grad_attention_weights
.- softmax_scoresdevice array type
Device array containing the pre- and post-softmax-scores (from forward pass). Shape: same as
grad_softmax_scores
.- graphopaque graph type
graph used for the operation.
- paramsopaque mha_params type
Structure summarizing hyperparameters of the primitive like
num_heads
,concat_heads
or the used activation function.- grad_attention_weights_workspaceoptional device array type
Optional workspace for computation of weight gradient. If not passed as None, the gradient on the weights will be computed deterministically through the passed workspace. Shape:
(graph.n_dst_nodes, 2 * dim_node + dim_edge)
.- stream_idint, default=0
CUDA stream pointer as a python int.