equitorch.nn
Contains equivariant neural network modules and functionalities.
- class equitorch.nn.SO3Linear(irreps_in1, irreps_in2, irreps_out, channels_in=None, channels_out=None, internal_weights=True, feature_mode='uu', path_norm=True, channel_norm=False, path=None)
Bases:
ModuleSO(3) equivariant linear layer using tensor products.
Equivalent to a TensorProduct where
input2does not have a channel dimension.Supports two main modes controlled by
feature_mode:'uv': Fully connected in channel dimension.Input1 shape:
(..., irreps_in1.dim, channels_in)Input2 shape:
(..., irreps_in2.dim)Weight shape:
(num_paths, channels_in, channels_out)Output shape:
(..., irreps_out.dim, channels_out)
'uu': Depthwise/elementwise in channel dimension.Input1 shape:
(..., irreps_in1.dim, channels)Input2 shape:
(..., irreps_in2.dim)Weight shape:
(num_paths, channels)Output shape:
(..., irreps_out.dim, channels)
- Parameters:
irreps_in1 (Irreps) – Irreducible representations of the main input tensor (
input1).irreps_in2 (Irreps) – Irreducible representations of the second input tensor (
input2), often representing weights like spherical harmonics.irreps_out (Irreps) – Irreducible representations of the output tensor.
channels_in (int, optional) – Number of channels for the main input (
input1). Required ifinternal_weights=True.channels_out (int, optional) – Number of channels for the output. Required if
internal_weights=True.internal_weights (bool, optional) – If
True, the module manages its own weight parameter. IfFalse, weights must be provided during the forward pass. Defaults toTrue.feature_mode (str, optional) –
Controls the type of linear operation:
{'uu', 'uv'}. Defaults to'uu'.'uu': Depthwise/elementwise linear. Assumeschannels_in == channels_out.'uv': Fully connected linear.
path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to
True.channel_norm (bool, optional) – Whether to apply channel normalization (specific to
'uv'mode). Divides weights by (sqrt{text{channels_in}}). Note: This interacts withpath_norm. Defaults toFalse.path (list, optional) – Manually specify the coupling paths. If
None, all allowed paths are used. Defaults toNone.
- weight
The learnable weights of the module if
internal_weights=True. Shape depends onfeature_mode.- Type:
torch.nn.Parameter or None
- tp_info_forward
Constant information for the forward pass computation.
- Type:
- tp_info_backward1
Constant information for the backward pass w.r.t.
input1.- Type:
- tp_info_backward2
Constant information for the backward pass w.r.t.
input2.- Type:
- num_paths
Number of coupling paths determined by the irreps.
- Type:
int
- weight_numel
Total number of elements in the weight tensor.
- Type:
int
- tp_info_forward: TensorProductInfo
- tp_info_backward1: TensorProductInfo
- tp_info_backward2: TensorProductInfo
- class equitorch.nn.IrrepWiseLinear(irreps, channels_in, channels_out, internal_weights=True, channel_norm: bool = False)
Bases:
ModuleIrrep-wise linear layer (channel mixing).
Applies a separate linear transformation to the channels associated with each irrep type. This operation does not change the spherical tensor structure (the irreps).
Input shape:
(..., irreps.dim, channels_in)Weight shape:
(num_paths, channels_in, channels_out)wherenum_pathsis the number of unique irreps inirreps.Output shape:
(..., irreps.dim, channels_out)
- Parameters:
irreps (Irreps or str) – Irreducible representations of the input tensor.
channels_in (int) – Number of input channels.
channels_out (int) – Number of output channels.
internal_weights (bool, optional) – If
True, the module manages its own weight parameter. IfFalse, weights must be provided during the forward pass. Defaults toTrue.channel_norm (bool, optional) – If
True, divides the output by (sqrt{text{channels_in}}). Defaults toFalse.
- weight
The learnable weights of the module if
internal_weights=True.- Type:
torch.nn.Parameter or None
- irreps_info
Constant information about the input irreps.
- Type:
- num_paths
Number of unique irreps in the input.
- Type:
int
- weight_numel
Total number of elements in the weight tensor.
- Type:
int
- irreps_info: IrrepsInfo
- class equitorch.nn.IrrepsLinear(irreps_in, irreps_out, channels_in, channels_out, internal_weights=True, path_norm=True, channel_norm=False, path=None)
Bases:
ModuleEquivariant linear layer that preserves the spherical tensor structure but mixes channels.
This layer applies a linear transformation across channels while respecting the equivariance constraints imposed by the input and output irreps. It only allows paths where the input and output irreps are the same ((l_{in} = l_{out}) and ((p_{in} = p_{out}) or (p_{out}=0))). This is often used for channel mixing in equivariant networks.
Input shape:
(..., irreps_in.dim, channels_in)Weight shape:
(num_paths, channels_in, channels_out)wherenum_pathsis the number of allowed paths.Output shape:
(..., irreps_out.dim, channels_out)
- Parameters:
irreps_in (Irreps) – Irreducible representations of the input tensor.
irreps_out (Irreps) – Irreducible representations of the output tensor.
channels_in (int) – Number of input channels.
channels_out (int) – Number of output channels.
internal_weights (bool, optional) – If
True, the module manages its own weight parameter. IfFalse, weights must be provided during the forward pass. Defaults toTrue.path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to
True.channel_norm (bool, optional) – Whether to apply channel normalization. Divides weights by (sqrt{text{channels_in}}). Defaults to
False.path (list, optional) – Manually specify the coupling paths. If
None, all allowed paths are used. Defaults toNone.
- weight
The learnable weights of the module if
internal_weights=True.- Type:
torch.nn.Parameter or None
- forward_info
Constant information for the forward pass computation.
- Type:
- backward_info
Constant information for the backward pass computation.
- Type:
- num_paths
Number of allowed coupling paths.
- Type:
int
- weight_numel
Total number of elements in the weight tensor.
- Type:
int
- forward_info: IrrepsLinearInfo
- backward_info: IrrepsLinearInfo
- class equitorch.nn.SO2Linear(irreps_in, irreps_out, channels_in=None, channels_out=None, internal_weights=True, feature_mode='uu', path_norm=True, channel_norm=False, path=None)
Bases:
ModuleSO(2) equivariant linear layer using tensor products.
This layer applies an SO(2) equivariant linear transformation, as proposed in Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs. It supports two main modes controlled by
feature_mode:'uv': Fully connected linear layer.Input shape:
(..., irreps_in.dim, channels_in)Weight shape:
(num_weights, channels_in, channels_out)Output shape:
(..., irreps_out.dim, channels_out)
'uu': Depthwise/elementwise linear layer.Input shape:
(..., irreps_in.dim, channels)Weight shape:
(num_weights, channels_out)Output shape:
(..., irreps_out.dim, channels_out)
- Parameters:
irreps_in (Irreps or str) – Irreducible representations of the input tensor.
irreps_out (Irreps or str) – Irreducible representations of the output tensor.
channels_in (int, optional) – Number of channels for the input. Required if
internal_weights=True.channels_out (int, optional) – Number of channels for the output. Required if
internal_weights=True.internal_weights (bool, optional) – If
True, the module manages its own weight parameter. IfFalse, weights must be provided during the forward pass. Defaults toTrue.feature_mode (str, optional) –
Controls the type of linear operation:
{'uu', 'uv'}. Defaults to'uu'.'uu': Depthwise/elementwise linear. Assumeschannels_in == channels_out.'uv': Fully connected linear.
path_norm (bool, optional) – Whether to apply path normalization to the weights. Normalizes by the square root of the number of paths to each output irrep. Defaults to
True.channel_norm (bool, optional) – Whether to apply channel normalization (specific to
'uv'mode). Divides weights by (sqrt{text{channels_in}}). Note: This interacts withpath_norm. Defaults toFalse.path (list, optional) – Manually specify the coupling paths. If
None, all allowed paths are used. Defaults toNone.
- weight
The learnable weights of the module if
internal_weights=True. Shape depends onfeature_mode.- Type:
torch.nn.Parameter or None
- info_forward
Constant information for the forward pass computation.
- Type:
- info_backward1
Constant information for the first backward pass.
- Type:
- info_backward2
Constant information for the second backward pass.
- Type:
- num_paths
Number of coupling paths determined by the irreps.
- Type:
int
- weight_numel
Total number of elements in the weight tensor.
- Type:
int
- class equitorch.nn.SplitIrreps(irreps: Irreps, split_num_irreps: Iterable[int], dim: int = -2)
Bases:
ModuleA module that splits input tensors according to specified irrep segments.
The splitting is done based on the dimensions of the irreducible representations.
- Parameters:
irreps (Irreps) – The irreducible representations specification.
split_num_irreps (Iterable[int]) – Number of irreps in each split segment. Must sum to total irreps. May contain at most one -1 or
...to represent the remaining irreps.dim (int, optional) – The dimension along which to split the input tensor. Defaults to -2.
Examples
>>> # Split 3 scalar irreps (dim=1 each) and 2 vector irreps (dim=3 each) >>> irreps = Irreps("3x0e + 2x1o") # 3 scalars + 2 vectors >>> split = SplitIrreps(irreps, [2, -1]) # Split first 2 irreps, then remaining >>> x = torch.randn(5, irreps.dim, 10) # batch=5, dim=9 (3*1 + 2*3), channels=10 >>> splits = split(x) # Returns list of 2 tensors >>> splits[0].shape torch.Size([5, 2, 10]) >>> splits[1].shape torch.Size([5, 7, 10])
>>> # Using ... for automatic size calculation >>> split = SplitIrreps(irreps, [1, ...]) # First irrep, then remaining >>> splits = split(x) >>> splits[0].shape torch.Size([5, 1, 10]) >>> splits[1].shape torch.Size([5, 8, 10])
- class equitorch.nn.Separable(irreps: Irreps, split_num_irreps: Iterable[int], sub_modules: Iterable[Module], cat_after: bool = True, dim: int = -2)
Bases:
ModuleA module that applies different transformations to different parts of input tensor according to irreducible representations (irreps), with optional concatenation.
- Parameters:
irreps (Irreps) – The irreducible representations specification for the input tensor.
split_num_irreps (Iterable[int]) – Number of irreps in each split segment. May contain at most one -1 or
...to represent remaining irreps. Length must match length ofsub_modules.sub_modules (Iterable[Callable]) – Transformation modules for each split segment. Use
Nonefor identity operation.cat_after (bool, optional) – Whether to concatenate results after transformation. Defaults to True.
dim (int, optional) – The dimension along which to split and concatenate tensors. Defaults to -2.
- Raises:
ValueError –
If lengths of
split_num_irrepsandsub_modulesdon’t match. - If sum ofsplit_num_irrepsdoesn’t match total irreps. - If invalidsplit_num_irrepsspecification.
- class equitorch.nn.SphericalHarmonics(l_max: int, normalize_input: bool = True, integral_normalize: bool = False)
Bases:
ModuleComputes spherical harmonics from input Cartesian coordinates. Wraps the functional
spherical_harmonics().Spherical harmonics are a set of orthogonal functions defined on the surface of a sphere. They are solutions to Laplace’s equation in spherical coordinates.
If integral_normalize is True, the output is scaled by (1 / sqrt{4pi}).
- Parameters:
l_max (int) – The maximum degree of the spherical harmonics.
normalize_input (bool, optional) – If True, normalizes the input xyz vector before computing spherical harmonics. Defaults to True.
integral_normalize (bool, optional) – If True, applies normalization for integration over the sphere. Defaults to False.
- class equitorch.nn.XYZToSpherical(normalize_input: bool = True, with_r: bool = False, eps: float = 1e-14, dim: int = -1)
Bases:
ModuleModule to convert Cartesian coordinates ((x, y, z)) to spherical ((theta, phi, r)). Wraps the functional
xyz_to_spherical().- Parameters:
normalize_input (bool, optional) – If True, normalizes the input xyz vector before computing angles. Defaults to True.
with_r (bool, optional) – If True, returns (r) along with (theta) and (phi). Defaults to False.
eps (float, optional) – Small (epsilon) for numerical stability. Defaults to 1e-14.
dim (int, optional) – Dimension of Cartesian coordinates. Defaults to -1.
- class equitorch.nn.SphericalToXYZ(dim: int = -1)
Bases:
ModuleModule to convert spherical coordinates ((theta, phi, r)) to Cartesian ((x, y, z)). Wraps the functional
spherical_to_xyz().- Parameters:
dim (int, optional) – Dimension along which to stack output (x,y,z). Defaults to -1.
- class equitorch.nn.XYZToSinCos(max_m: int, normalize_input: bool = True, component_normalize: bool = False, eps: float = 1e-14, dim: int = -1)
Bases:
ModuleModule to convert Cartesian coordinates ((x, y, z)) to sin/cos embeddings of the spherical angles (theta) and (phi). Wraps the functional
xyz_to_sincos().- Parameters:
max_m (int) – The maximum multiple of the angles to compute (sin) / (cos) for.
normalize_input (bool, optional) – If True, normalizes xyz before extracting angles. Defaults to True.
component_normalize – (bool, optional): False
eps (float, optional) – Small (epsilon) for numerical stability. Defaults to 1e-14.
dim (int, optional) – Dimension of Cartesian coordinates in xyz. Defaults to -1.
- class equitorch.nn.BatchRMSNorm(irreps: Irreps, channels: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, scaled: bool = True)
Bases:
ModuleApplies Batch Root Mean Square Normalization for equivariant features.
\[x'_{nimc} = \gamma_{ic} \cdot (x_{nimc} / \sigma_{ic})\]where
\[\sigma_{ic} = \sqrt{E[\text{SquaredNorm}(x_{nic})] + \epsilon}\]The
SquaredNormcan be scaled by (1/text{irrep}_itext{.dim}) depending on thescaledargument. Running statistics are used during evaluation.- Parameters:
irreps (Irreps) – Irreducible representations of the input tensor.
channels (int) – Number of channels in the input tensor (size of the last dimension).
eps (float, optional) – A value added to the denominator for numerical stability. (epsilon) Defaults to 1e-5.
momentum (float, optional) – The value used for the running_mean computation. Defaults to 0.1.
affine (bool, optional) – If True, this module has learnable affine parameters (weight (gamma_{ic})). Defaults to True.
scaled (bool, optional) – If True, the
SquaredNormused for calculating statistics is scaled by (1/text{irrep}_itext{.dim}). Defaults to True.
- irreps_info: IrrepsInfo
- reset_running_stats() None
- reset_parameters() None
- class equitorch.nn.LayerRMSNorm(irreps: Irreps, channels: int, eps: float = 1e-05, affine: bool = True, scaled: bool = True)
Bases:
ModuleApplies Irrep-wise Layer Root Mean Square Normalization.
Computes statistics independently for each irrep instance within each sample. Normalizes using the RMS value calculated across channels and irrep components for that specific sample and irrep instance.
- Parameters:
irreps (Irreps) – Irreducible representations of the input tensor.
channels (int) – Number of channels in the input tensor (size of the last dimension).
eps (float, optional) – A value added to the denominator for numerical stability. (epsilon) Defaults to 1e-5.
affine (bool, optional) – If True, this module has learnable affine parameters (weight (gamma_{ic})). Defaults to True.
scaled (bool, optional) – If True, the statistics calculation considers the norm to be scaled by (1/text{irrep}_itext{.dim}). Defaults to True.
- irreps_info: IrrepsInfo
- reset_parameters() None
- equitorch.nn.initialize_tensor_product(weight, feature_mode, gain=1, channel_normed=False)
Initialize weights for tensor product operations.
This function initializes weights for tensor product operations with different feature modes. The initialization uses a uniform distribution with bounds calculated based on the feature mode and whether channel normalization is used.
For ‘uvw’ mode:
\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]where \(\text{fan_in} = \text{weight.shape[-2]} \cdot \text{weight.shape[-3]}\)
For ‘uuu’ mode:
\[a = \sqrt{3} \cdot \text{gain}\]- Parameters:
weight (torch.Tensor) – The weight tensor to initialize.
feature_mode (str) – The feature mode for initialization. Must be one of [‘uvw’, ‘uuu’].
gain (float, optional) – The gain factor to apply. Default is 1.
channel_normed (bool, optional) – Whether channel normalization is used. Default is False.
- Raises:
ValueError – If an unknown feature_mode is provided.
- equitorch.nn.initialize_so3_so2_linear(weight, feature_mode, gain=1, channel_normed=False)
Initialize weights for SO(3) or SO(2) linear operations.
This function initializes weights for SO(3) or SO(2) linear operations with different feature modes. The initialization uses a uniform distribution with bounds calculated based on the feature mode and whether channel normalization is used.
For ‘uv’ mode:
\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]where \(\text{fan_in} = \text{weight.shape[-2]}\)
For ‘uu’ mode:
\[a = \sqrt{3} \cdot \text{gain}\]- Parameters:
weight (torch.Tensor) – The weight tensor to initialize.
feature_mode (str) – The feature mode for initialization. Must be one of [‘uv’, ‘uu’].
gain (float, optional) – The gain factor to apply. Default is 1.
channel_normed (bool, optional) – Whether channel normalization is used. Default is False.
- Raises:
ValueError – If an unknown feature_mode is provided.
- equitorch.nn.initialize_linear(weight, gain=1, channel_normed=False)
Initialize weights for standard linear operations.
This function initializes weights for standard linear operations using a uniform distribution with bounds calculated based on whether channel normalization is used.
\[\begin{split}a = \begin{cases} \sqrt{3} \cdot \text{gain}, & \text{if channel_normed} = \text{True} \\ \sqrt{\frac{3}{\text{fan_in}}} \cdot \text{gain}, & \text{otherwise} \end{cases}\end{split}\]where \(\text{fan_in} = \text{weight.shape[-2]}\)
- Parameters:
weight (torch.Tensor) – The weight tensor to initialize.
gain (float, optional) – The gain factor to apply. Default is 1.
channel_normed (bool, optional) – Whether channel normalization is used. Default is False.
- class equitorch.nn.TensorProduct(irreps_in1, irreps_in2, irreps_out, channels_in1=None, channels_in2=None, channels_out=None, internal_weights=True, feature_mode='uuu', path_norm=True, channel_norm=False, path=None)
Bases:
ModuleComputes the tensor product of two equivariant feature tensors.
Supports two main modes controlled by
feature_mode:'uvw': Fully connected tensor product.Input1 shape:
(..., irreps_in1.dim, channels_in1)Input2 shape:
(..., irreps_in2.dim, channels_in2)Weight shape:
(num_paths, channels_in1, channels_in2, channels_out)Output shape:
(..., irreps_out.dim, channels_out)
'uuu': Depthwise/elementwise tensor product. withuuuinstructions (often used for self-interaction).Input1 shape:
(..., irreps_in1.dim, channels)Input2 shape:
(..., irreps_in2.dim, channels)Weight shape:
(num_paths, channels_out)(wherechannels_outusually equalschannels)Output shape:
(..., irreps_out.dim, channels_out)
- Parameters:
irreps_in1 (Irreps) – Irreducible representations of the first input tensor.
irreps_in2 (Irreps) – Irreducible representations of the second input tensor.
irreps_out (Irreps) – Irreducible representations of the output tensor.
channels_in1 (int, optional) – Number of channels for the first input. Required if
internal_weights=Trueorfeature_mode='uvw'.channels_in2 (int, optional) – Number of channels for the second input. Required if
internal_weights=Trueorfeature_mode='uvw'.channels_out (int, optional) – Number of channels for the output. Required if
internal_weights=True.internal_weights (bool, default=True) – If
True, the module manages its own weight parameter. IfFalse, weights must be provided during the forward pass.feature_mode ({'uuu', 'uvw'}, default='uuu') –
Controls the type of tensor product:
'uuu': Depthwise/elementwise product. Assumeschannels_in1 == channels_in2 == channels_out.'uvw': Fully connected product.
path_norm (bool, default=True) – Whether to apply path normalization to the weights.
channel_norm (bool, default=False) – Whether to apply channel normalization (specific to
'uvw'mode). Divides weights by \(\sqrt{\text{channels_in1} \times \text{channels_in2}}\).path (list, optional) – Manually specify the coupling paths. If
None, all allowed paths are used.
- weight
The learnable weights of the module if
internal_weights=True. Shape depends onfeature_mode.- Type:
torch.nn.Parameter or None
- tp_info_forward
Constant information for the forward pass computation.
- Type:
- tp_info_backward1
Constant information for the backward pass w.r.t. input1.
- Type:
- tp_info_backward2
Constant information for the backward pass w.r.t. input2.
- Type:
- num_paths
Number of coupling paths determined by the irreps.
- Type:
int
- weight_numel
Total number of elements in the weight tensor.
- Type:
int
- tp_info_forward: TensorProductInfo
- tp_info_backward1: TensorProductInfo
- tp_info_backward2: TensorProductInfo
- class equitorch.nn.TensorDot(irreps, feature_mode, scaled=False)
Bases:
ModuleComputes the equivariant irrep-wise dot product of two feature tensors.
Supports two main modes controlled by
feature_mode:'uv': Channel-cartesian dot product.Input1 shape:
(..., irreps.dim, channels1)Input2 shape:
(..., irreps.dim, channels2)Output shape:
(..., len(irreps), channels1, channels2)
'uu': Channel-wise dot product. Sums over the channel dimension after the dot product.Input1 shape:
(..., irreps.dim, channels)Input2 shape:
(..., irreps.dim, channels)Output shape:
(..., len(irreps), channels)
- Parameters:
irreps (Irreps or str) – Irreducible representations of the input tensors. Both inputs must have the same irreps.
feature_mode ({'uv', 'uu'}) –
Controls how the channel dimension is handled:
'uv': Channel-cartesian dot product.'uu': Channel-wise dot product.
scaled (bool, default=False) – If
True, scales the dot product by \(1 / \sqrt{\text{irrep.dim}}\).
- irreps_info
Constant information about the input irreps.
- Type:
- irreps_info: IrrepsInfo
- class equitorch.nn.SparseWignerRotation(irreps: Irreps)
Bases:
ModuleApplies a sparse Wigner D-matrix rotation to input features.
This module computes the rotation based on Euler angles (\(\alpha, \beta, \gamma\)) provided as precomputed sin/cos tensors. It utilizes sparse matrix operations for the rotation.
Warning
It is currently suggested to use
DenseWignerRotationorWignerDfor applying rotations, at least when gradients with respect to angles are not required.The Wigner D-matrix \(D^l_{m'm}(R)\) transforms spherical tensors under rotation \(R\). This module applies such transformations for a given
Irreps.- Parameters:
irreps (Irreps) – The irreducible representations defining the input and output feature space.
- info: WignerRotationInfo
- class equitorch.nn.DenseWignerRotation(irreps: Irreps)
Bases:
ModuleApplies a dense Wigner D-matrix rotation to input features.
This module takes a precomputed dense Wigner D-matrix and applies it to the input features. The Wigner D-matrix \(D(R)\) itself should be computed separately, for example, using the
WignerDmodule.- Parameters:
irreps (Irreps) – The irreducible representations defining the input and output feature space. This is used for validation and representation purposes.
- class equitorch.nn.WignerD(irreps: Irreps)
Bases:
ModuleComputes the dense Wigner D-matrix \(D(R)\) for given
Irrepsand Euler angles \((\alpha, \beta, \gamma)\).The Wigner D-matrix is constructed based on the ZYZ Euler angle convention:
\[D(\alpha, \beta, \gamma) = D_z(\alpha) D_y(\beta) D_z(\gamma)\]This module caches the necessary sparse rotation information and an identity matrix to efficiently compute the dense D-matrix using the
wigner_d_matrix()functional.- Parameters:
irreps (Irreps) – The irreducible representations for which to compute the D-matrix. The resulting D-matrix will have dimensions
(irreps.dim, irreps.dim).
- info: WignerRotationInfo
- class equitorch.nn.AlignToZWignerD(irreps: Irreps, normalized: bool = True, eps: float = 1e-14)
Bases:
ModuleComputes the Wigner D-matrix \(D(R_{align})\) that rotates a given vector \(\vec{v} = (x, y, z)\) onto the z-axis.
The rotation \(R_{align}\) is defined by Euler angles \((0, -\theta, -\phi)\), where \(\theta\) and \(\phi\) are the polar and azimuthal angles of the vector \(\vec{v}\), respectively. This means:
\[R_{align} \vec{v} = ||\vec{v}|| \hat{z}\]The Wigner D-matrix is then \(D(0, -\theta, -\phi)\).
This module caches the necessary sparse rotation information and an identity matrix. It utilizes the
align_to_z_wigner_d()functional.- Parameters:
irreps (Irreps) – The irreducible representations for which to compute the D-matrix.
normalized (bool, optional) – Whether to normalize the input
xyzvector before calculating angles for rotation. IfTrue, effectively rotates \(\hat{v}\). Defaults toTrue.eps (float, optional) – Small \(\epsilon\) value for numerical stability in angle calculation. Defaults to
1e-14.
- info: WignerRotationInfo
- class equitorch.nn.PolynomialCutoff(r_max: float, r_min: float = 0.0, p: float = 6)
Bases:
ModulePolynomial cutoff, as proposed in DimeNet.
The polynomial cutoff function is defined as:
\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ 1 - \frac{(p+1)(p+2)}{2}u^p + p(p+2)u^{p+1} - \frac{p(p+1)}{2}u^{p+2}, & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.
- Parameters:
r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.
r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than or equal to
r_max. Defaults to0..p (float, optional) – The power parameter (p) controlling the smoothness of the cutoff. Must be greater than or equal to
2.0. Defaults to6..
- class equitorch.nn.CosineCutoff(r_max: float, r_min: float = 0)
Bases:
ModuleThe cosine cutoff function.
The cosine cutoff function is defined as:
\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ \frac{1}{2}\left[1 + \cos\left(\pi \cdot u\right)\right], & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.
This cutoff function smoothly decreases from 1 to 0 in the range [r_{text{min}}, r_{text{max}}] using a cosine function.
- Parameters:
r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.
r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than
r_max. Defaults to0..
- class equitorch.nn.MollifierCutoff(r_max: float, r_min: float = 0, eps: float = 1e-07)
Bases:
ModuleThe mollifier cutoff function.
The mollifier cutoff function is defined as:
\[\begin{split}f(r) = \begin{cases} 1, & r < r_{\text{min}} \\ \exp\left(1 - \frac{1}{1 - u^2 + \epsilon}\right), & r_{\text{min}} \leq r \leq r_{\text{max}} \\ 0, & r > r_{\text{max}} \end{cases}\end{split}\]where (u = frac{r - r_{text{min}}}{r_{text{max}} - r_{text{min}}}) and (r) is the input distance.
This cutoff function smoothly decreases from 1 to 0 in the range [r_{text{min}}, r_{text{max}}] using a mollifier (bump) function.
- Parameters:
r_max (float) – The cutoff distance (r_{text{max}}) where the function reaches zero.
r_min (float, optional) – The starting distance (r_{text{min}}) where the function begins to decrease from 1. Must be less than
r_max. Defaults to0..eps (float, optional) – Small epsilon value (epsilon) to prevent division by zero. Defaults to
1e-7.
- class equitorch.nn.Gate(irreps: Irreps, activation: Module = None, irrep_wise: bool = True)
Bases:
ModuleApplies element-wise gates to equivariant features.
This module implements gating nonlinearities for features represented by
Irreps. It can operate in two primary modes based on how the gate values are provided:Separate Gates (``gate`` argument provided): The module takes two distinct inputs:
input(the features to be gated) andgate(the gate scalars).If
irrep_wise=True(default): Each gate scalar in thegatetensor is applied to its corresponding irrep block within theinputfeatures. Thegatetensor should have a shape compatible with(..., num_gates, channels), wherenum_gatesis the number of irreps inirreps.If
irrep_wise=False: A single gate scalar (or a set of scalars broadcastable across irreps) is applied to all irrep blocks in theinputfeatures. Thegatetensor should have a shape compatible with(..., 1, channels).
Concatenated Input (``gate=None``): The module takes a single
inputtensor where the features and their corresponding gate scalars are concatenated along the spherical dimension (dim=-2). The lastnum_gatesslices along this dimension are interpreted as the gate scalars. Theinputtensor shape is expected to be(..., irreps.dim + num_gates, channels). The module internally splits this tensor into features and gates, optionally applies an activation function to the extracted gates, and then proceeds with the gating operation as described in mode 1.
An optional activation function can be applied to the gate scalars before they modulate the features.
Example
import torch from equitorch.irreps import Irreps from equitorch.nn import Gate irreps = Irreps("1x0e + 2x1o") # Example: one scalar, two l=1 odd irreps gate_module = Gate(irreps, activation=torch.nn.Tanh()) batch_size, channels = 4, 8 num_gates = len(irreps) # This will be 2 for the example irreps # Mode 1: Separate input and gate tensors (irrep_wise=True) features = torch.randn(batch_size, irreps.dim, channels) gates = torch.randn(batch_size, num_gates, channels) output_separate = gate_module(features, gates) print(f"Output shape (separate gates): {output_separate.shape}") # Mode 2: Concatenated input tensor # irreps.dim for "1x0e + 2x1o" is 1*1 + 2*3 = 7 # num_gates is 2 concatenated_input = torch.randn(batch_size, irreps.dim + num_gates, channels) output_concatenated = gate_module(concatenated_input) print(f"Output shape (concatenated input): {output_concatenated.shape}")
- Parameters:
irreps (Irreps) – The irreducible representations of the feature part of the input tensor (i.e., the part that will be gated).
activation (torch.nn.Module, optional) – An activation function to be applied to the gate scalars before the gating operation. Defaults to
None(no activation).irrep_wise (bool, optional) – Determines how gates are applied. If
True(default), gates are applied irrep-by-irrep. This requires thegatetensor (if provided separately) to have a shape like(..., num_gates, channels). IfFalse, a single gate (or a broadcastable set) is applied across all irreps. This requires thegatetensor (if provided separately) to have a shape like(..., 1, channels).num_gatescorresponds tolen(irreps).
- irreps_info
Cached information about the input feature irreps, used for efficient gating.
- Type:
- num_gates
The number of distinct gate scalars, equal to
len(irreps). This dictates the expected size of the gate dimension in thegatetensor or the number of gate slices in a concatenated input.- Type:
int
- irreps_info: IrrepsInfo
- class equitorch.nn.SinCos(max_m: int, with_ones: bool = True, component_normalize: bool = False)
Bases:
ModuleModule wrapper for the
sincos()function.Computes the sin/cos expansion of an angle (a):
\[[1.0, \sin(a), \cos(a), \sin(2a), \cos(2a), \dots, \sin(\text{max_m} \cdot a), \cos(\text{max_m} \cdot a)]\]or
\[[1.0, \sqrt{2}\sin(a), \sqrt{2}\cos(a), \sqrt{2}\sin(2a), \sqrt{2}\cos(2a), \dots, \sqrt{2}\sin(\text{max_m} \cdot a), \sqrt{2}\cos(\text{max_m} \cdot a)]\]The leading 1.0 is excluded if with_ones is
False.- Parameters:
max_m (int) – The maximum multiple of the angle (a) to compute (sin) and (cos) for.
with_ones (bool, optional) – Whether to include the leading 1.0 in the expansion. Defaults to
True.component_normalize (bool, optional) – If
True, multiplies the (sin) and (cos) values by (sqrt{2}) such that the expectation of the squared norm over ([0, 2pi]) is 1. Defaults toFalse.
- class equitorch.nn.BesselBasis(r_max, num_basis=8, trainable=True)
Bases:
Module- __init__(r_max, num_basis=8, trainable=True)
Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
- Parameters:
r_max (float) – Cutoff radius
num_basis (int) – Number of Bessel Basis functions
trainable (bool) – Train the \(n \pi\) part or not.
- r_max: float
- prefactor: float
- class equitorch.nn.SquaredNorm(irreps: Irreps, scaled: bool = True)
Bases:
ModuleComputes the squared L2 norm for each irrep block in an input tensor.
\[\text{Output}_k = \sum_{m \in \text{irrep}_k} (\text{input}_{km}^2)\]Optionally scales the output by \(1/\text{irrep}_k\text{.dim}\).
- Parameters:
irreps (Irreps) – Irreducible representations of the input tensor.
scaled (bool, optional) – If
True, scales the output of eachirrep_kby \(1/\text{irrep}_k\text{.dim}\). Defaults toTrue.
- irreps_info: IrrepsInfo
- class equitorch.nn.Norm(irreps: Irreps, scaled: bool = True)
Bases:
ModuleComputes the L2 norm for each irrep block in an input tensor.
\[\text{Output}_k = \sqrt{\sum_{m \in \text{irrep}_k} (\text{input}_{km}^2)}\]Optionally scales the output by \(\sqrt{1/\text{irrep}_k\text{.dim}}\). Gradient at zero vector is zero.
- Parameters:
irreps (Irreps) – Irreducible representations of the input tensor.
scaled (bool, optional) – If
True, scales the output of eachirrep_kby \(\sqrt{1/\text{irrep}_k\text{.dim}}\). Defaults toTrue.
- irreps_info: IrrepsInfo
- class equitorch.nn.MeanSquaredNorm(irreps: Irreps, scaled: bool = True, dim: Literal[0, -1, 2] = -1)
Bases:
ModuleComputes the mean of squared L2 norms over a specified dimension (batch or channel) for each irrep block.
If
dim=0(batch mean):\[\text{Output}_{ic} = \frac{1}{N} \sum_n \left( \sum_{m \in \text{irrep}_i} (\text{input}_n(im)c^2) \right)\]If
dim=-1(channel mean):\[\text{Output}_{ni} = \frac{1}{C} \sum_c \left( \sum_{m \in \text{irrep}_i} (\text{input}_n(im)c^2) \right)\]Optionally scales the output by \(1/\text{irrep}_i\text{.dim}\).
- Parameters:
irreps (Irreps) – Irreducible representations of the input tensor.
scaled (bool, optional) – If
True, scales the output of eachirrep_iby \(1/\text{dim}_\text{irrep_i}\). Defaults toTrue.dim (int, optional) – Dimension over which to compute the mean. Allowed values:
0(batch),-1or2(channel). Defaults to-1.
- irreps_info: IrrepsInfo
- class equitorch.nn.Dropout(p: float = 0.5, irreps: Irreps = None, irrep_wise: bool = True, work_on_eval: bool = False)
Bases:
ModuleApply dropout to equivariant features.
Can operate irrep-wise or on the entire feature vector (channel-wise).
- Parameters:
p (float, optional) – Probability of an element to be zeroed. Default: 0.5
irreps (Irreps, optional) – Irreps of the input tensor. Required if irrep_wise is True. Default: None
irrep_wise (bool, optional) – If True, applies dropout independently for each (irrep_instance, channel). If False, applies standard 1D dropout treating (irreps_dim, channels) as a single feature dimension for dropout. Default: True
work_on_eval (bool, optional) – If True, dropout is applied even during evaluation. Default: False
- class equitorch.nn.AnglesToMatrix
Bases:
ModuleModule to convert Euler angles (ZYZ convention) to rotation matrices.
The ZYZ Euler angles (alpha, beta, gamma) correspond to the rotation matrix:
\[R(\alpha, \beta, \gamma) = R_z(\alpha) R_y(\beta) R_z(\gamma)\]which is explicitly:
\[\begin{split}\begin{pmatrix} -\sin(\alpha)\sin(\gamma) + \cos(\alpha)\cos(\beta)\cos(\gamma) & -\sin(\alpha)\cos(\beta)\cos(\gamma) - \sin(\gamma)\cos(\alpha) & \sin(\beta)\cos(\gamma) \\ \sin(\alpha)\cos(\gamma) + \sin(\gamma)\cos(\alpha)\cos(\beta) & -\sin(\alpha)\sin(\gamma)\cos(\beta) + \cos(\alpha)\cos(\gamma) & \sin(\beta)\sin(\gamma) \\ -\sin(\beta)\cos(\alpha) & \sin(\alpha)\sin(\beta) & \cos(\beta) \end{pmatrix}\end{split}\]Wraps the functional version
angles_to_matrix().