pylibcugraphops.dimenet.agg_edge_to_edge_fwd#

pylibcugraphops.dimenet.agg_edge_to_edge_fwd = <nanobind.nb_func object>#

Computes the forward pass for Dimenet++ interaction block aggregation layer.

agg_edge_to_edge_fwd(
    output_embedding: device array,
    input_vector: device array, input_rbf: device array,
    input_embedding: device array, input_weights: device array,
    coo_idx: device array, dst_offsets: device array,
    dst_edge_index: device array, mma_operation: pylibcugraphops.MMAOp,
    cuda_stream: int = 0
)
We define the following dimensions:
  • n_spherical: number of spherical basis functions. Must be 7.

  • n_radial: number of radial basis functions. Must be 6.

  • n_vec: number of vector/position dimensions. Must be 3.

  • n_emb: input/output embedding dimension. Assumed to be at most 64.

  • n_mid: we project the spherical basis features twice: once to feature

    dimension n_mid, and then to feature dimension n_emb. Assumed to be at most 8.

Parameters:
output_embeddingdevice array type

Device array containing the output embeddings values. Dimension is assumed to be [#edges, n_emb].

input_vectordevice array type

Device array containing the vector values (position of input node - position of output node + edge offsets) for each edge. Dimension is assumed to be [#edges, n_vec].

input_rbfdevice array type

Device array containing the radial basis features used for calculating the spherical basis features. Dimension is assumed to be [#edges, n_spherical * n_radial].

input_embeddingdevice array type

Device array containing the input embeddings values. Dimension is assumed to be [#edges, n_emb].

input_weightsdevice array type

Device array containing the weights used to project spherical basis features twice: first to “middle” dimension n_mid, then to embedding dimension n_emb. The weights are represented in a combined form with the second projection weights being transposed. Dimension is assumed to be [(n_spherical * n_radial) + n_emb, n_mid].

coo_idxdevice array type

Device array containing the COO index of the graph. Dimension is assumed to be [2, #edges].

dst_offsetsdevice array type

Device array containing the CSR-like offsets of the destination nodes. Dimension is assumed to be [#nodes + 1].

dst_edge_indexdevice array type

Device array containing the CSR-like indices of mapping the neighbors of destination nodes to edge IDs. Dimension is assumed to be [#edges].

mma_operationpylibcugraphops.MMAOp

MMA precision: pylibcugraphops.MMAOp.HighPrecision performs 3x TF32 operations while pylibcugraphops.MMAOp.LowPrecision performs 1x TF32 MMA operation

cuda_streamint, default=0

CUDA stream as an integer representing the raw pointer