equitorch.nn

Contains equivariant neural network modules and functionalities.

class equitorch.nn.SO3Linear(irreps_in1, irreps_in2, irreps_out, channels_in=None, channels_out=None, internal_weights=True, feature_mode='uu', path_norm=True, channel_norm=False, path=None)

Bases: Module

SO(3) equivariant linear layer using tensor products.

Equivalent to a TensorProduct where input2 does not have a channel dimension.

Supports two main modes controlled by feature_mode:

  • 'uv': Fully connected in channel dimension.

    • Input1 shape: (..., irreps_in1.dim, channels_in)

    • Input2 shape: (..., irreps_in2.dim)

    • Weight shape: (num_paths, channels_in, channels_out)

    • Output shape: (..., irreps_out.dim, channels_out)

  • 'uu': Depthwise/elementwise in channel dimension.

    • Input1 shape: (..., irreps_in1.dim, channels)

    • Input2 shape: (..., irreps_in2.dim)

    • Weight shape: (num_paths, channels)

    • Output shape: (..., irreps_out.dim, channels)

Parameters:
  • irreps_in1 (Irreps) – Irreducible representations of the main input tensor (input1).

  • irreps_in2 (Irreps) – Irreducible representations of the second input tensor (input2), often representing weights like spherical harmonics.

  • irreps_out (Irreps) – Irreducible representations of the output tensor.

  • channels_in (int, optional) – Number of channels for the main input (input1). Required if internal_weights=True.

  • channels_out (int, optional) – Number of channels for the output. Required if internal_weights=True.

  • internal_weights (bool, optional) – If True, the module manages its own weight parameter. If False, weights must be provided during the forward pass. Defaults to True.

  • feature_mode (str, optional) –

    Controls the type of linear operation: {'uu', 'uv'}. Defaults to 'uu'.

    • 'uu': Depthwise/elementwise linear. Assumes channels_in == channels_out.

    • 'uv': Fully connected linear.

  • path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to True.

  • channel_norm (bool, optional) – Whether to apply channel normalization (specific to 'uv' mode). Divides weights by (sqrt{text{channels_in}}). Note: This interacts with path_norm. Defaults to False.

  • path (list, optional) – Manually specify the coupling paths. If None, all allowed paths are used. Defaults to None.

weight

The learnable weights of the module if internal_weights=True. Shape depends on feature_mode.

Type:

torch.nn.Parameter or None

tp_info_forward

Constant information for the forward pass computation.

Type:

TensorProductInfo

tp_info_backward1

Constant information for the backward pass w.r.t. input1.

Type:

TensorProductInfo

tp_info_backward2

Constant information for the backward pass w.r.t. input2.

Type:

TensorProductInfo

num_paths

Number of coupling paths determined by the irreps.

Type:

int

weight_numel

Total number of elements in the weight tensor.

Type:

int

tp_info_forward: TensorProductInfo
tp_info_backward1: TensorProductInfo
tp_info_backward2: TensorProductInfo
class equitorch.nn.IrrepWiseLinear(irreps, channels_in, channels_out, internal_weights=True, channel_norm: bool = False)

Bases: Module

Irrep-wise linear layer (channel mixing).

Applies a separate linear transformation to the channels associated with each irrep type. This operation does not change the spherical tensor structure (the irreps).

  • Input shape: (..., irreps.dim, channels_in)

  • Weight shape: (num_paths, channels_in, channels_out) where num_paths is the number of unique irreps in irreps.

  • Output shape: (..., irreps.dim, channels_out)

Parameters:
  • irreps (Irreps or str) – Irreducible representations of the input tensor.

  • channels_in (int) – Number of input channels.

  • channels_out (int) – Number of output channels.

  • internal_weights (bool, optional) – If True, the module manages its own weight parameter. If False, weights must be provided during the forward pass. Defaults to True.

  • channel_norm (bool, optional) – If True, divides the output by (sqrt{text{channels_in}}). Defaults to False.

weight

The learnable weights of the module if internal_weights=True.

Type:

torch.nn.Parameter or None

irreps_info

Constant information about the input irreps.

Type:

IrrepsInfo

num_paths

Number of unique irreps in the input.

Type:

int

weight_numel

Total number of elements in the weight tensor.

Type:

int

irreps_info: IrrepsInfo
class equitorch.nn.IrrepsLinear(irreps_in, irreps_out, channels_in, channels_out, internal_weights=True, path_norm=True, channel_norm=False, path=None)

Bases: Module

Equivariant linear layer that preserves the spherical tensor structure but mixes channels.

This layer applies a linear transformation across channels while respecting the equivariance constraints imposed by the input and output irreps. It only allows paths where the input and output irreps are the same ((l_{in} = l_{out}) and ((p_{in} = p_{out}) or (p_{out}=0))). This is often used for channel mixing in equivariant networks.

  • Input shape: (..., irreps_in.dim, channels_in)

  • Weight shape: (num_paths, channels_in, channels_out) where num_paths is the number of allowed paths.

  • Output shape: (..., irreps_out.dim, channels_out)

Parameters:
  • irreps_in (Irreps) – Irreducible representations of the input tensor.

  • irreps_out (Irreps) – Irreducible representations of the output tensor.

  • channels_in (int) – Number of input channels.

  • channels_out (int) – Number of output channels.

  • internal_weights (bool, optional) – If True, the module manages its own weight parameter. If False, weights must be provided during the forward pass. Defaults to True.

  • path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to True.

  • channel_norm (bool, optional) – Whether to apply channel normalization. Divides weights by (sqrt{text{channels_in}}). Defaults to False.

  • path (list, optional) – Manually specify the coupling paths. If None, all allowed paths are used. Defaults to None.

weight

The learnable weights of the module if internal_weights=True.

Type:

torch.nn.Parameter or None

forward_info

Constant information for the forward pass computation.

Type:

IrrepsLinearInfo

backward_info

Constant information for the backward pass computation.

Type:

IrrepsLinearInfo

num_paths

Number of allowed coupling paths.

Type:

int

weight_numel

Total number of elements in the weight tensor.

Type:

int

forward_info: IrrepsLinearInfo
backward_info: IrrepsLinearInfo
class equitorch.nn.SO2Linear(irreps_in, irreps_out, channels_in=None, channels_out=None, internal_weights=True, feature_mode='uu', path_norm=True, channel_norm=False, path=None)

Bases: Module

SO(2) equivariant linear layer using tensor products.

This layer applies an SO(2) equivariant linear transformation, as proposed in Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs. It supports two main modes controlled by feature_mode:

  • 'uv': Fully connected linear layer.

    • Input shape: (..., irreps_in.dim, channels_in)

    • Weight shape: (num_weights, channels_in, channels_out)

    • Output shape: (..., irreps_out.dim, channels_out)

  • 'uu': Depthwise/elementwise linear layer.

    • Input shape: (..., irreps_in.dim, channels)

    • Weight shape: (num_weights, channels_out)

    • Output shape: (..., irreps_out.dim, channels_out)

Parameters:
  • irreps_in (Irreps or str) – Irreducible representations of the input tensor.

  • irreps_out (Irreps or str) – Irreducible representations of the output tensor.

  • channels_in (int, optional) – Number of channels for the input. Required if internal_weights=True.

  • channels_out (int, optional) – Number of channels for the output. Required if internal_weights=True.

  • internal_weights (bool, optional) – If True, the module manages its own weight parameter. If False, weights must be provided during the forward pass. Defaults to True.

  • feature_mode (str, optional) –

    Controls the type of linear operation: {'uu', 'uv'}. Defaults to 'uu'.

    • 'uu': Depthwise/elementwise linear. Assumes channels_in == channels_out.

    • 'uv': Fully connected linear.

  • path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to True.

  • channel_norm (bool, optional) – Whether to apply channel normalization (specific to 'uv' mode). Divides weights by (sqrt{text{channels_in}}). Note: This interacts with path_norm. Defaults to False.

  • path (list, optional) – Manually specify the coupling paths. If None, all allowed paths are used. Defaults to None.

weight

The learnable weights of the module if internal_weights=True. Shape depends on feature_mode.

Type:

torch.nn.Parameter or None

info_forward

Constant information for the forward pass computation.

Type:

SparseProductInfo

info_backward1

Constant information for the first backward pass.

Type:

SparseProductInfo

info_backward2

Constant information for the second backward pass.

Type:

SparseProductInfo

num_paths

Number of coupling paths determined by the irreps.

Type:

int

weight_numel

Total number of elements in the weight tensor.

Type:

int

class equitorch.nn.SplitIrreps(irreps: Irreps, split_num_irreps: Iterable[int], dim: int = -2)

Bases: Module

A module that splits input tensors according to specified irrep segments.

The splitting is done based on the dimensions of the irreducible representations.

Parameters:
  • irreps (Irreps) – The irreducible representations specification.

  • split_num_irreps (Iterable[int]) – Number of irreps in each split segment. Must sum to total irreps. May contain at most one -1 or ... to represent the remaining irreps.

  • dim (int, optional) – The dimension along which to split the input tensor. Defaults to -2.

Examples

>>> # Split 3 scalar irreps (dim=1 each) and 2 vector irreps (dim=3 each)
>>> irreps = Irreps("3x0e + 2x1o")  # 3 scalars + 2 vectors
>>> split = SplitIrreps(irreps, [2, -1])  # Split first 2 irreps, then remaining
>>> x = torch.randn(5, irreps.dim, 10)  # batch=5, dim=9 (3*1 + 2*3), channels=10
>>> splits = split(x)  # Returns list of 2 tensors
>>> splits[0].shape
torch.Size([5, 2, 10])
>>> splits[1].shape
torch.Size([5, 7, 10])
>>> # Using ... for automatic size calculation
>>> split = SplitIrreps(irreps, [1, ...])  # First irrep, then remaining
>>> splits = split(x)
>>> splits[0].shape
torch.Size([5, 1, 10])
>>> splits[1].shape
torch.Size([5, 8, 10])
class equitorch.nn.Separable(irreps: Irreps, split_num_irreps: Iterable[int], sub_modules: Iterable[Module], cat_after: bool = True, dim: int = -2)

Bases: Module

A module that applies different transformations to different parts of input tensor according to irreducible representations (irreps), with optional concatenation.

Parameters:
  • irreps (Irreps) – The irreducible representations specification for the input tensor.

  • split_num_irreps (Iterable[int]) – Number of irreps in each split segment. May contain at most one -1 or ... to represent remaining irreps. Length must match length of sub_modules.

  • sub_modules (Iterable[Callable]) – Transformation modules for each split segment. Use None for identity operation.

  • cat_after (bool, optional) – Whether to concatenate results after transformation. Defaults to True.

  • dim (int, optional) – The dimension along which to split and concatenate tensors. Defaults to -2.

Raises:

ValueError

  • If lengths of split_num_irreps and sub_modules don’t match. - If sum of split_num_irreps doesn’t match total irreps. - If invalid split_num_irreps specification.

class equitorch.nn.SphericalHarmonics(l_max: int, normalize_input: bool = True, integral_normalize: bool = False)

Bases: Module

Computes spherical harmonics from input Cartesian coordinates. Wraps the functional spherical_harmonics().

Spherical harmonics are a set of orthogonal functions defined on the surface of a sphere. They are solutions to Laplace’s equation in spherical coordinates.

If integral_normalize is True, the output is scaled by (1 / sqrt{4pi}).

Parameters:
  • l_max (int) – The maximum degree of the spherical harmonics.

  • normalize_input (bool, optional) – If True, normalizes the input xyz vector before computing spherical harmonics. Defaults to True.

  • integral_normalize (bool, optional) – If True, applies normalization for integration over the sphere. Defaults to False.

class equitorch.nn.XYZToSpherical(normalize_input: bool = True, with_r: bool = False, eps: float = 1e-14, dim: int = -1)

Bases: Module

Module to convert Cartesian coordinates ((x, y, z)) to spherical ((theta, phi, r)). Wraps the functional xyz_to_spherical().

Parameters:
  • normalize_input (bool, optional) – If True, normalizes the input xyz vector before computing angles. Defaults to True.

  • with_r (bool, optional) – If True, returns (r) along with (theta) and (phi). Defaults to False.

  • eps (float, optional) – Small (epsilon) for numerical stability. Defaults to 1e-14.

  • dim (int, optional) – Dimension of Cartesian coordinates. Defaults to -1.

class equitorch.nn.SphericalToXYZ(dim: int = -1)

Bases: Module

Module to convert spherical coordinates ((theta, phi, r)) to Cartesian ((x, y, z)). Wraps the functional spherical_to_xyz().

Parameters:

dim (int, optional) – Dimension along which to stack output (x,y,z). Defaults to -1.

class equitorch.nn.XYZToSinCos(max_m: int, normalize_input: bool = True, component_normalize: bool = False, eps: float = 1e-14, dim: int = -1)

Bases: Module

Module to convert Cartesian coordinates ((x, y, z)) to sin/cos embeddings of the spherical angles (theta) and (phi). Wraps the functional xyz_to_sincos().

Parameters:
  • max_m (int) – The maximum multiple of the angles to compute (sin) / (cos) for.

  • normalize_input (bool, optional) – If True, normalizes xyz before extracting angles. Defaults to True.

  • component_normalize – (bool, optional): False

  • eps (float, optional) – Small (epsilon) for numerical stability. Defaults to 1e-14.

  • dim (int, optional) – Dimension of Cartesian coordinates in xyz. Defaults to -1.

class equitorch.nn.BatchRMSNorm(irreps: Irreps, channels: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, scaled: bool = True)

Bases: Module

Applies Batch Root Mean Square Normalization for equivariant features.

\[x'_{nimc} = \gamma_{ic} \cdot (x_{nimc} / \sigma_{ic})\]

where

\[\sigma_{ic} = \sqrt{E[\text{SquaredNorm}(x_{nic})] + \epsilon}\]

The SquaredNorm can be scaled by (1/text{irrep}_itext{.dim}) depending on the scaled argument. Running statistics are used during evaluation.

Parameters:
  • irreps (Irreps) – Irreducible representations of the input tensor.

  • channels (int) – Number of channels in the input tensor (size of the last dimension).

  • eps (float, optional) – A value added to the denominator for numerical stability. (epsilon) Defaults to 1e-5.

  • momentum (float, optional) – The value used for the running_mean computation. Defaults to 0.1.

  • affine (bool, optional) – If True, this module has learnable affine parameters (weight (gamma_{ic})). Defaults to True.

  • scaled (bool, optional) – If True, the SquaredNorm used for calculating statistics is scaled by (1/text{irrep}_itext{.dim}). Defaults to True.

irreps_info: IrrepsInfo
reset_running_stats() None
reset_parameters() None
class equitorch.nn.LayerRMSNorm(irreps: Irreps, channels: int, eps: float = 1e-05, affine: bool = True, scaled: bool = True)

Bases: Module

Applies Irrep-wise Layer Root Mean Square Normalization.

Computes statistics independently for each irrep instance within each sample. Normalizes using the RMS value calculated across channels and irrep components for that specific sample and irrep instance.

Parameters:
  • irreps (Irreps) – Irreducible representations of the input tensor.

  • channels (int) – Number of channels in the input tensor (size of the last dimension).

  • eps (float, optional) – A value added to the denominator for numerical stability. (epsilon) Defaults to 1e-5.

  • affine (bool, optional) – If True, this module has learnable affine parameters (weight (gamma_{ic})). Defaults to True.

  • scaled (bool, optional) – If True, the statistics calculation considers the norm to be scaled by (1/text{irrep}_itext{.dim}). Defaults to True.

irreps_info: IrrepsInfo
reset_parameters() None
equitorch.nn.initialize_tensor_product(weight, feature_mode, gain=1, channel_normed=False)

Initialize weights for tensor product operations.

This function initializes weights for tensor product operations with different feature modes. The initialization uses a uniform distribution with bounds calculated based on the feature mode and whether channel normalization is used.

For ‘uvw’ mode:

\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]

where \(\text{fan_in} = \text{weight.shape[-2]} \cdot \text{weight.shape[-3]}\)

For ‘uuu’ mode:

\[a = \sqrt{3} \cdot \text{gain}\]
Parameters:
  • weight (torch.Tensor) – The weight tensor to initialize.

  • feature_mode (str) – The feature mode for initialization. Must be one of [‘uvw’, ‘uuu’].

  • gain (float, optional) – The gain factor to apply. Default is 1.

  • channel_normed (bool, optional) – Whether channel normalization is used. Default is False.

Raises:

ValueError – If an unknown feature_mode is provided.

equitorch.nn.initialize_so3_so2_linear(weight, feature_mode, gain=1, channel_normed=False)

Initialize weights for SO(3) or SO(2) linear operations.

This function initializes weights for SO(3) or SO(2) linear operations with different feature modes. The initialization uses a uniform distribution with bounds calculated based on the feature mode and whether channel normalization is used.

For ‘uv’ mode:

\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]

where \(\text{fan_in} = \text{weight.shape[-2]}\)

For ‘uu’ mode:

\[a = \sqrt{3} \cdot \text{gain}\]
Parameters:
  • weight (torch.Tensor) – The weight tensor to initialize.

  • feature_mode (str) – The feature mode for initialization. Must be one of [‘uv’, ‘uu’].

  • gain (float, optional) – The gain factor to apply. Default is 1.

  • channel_normed (bool, optional) – Whether channel normalization is used. Default is False.

Raises:

ValueError – If an unknown feature_mode is provided.

equitorch.nn.initialize_linear(weight, gain=1, channel_normed=False)

Initialize weights for standard linear operations.

This function initializes weights for standard linear operations using a uniform distribution with bounds calculated based on whether channel normalization is used.

\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]

where \(\text{fan_in} = \text{weight.shape[-2]}\)

Parameters:
  • weight (torch.Tensor) – The weight tensor to initialize.

  • gain (float, optional) – The gain factor to apply. Default is 1.

  • channel_normed (bool, optional) – Whether channel normalization is used. Default is False.

class equitorch.nn.TensorProduct(irreps_in1, irreps_in2, irreps_out, channels_in1=None, channels_in2=None, channels_out=None, internal_weights=True, feature_mode='uuu', path_norm=True, channel_norm=False, path=None)

Bases: Module

Computes the tensor product of two equivariant feature tensors.

Supports two main modes controlled by feature_mode:

  • 'uvw': Fully connected tensor product.

    • Input1 shape: (..., irreps_in1.dim, channels_in1)

    • Input2 shape: (..., irreps_in2.dim, channels_in2)

    • Weight shape: (num_paths, channels_in1, channels_in2, channels_out)

    • Output shape: (..., irreps_out.dim, channels_out)

  • 'uuu': Depthwise/elementwise tensor product. with uuu instructions (often used for self-interaction).

    • Input1 shape: (..., irreps_in1.dim, channels)

    • Input2 shape: (..., irreps_in2.dim, channels)

    • Weight shape: (num_paths, channels_out) (where channels_out usually equals channels)

    • Output shape: (..., irreps_out.dim, channels_out)

Parameters:
  • irreps_in1 (Irreps) – Irreducible representations of the first input tensor.

  • irreps_in2 (Irreps) – Irreducible representations of the second input tensor.

  • irreps_out (Irreps) – Irreducible representations of the output tensor.

  • channels_in1 (int, optional) – Number of channels for the first input. Required if internal_weights=True or feature_mode='uvw'.

  • channels_in2 (int, optional) – Number of channels for the second input. Required if internal_weights=True or feature_mode='uvw'.

  • channels_out (int, optional) – Number of channels for the output. Required if internal_weights=True.

  • internal_weights (bool, default=True) – If True, the module manages its own weight parameter. If False, weights must be provided during the forward pass.

  • feature_mode ({'uuu', 'uvw'}, default='uuu') –

    Controls the type of tensor product:

    • 'uuu': Depthwise/elementwise product. Assumes channels_in1 == channels_in2 == channels_out.

    • 'uvw': Fully connected product.

  • path_norm (bool, default=True) – Whether to apply path normalization to the weights.

  • channel_norm (bool, default=False) – Whether to apply channel normalization (specific to 'uvw' mode). Divides weights by \(\sqrt{\text{channels_in1} \times \text{channels_in2}}\).

  • path (list, optional) – Manually specify the coupling paths. If None, all allowed paths are used.

weight

The learnable weights of the module if internal_weights=True. Shape depends on feature_mode.

Type:

torch.nn.Parameter or None

tp_info_forward

Constant information for the forward pass computation.

Type:

TensorProductInfo

tp_info_backward1

Constant information for the backward pass w.r.t. input1.

Type:

TensorProductInfo

tp_info_backward2

Constant information for the backward pass w.r.t. input2.

Type:

TensorProductInfo

num_paths

Number of coupling paths determined by the irreps.

Type:

int

weight_numel

Total number of elements in the weight tensor.

Type:

int

tp_info_forward: TensorProductInfo
tp_info_backward1: TensorProductInfo
tp_info_backward2: TensorProductInfo
class equitorch.nn.TensorDot(irreps, feature_mode, scaled=False)

Bases: Module

Computes the equivariant irrep-wise dot product of two feature tensors.

Supports two main modes controlled by feature_mode:

  • 'uv': Channel-cartesian dot product.

    • Input1 shape: (..., irreps.dim, channels1)

    • Input2 shape: (..., irreps.dim, channels2)

    • Output shape: (..., len(irreps), channels1, channels2)

  • 'uu': Channel-wise dot product. Sums over the channel dimension after the dot product.

    • Input1 shape: (..., irreps.dim, channels)

    • Input2 shape: (..., irreps.dim, channels)

    • Output shape: (..., len(irreps), channels)

Parameters:
  • irreps (Irreps or str) – Irreducible representations of the input tensors. Both inputs must have the same irreps.

  • feature_mode ({'uv', 'uu'}) –

    Controls how the channel dimension is handled:

    • 'uv': Channel-cartesian dot product.

    • 'uu': Channel-wise dot product.

  • scaled (bool, default=False) – If True, scales the dot product by \(1 / \sqrt{\text{irrep.dim}}\).

irreps_info

Constant information about the input irreps.

Type:

IrrepsInfo

irreps_info: IrrepsInfo
class equitorch.nn.SparseWignerRotation(irreps: Irreps)

Bases: Module

Applies a sparse Wigner D-matrix rotation to input features.

This module computes the rotation based on Euler angles (\(\alpha, \beta, \gamma\)) provided as precomputed sin/cos tensors. It utilizes sparse matrix operations for the rotation.

Warning

It is currently suggested to use DenseWignerRotation or WignerD for applying rotations, at least when gradients with respect to angles are not required.

The Wigner D-matrix \(D^l_{m'm}(R)\) transforms spherical tensors under rotation \(R\). This module applies such transformations for a given Irreps.

Parameters:

irreps (Irreps) – The irreducible representations defining the input and output feature space.

info: WignerRotationInfo
class equitorch.nn.DenseWignerRotation(irreps: Irreps)

Bases: Module

Applies a dense Wigner D-matrix rotation to input features.

This module takes a precomputed dense Wigner D-matrix and applies it to the input features. The Wigner D-matrix \(D(R)\) itself should be computed separately, for example, using the WignerD module.

Parameters:

irreps (Irreps) – The irreducible representations defining the input and output feature space. This is used for validation and representation purposes.

class equitorch.nn.WignerD(irreps: Irreps)

Bases: Module

Computes the dense Wigner D-matrix \(D(R)\) for given Irreps and Euler angles \((\alpha, \beta, \gamma)\).

The Wigner D-matrix is constructed based on the ZYZ Euler angle convention:

\[D(\alpha, \beta, \gamma) = D_z(\alpha) D_y(\beta) D_z(\gamma)\]

This module caches the necessary sparse rotation information and an identity matrix to efficiently compute the dense D-matrix using the wigner_d_matrix() functional.

Parameters:

irreps (Irreps) – The irreducible representations for which to compute the D-matrix. The resulting D-matrix will have dimensions (irreps.dim, irreps.dim).

info: WignerRotationInfo
class equitorch.nn.AlignToZWignerD(irreps: Irreps, normalized: bool = True, eps: float = 1e-14)

Bases: Module

Computes the Wigner D-matrix \(D(R_{align})\) that rotates a given vector \(\vec{v} = (x, y, z)\) onto the z-axis.

The rotation \(R_{align}\) is defined by Euler angles \((0, -\theta, -\phi)\), where \(\theta\) and \(\phi\) are the polar and azimuthal angles of the vector \(\vec{v}\), respectively. This means:

\[R_{align} \vec{v} = ||\vec{v}|| \hat{z}\]

The Wigner D-matrix is then \(D(0, -\theta, -\phi)\).

This module caches the necessary sparse rotation information and an identity matrix. It utilizes the align_to_z_wigner_d() functional.

Parameters:
  • irreps (Irreps) – The irreducible representations for which to compute the D-matrix.

  • normalized (bool, optional) – Whether to normalize the input xyz vector before calculating angles for rotation. If True, effectively rotates \(\hat{v}\). Defaults to True.

  • eps (float, optional) – Small \(\epsilon\) value for numerical stability in angle calculation. Defaults to 1e-14.

info: WignerRotationInfo
class equitorch.nn.PolynomialCutoff(r_max: float, r_min: float = 0.0, p: float = 6)

Bases: Module

Polynomial cutoff, as proposed in DimeNet.

The polynomial cutoff function is defined as:

\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ 1 - \frac{(p+1)(p+2)}{2}u^p + p(p+2)u^{p+1} - \frac{p(p+1)}{2}u^{p+2}, & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]

where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.

Parameters:
  • r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.

  • r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than or equal to r_max. Defaults to 0..

  • p (float, optional) – The power parameter (p) controlling the smoothness of the cutoff. Must be greater than or equal to 2.0. Defaults to 6..

class equitorch.nn.CosineCutoff(r_max: float, r_min: float = 0)

Bases: Module

The cosine cutoff function.

The cosine cutoff function is defined as:

\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ \frac{1}{2}\left[1 + \cos\left(\pi \cdot u\right)\right], & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]

where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.

This cutoff function smoothly decreases from 1 to 0 in the range [r_{text{min}}, r_{text{max}}] using a cosine function.

Parameters:
  • r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.

  • r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than r_max. Defaults to 0..

class equitorch.nn.MollifierCutoff(r_max: float, r_min: float = 0, eps: float = 1e-07)

Bases: Module

The mollifier cutoff function.

The mollifier cutoff function is defined as:

\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ \exp\left(1 - \frac{1}{1 - u^2 + \epsilon}\right), & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]

where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.

This cutoff function smoothly decreases from 1 to 0 in the range [r_{text{min}}, r_{text{max}}] using a mollifier (bump) function.

Parameters:
  • r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.

  • r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than r_max. Defaults to 0..

  • eps (float, optional) – Small epsilon value (epsilon) to prevent division by zero. Defaults to 1e-7.

class equitorch.nn.Gate(irreps: Irreps, activation: Module = None, irrep_wise: bool = True)

Bases: Module

Applies element-wise gates to equivariant features.

This module implements gating nonlinearities for features represented by Irreps. It can operate in two primary modes based on how the gate values are provided:

  1. Separate Gates (``gate`` argument provided): The module takes two distinct inputs: input (the features to be gated) and gate (the gate scalars).

    • If irrep_wise=True (default): Each gate scalar in the gate tensor is applied to its corresponding irrep block within the input features. The gate tensor should have a shape compatible with (..., num_gates, channels), where num_gates is the number of irreps in irreps.

    • If irrep_wise=False: A single gate scalar (or a set of scalars broadcastable across irreps) is applied to all irrep blocks in the input features. The gate tensor should have a shape compatible with (..., 1, channels).

  2. Concatenated Input (``gate=None``): The module takes a single input tensor where the features and their corresponding gate scalars are concatenated along the spherical dimension (dim=-2). The last num_gates slices along this dimension are interpreted as the gate scalars. The input tensor shape is expected to be (..., irreps.dim + num_gates, channels). The module internally splits this tensor into features and gates, optionally applies an activation function to the extracted gates, and then proceeds with the gating operation as described in mode 1.

An optional activation function can be applied to the gate scalars before they modulate the features.

Example

import torch
from equitorch.irreps import Irreps
from equitorch.nn import Gate

irreps = Irreps("1x0e + 2x1o") # Example: one scalar, two l=1 odd irreps
gate_module = Gate(irreps, activation=torch.nn.Tanh())

batch_size, channels = 4, 8
num_gates = len(irreps) # This will be 2 for the example irreps

# Mode 1: Separate input and gate tensors (irrep_wise=True)
features = torch.randn(batch_size, irreps.dim, channels)
gates = torch.randn(batch_size, num_gates, channels)
output_separate = gate_module(features, gates)
print(f"Output shape (separate gates): {output_separate.shape}")

# Mode 2: Concatenated input tensor
# irreps.dim for "1x0e + 2x1o" is 1*1 + 2*3 = 7
# num_gates is 2
concatenated_input = torch.randn(batch_size, irreps.dim + num_gates, channels)
output_concatenated = gate_module(concatenated_input)
print(f"Output shape (concatenated input): {output_concatenated.shape}")
Parameters:
  • irreps (Irreps) – The irreducible representations of the feature part of the input tensor (i.e., the part that will be gated).

  • activation (torch.nn.Module, optional) – An activation function to be applied to the gate scalars before the gating operation. Defaults to None (no activation).

  • irrep_wise (bool, optional) – Determines how gates are applied. If True (default), gates are applied irrep-by-irrep. This requires the gate tensor (if provided separately) to have a shape like (..., num_gates, channels). If False, a single gate (or a broadcastable set) is applied across all irreps. This requires the gate tensor (if provided separately) to have a shape like (..., 1, channels). num_gates corresponds to len(irreps).

irreps_info

Cached information about the input feature irreps, used for efficient gating.

Type:

IrrepsInfo

num_gates

The number of distinct gate scalars, equal to len(irreps). This dictates the expected size of the gate dimension in the gate tensor or the number of gate slices in a concatenated input.

Type:

int

irreps_info: IrrepsInfo
class equitorch.nn.SinCos(max_m: int, with_ones: bool = True, component_normalize: bool = False)

Bases: Module

Module wrapper for the sincos() function.

Computes the sin/cos expansion of an angle (a):

\[[1.0, \sin(a), \cos(a), \sin(2a), \cos(2a), \dots, \sin(\text{max_m} \cdot a), \cos(\text{max_m} \cdot a)]\]

or

\[[1.0, \sqrt{2}\sin(a), \sqrt{2}\cos(a), \sqrt{2}\sin(2a), \sqrt{2}\cos(2a), \dots, \sqrt{2}\sin(\text{max_m} \cdot a), \sqrt{2}\cos(\text{max_m} \cdot a)]\]

The leading 1.0 is excluded if with_ones is False.

Parameters:
  • max_m (int) – The maximum multiple of the angle (a) to compute (sin) and (cos) for.

  • with_ones (bool, optional) – Whether to include the leading 1.0 in the expansion. Defaults to True.

  • component_normalize (bool, optional) – If True, multiplies the (sin) and (cos) values by (sqrt{2}) such that the expectation of the squared norm over ([0, 2pi]) is 1. Defaults to False.

class equitorch.nn.BesselBasis(r_max, num_basis=8, trainable=True)

Bases: Module

__init__(r_max, num_basis=8, trainable=True)

Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123

Parameters:
  • r_max (float) – Cutoff radius

  • num_basis (int) – Number of Bessel Basis functions

  • trainable (bool) – Train the \(n \pi\) part or not.

r_max: float
prefactor: float
class equitorch.nn.SquaredNorm(irreps: Irreps, scaled: bool = True)

Bases: Module

Computes the squared L2 norm for each irrep block in an input tensor.

\[\text{Output}_k = \sum_{m \in \text{irrep}_k} (\text{input}_{km}^2)\]

Optionally scales the output by \(1/\text{irrep}_k\text{.dim}\).

Parameters:
  • irreps (Irreps) – Irreducible representations of the input tensor.

  • scaled (bool, optional) – If True, scales the output of each irrep_k by \(1/\text{irrep}_k\text{.dim}\). Defaults to True.

irreps_info: IrrepsInfo
class equitorch.nn.Norm(irreps: Irreps, scaled: bool = True)

Bases: Module

Computes the L2 norm for each irrep block in an input tensor.

\[\text{Output}_k = \sqrt{\sum_{m \in \text{irrep}_k} (\text{input}_{km}^2)}\]

Optionally scales the output by \(\sqrt{1/\text{irrep}_k\text{.dim}}\). Gradient at zero vector is zero.

Parameters:
  • irreps (Irreps) – Irreducible representations of the input tensor.

  • scaled (bool, optional) – If True, scales the output of each irrep_k by \(\sqrt{1/\text{irrep}_k\text{.dim}}\). Defaults to True.

irreps_info: IrrepsInfo
class equitorch.nn.MeanSquaredNorm(irreps: Irreps, scaled: bool = True, dim: Literal[0, -1, 2] = -1)

Bases: Module

Computes the mean of squared L2 norms over a specified dimension (batch or channel) for each irrep block.

If dim=0 (batch mean):

\[\text{Output}_{ic} = \frac{1}{N} \sum_n \left( \sum_{m \in \text{irrep}_i} (\text{input}_n(im)c^2) \right)\]

If dim=-1 (channel mean):

\[\text{Output}_{ni} = \frac{1}{C} \sum_c \left( \sum_{m \in \text{irrep}_i} (\text{input}_n(im)c^2) \right)\]

Optionally scales the output by \(1/\text{irrep}_i\text{.dim}\).

Parameters:
  • irreps (Irreps) – Irreducible representations of the input tensor.

  • scaled (bool, optional) – If True, scales the output of each irrep_i by \(1/\text{dim}_\text{irrep_i}\). Defaults to True.

  • dim (int, optional) – Dimension over which to compute the mean. Allowed values: 0 (batch), -1 or 2 (channel). Defaults to -1.

irreps_info: IrrepsInfo
class equitorch.nn.Dropout(p: float = 0.5, irreps: Irreps = None, irrep_wise: bool = True, work_on_eval: bool = False)

Bases: Module

Apply dropout to equivariant features.

Can operate irrep-wise or on the entire feature vector (channel-wise).

Parameters:
  • p (float, optional) – Probability of an element to be zeroed. Default: 0.5

  • irreps (Irreps, optional) – Irreps of the input tensor. Required if irrep_wise is True. Default: None

  • irrep_wise (bool, optional) – If True, applies dropout independently for each (irrep_instance, channel). If False, applies standard 1D dropout treating (irreps_dim, channels) as a single feature dimension for dropout. Default: True

  • work_on_eval (bool, optional) – If True, dropout is applied even during evaluation. Default: False

class equitorch.nn.AnglesToMatrix

Bases: Module

Module to convert Euler angles (ZYZ convention) to rotation matrices.

The ZYZ Euler angles (alpha, beta, gamma) correspond to the rotation matrix:

\[R(\alpha, \beta, \gamma) = R_z(\alpha) R_y(\beta) R_z(\gamma)\]

which is explicitly:

\[\begin{split}\begin{pmatrix} -\sin(\alpha)\sin(\gamma) + \cos(\alpha)\cos(\beta)\cos(\gamma) & -\sin(\alpha)\cos(\beta)\cos(\gamma) - \sin(\gamma)\cos(\alpha) & \sin(\beta)\cos(\gamma) \\ \sin(\alpha)\cos(\gamma) + \sin(\gamma)\cos(\alpha)\cos(\beta) & -\sin(\alpha)\sin(\gamma)\cos(\beta) + \cos(\alpha)\cos(\gamma) & \sin(\beta)\sin(\gamma) \\ -\sin(\beta)\cos(\alpha) & \sin(\alpha)\sin(\beta) & \cos(\beta) \end{pmatrix}\end{split}\]

Wraps the functional version angles_to_matrix().