pylibcugraphops.dimenet.agg_edge_to_edge_bwd2_main#
- pylibcugraphops.dimenet.agg_edge_to_edge_bwd2_main = <nanobind.nb_func object>#
- Computes the second-order backward pass for Dimenet++ interaction block
aggregation layer with respect to radial basis features, embeddings and weights as used in first backward pass.
agg_edge_to_edge_bwd2_main( output_grad_rbf_grad_embedding: device array, output_grad_embedding_grad_rbf: device array, output_grad_weights: device array, input_vector: device array, input_rbf: device array, input_grad_grad_rbf: device array, input_embedding: device array, input_grad_grad_embedding: device array, input_grad_embedding: device array, input_weights: device array, coo_idx: device array, src_offsets: device array, mma_operation: pylibcugraphops.MMAOp, cuda_stream: int = 0 )
- We define the following dimensions:
n_spherical: number of spherical basis functions. Must be 7.
n_radial: number of radial basis functions. Must be 6.
n_vec: number of vector/position dimensions. Must be 3.
n_emb: input/output embedding dimension. Assumed to be at most 64.
- n_mid: we project the spherical basis features twice: once to feature
dimension n_mid, and then to feature dimension n_emb. Assumed to be at most 8.
- Parameters:
- output_grad_rbf_grad_embeddingdevice array type
Device array containing the output gradients for radial basis features (from backward). Dimension is assumed to be [#edges, n_spherical * n_radial].
- output_grad_embedding_grad_rbfdevice array type
Device array containing the output gradients for input embeddings (from backward). Dimension is assumed to be [#edges, n_emb].
- output_grad_weightsdevice array type
Device array containing the output gradients for weights (from backward). Dimension is assumed to be [(n_spherical * n_radial + n_emb), n_mid].
- input_vectordevice array type
Device array containing the vector values from forward. Dimension is assumed to be [#edges, n_vec].
- input_rbfdevice array type
Device array containing the radial basis features from forward. Dimension is assumed to be [#edges, n_spherical * n_radial].
- input_grad_grad_rbfdevice array type
Device array containing the gradients for output gradients (from backward) of radial basis features (from forward). Dimension is assumed to be [#edges, n_spherical * n_radial].
- input_embeddingdevice array type
Device array containing the input embeddings values from forward. Dimension is assumed to be [#edges, n_emb].
- input_grad_grad_embeddingdevice array type
Device array containing the gradients for output gradients (from backward) of input embeddings (from forward). Dimension is assumed to be [#edges, n_emb].
- input_grad_embeddingdevice array type
Device array containing the gradients for output embeddings (from backward). Dimension is assumed to be [#edges, n_emb].
- input_weightsdevice array type
Device array containing the weights used in forward. Dimension is assumed to be [(n_spherical * n_radial + n_emb), n_mid].
- coo_idxdevice array type
Device array containing the COO index of the graph. Dimension is assumed to be [2, #edges].
- src_offsetsdevice array type
Device array containing the CSR-like offsets of the source nodes. Dimension is assumed to be [#nodes + 1].
- mma_operationpylibcugraphops.MMAOp
MMA precision: pylibcugraphops.MMAOp.HighPrecision performs 3x TF32 operations while pylibcugraphops.MMAOp.LowPrecision performs 1x TF32 MMA operation
- cuda_streamint, default=0
CUDA stream as an integer representing the raw pointer