pylibcugraphops.operators.agg_hg_basis_n2n_post_bwd#

pylibcugraphops.operators.agg_hg_basis_n2n_post_bwd = <nanobind.nb_func object>#
Computes the backward pass for node-to-node full-graph RGCN-like basis regularized

aggregation, with features being transformed after (post) this aggregation (in forward).

agg_hg_basis_n2n_post_bwd(
    grad_input: Optional[device array],
    grad_weights: Optional[device array], grad_output: device array,
    input_embedding: device array, weights_combination: Optional[device array],
    graph: pylibcugraphops.csc_hg_int[32|64], concat_own: bool = False,
    norm_by_out_degree: bool = False, stream_id: int = 0
) -> None
Parameters:
grad_inputdevice array type | None

Device array containing the output gradient on input embeddings of forward. If None, this is not calculated. Shape: (graph.n_dst_nodes, dim_in).

grad_weightsdevice array type | None

Device array containing the output gradient on combination weights of forward. If None, this is not calculated. If both this and grad_input are None, no calculation is performed. Shape: (n_edge_types, n_bases) if set.

grad_outputdevice array type

Device array containing the input gradient on output embeddings of forward. Shape for concat_own=False: (graph.n_dst_nodes, dim_out); Shape for concat_own=True: (graph.n_dst_nodes, dim_out + dim_in), with dim_out = dim_in * n_bases if weights_combination is set, dim_out = dim_in * n_edge_types otherwise.

input_embeddingdevice array type

Device array containing the input embeddings from forward. Shape: same as grad_input.

weights_combinationdevice array type | None

Device array containing the combination weights from forward. Shape: same as grad_weights.

graphopaque graph type

The graph used for the operation (must be same as in forward).

concat_ownbool, default=False

Concatenate output node embeddings in the aggregation.

norm_by_out_degreebool, default=False

If set, output embeddings are normed by the degree of the output node.

stream_idint, default=0

CUDA stream pointer as a python int.