pylibcugraphops.dimenet.radial_basis_bwd_bwd#

pylibcugraphops.dimenet.radial_basis_bwd_bwd = <nanobind.nb_func object>#

Computes the second-order backward pass for Dimenet++ radial basis features.

radial_basis_bwd_bwd(
    output_grad_grad_rbf: device array,
    output_grad_grad_sbf_rad: device array, output_grad_w: device array,
    input_grad_grad_vector: device array, input_grad_grad_w: device array,
    input_grad_rbf: device array, input_grad_sbf_rad: device array,
    input_vector: device array, input_w: device array
)
We define the following dimensions:
  • n_spherical: number of spherical basis functions. Must be 7.

  • n_radial: number of radial basis functions. Must be 6.

  • n_dist: number of distance dimensions. Must be 1.

Parameters:
output_grad_grad_rbfdevice array type

Device array containing the output gradients for gradients of radial basis features (of backward). Dimension is assumed to be [#edges, n_radial].

output_grad_grad_sbf_raddevice array type

Device array containing the output gradients for gradients of radial part of spherical basis features (of backward). Dimension is assumed to be [#edges, n_spherical * n_radial].

output_grad_wdevice array type

Device array containing the output gradients of input frequencies (of backward). Dimension is assumed to be [n_radial].

input_grad_grad_vectordevice array type

Device array containing the gradients on gradients of vector values (from backward). Dimension is assumed to be [#edges, n_dist].

input_grad_grad_wdevice array type

Device array containing the gradients on gradients of frequencies (from backward). Dimension is assumed to be [n_radial].

input_grad_rbfdevice array type

Device array containing the gradients on radial basis features from forward. Dimension is assumed to be [#edges, n_radial].

input_grad_sbf_raddevice array type

Device array containing the gradients on radial part of spherical basis features from forward. Dimension is assumed to be [#edges, n_spherical * n_radial].

input_vectordevice array type

Device array containing the vector values from forward. Dimension is assumed to be [#edges, n_dist].

input_wdevice array type

Device array containing the input frequencies from forward. Dimension is assumed to be [n_radial].