pylibcugraphops.dimenet.radial_basis_bwd#
- pylibcugraphops.dimenet.radial_basis_bwd = <nanobind.nb_func object>#
Computes the backward pass for Dimenet++ radial basis features.
radial_basis_bwd( output_grad_vector: device array, output_grad_w: device array, input_grad_rbf: device array, input_grad_sbf_rad: device array, input_vector: device array, input_w: device array )
- We define the following dimensions:
n_spherical: number of spherical basis functions. Must be 7.
n_radial: number of radial basis functions. Must be 6.
n_dist: number of distance dimensions. Must be 1.
- Parameters:
- output_grad_vectordevice array type
Device array containing the output gradients for input vector (of forward). Dimension is assumed to be [#edges, n_dist].
- output_grad_wdevice array type
Device array containing the output gradients of input frequencies (of forward). Dimension is assumed to be [n_radial].
- input_grad_rbfdevice array type
Device array containing the gradients on radial basis features from forward. Dimension is assumed to be [#edges, n_radial].
- input_grad_sbf_raddevice array type
Device array containing the gradients on radial part of spherical basis features from forward. Dimension is assumed to be [#edges, n_spherical * n_radial].
- input_vectordevice array type
Device array containing the vector values from forward. Dimension is assumed to be [#edges, n_dist].
- input_wdevice array type
Device array containing the input frequencies from forward. Dimension is assumed to be [n_radial].