equitorch.utils

Utility functions.

equitorch.utils.rand_spherical_xyz(shape: Tuple[int, ...], device=None, dtype=None) Tensor

Generate random points uniformly distributed on a unit sphere.

Parameters:
  • shape – Tuple defining the batch dimensions (e.g., (10,) or (5,5))

  • device – Torch device for the output tensor

  • dtype – Torch dtype for the output tensor

Returns:

Tensor of shape (*shape, 3) where each vector has unit norm

equitorch.utils.rand_spherical_angles(shape: Tuple[int, ...], device=None, dtype=None) Tuple[Tensor, Tensor]

Generate random spherical angles with uniform distribution.

Parameters:
  • shape – Tuple defining the batch dimensions

  • device – Torch device for the output tensors

  • dtype – Torch dtype for the output tensors

Returns:

  • theta is in [0, π) (polar angle from +z axis)

  • phi is in [0, 2π) (azimuthal angle from +x axis)

Return type:

Tuple of (theta, phi) where

equitorch.utils.rand_rotation_angles(shape: Tuple[int, ...], device=None, dtype=None) Tuple[Tensor, Tensor, Tensor]

Generate random Euler angles (ZYZ convention) with uniform distribution.

Parameters:
  • shape – Tuple defining the batch dimensions

  • device – Torch device for the output tensors

  • dtype – Torch dtype for the output tensors

Returns:

  • alpha is in [0, 2π) (first rotation about z-axis)

  • beta is in [0, π) (rotation about y-axis)

  • gamma is in [0, 2π) (second rotation about z-axis)

Return type:

Tuple of (alpha, beta, gamma) where

equitorch.utils.rand_rotation_matrices(shape: Tuple[int, ...], device=None, dtype=None) Tensor

Generate random rotation matrices using Rodrigues’ rotation formula.

Parameters:
  • shape – Tuple defining the batch dimensions

  • device – Torch device for the output tensor

  • dtype – Torch dtype for the output tensor

Returns:

Tensor of shape (*shape, 3, 3) containing valid rotation matrices

equitorch.utils.expand_left(source: Tensor, target: Tensor, dim: int)
equitorch.utils.extract_batch_segments(keys: List[List[int]])

Process sorted integer key lists to generate batch indices, boundary pointers, and key values.

Parameters:

keys (List[List[int]]) – A list of sorted integer key lists. All lists must have the same length.

Returns:

  • batch (List[int]) – A list where each element indicates the batch index it belongs to.

  • seg (List[int]) – A list of boundary pointers indicating the start and end of each batch.

  • val (List[List[int]]) – A list of lists containing the key values at the boundary points for each key list.

Notes

  • The input key lists must be sorted in ascending order.

  • If the input is empty, the function returns empty lists for batch, seg, and val.

Examples

>>> keys = [
...     [1, 1, 2, 2],
...     [1, 1, 2, 2]
... ]
>>> extract_batch_seg_native(keys)
([0, 0, 1, 1], [0, 2, 4], [[1, 2], [1, 2]])
>>> keys = [
...     [5, 5, 5],
...     [5, 5, 5]
... ]
>>> extract_batch_seg_native(keys)
([0, 0, 0], [0, 3], [[5], [5]])
>>> keys = [
...     [1, 1, 2, 3, 3],
...     [1, 2, 2, 3, 3]
... ]
>>> extract_batch_seg_native(keys)
([0, 1, 2, 3, 3], [0, 1, 2, 3, 5], [[1, 1, 2, 3], [1, 2, 2, 3]])
equitorch.utils.sort_by_column_key(to_sort: List[List[Any]], key: List[List[Any]] = None) List[List[Any]]

Sort the columns of the first 2D list based on the column-wise lexicographical order of the key 2D list.

Parameters:
  • to_sort (List[List[Any]]) – The first 2D list whose columns are to be sorted.

  • key (List[List[Any]]) – The key 2D list used to determine the sorting order of columns.

Returns:

The first 2D list with columns sorted according to the column-wise lexicographical order of the key.

Return type:

List[List[Any]]

Raises:

ValueError – If either to_sort or key is empty, or if their lengths do not match.

Examples

>>> to_sort = [[1, 2, 3],
...            [4, 5, 6]]
>>> key = [[2, 1, 3],
...        [1, 3, 2]]
>>> sort_by_column_key(to_sort, key)
[[2, 1, 3],
 [5, 4, 6]]
equitorch.utils.extract_scatter_indices(keys: List[List[int]]) Tuple[List[int], List[List[int]]]

Process integer key lists to generate scatter indices and sorted unique keys.

Parameters:

keys (List[List[int]]) – A list of integer key lists. All lists must have the same length.

Returns:

  • indices (List[int]) – A list where each element is the index of the corresponding key tuple in the sorted unique list.

  • scatter_keys (List[List[int]]) – A list of lists containing the sorted unique key values for each original key list.

Notes

  • If the input is empty, the function returns empty lists for indices and scatter_keys.

Examples

>>> keys = [
...     [1, 1, 2, 2],
...     [1, 1, 2, 2]
... ]
>>> extract_scatter_indices(keys)
([0, 0, 1, 1], [[1, 2], [1, 2]])
>>> keys = [
...     [5, 5, 5],
...     [5, 5, 5]
... ]
>>> extract_scatter_indices(keys)
([0, 0, 0], [[5], [5]])
>>> keys = [
...     [1, 1, 2, 3, 3],
...     [1, 2, 2, 3, 3]
... ]
>>> extract_scatter_indices(keys)
([0, 1, 2, 3, 3], [[1, 1, 2, 3], [1, 2, 2, 3]])
equitorch.utils.sparse_scale_info(index=None, index_out=None, scale=None, out_size=None)
equitorch.utils.sparse_scale_infos(index=None, index_out=None, scale=None, out_size=None, in_size=None)
equitorch.utils.sparse_product_info(index1=None, index2=None, index=None, scale=None, out_size=None)
equitorch.utils.sparse_product_infos(index1=None, index2=None, index=None, scale=None, out_size=None, in1_size=None, in2_size=None)
equitorch.utils.generate_fully_connected_tp_paths(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps)
equitorch.utils.tp_info(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps, path: List[Tuple[int, int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
equitorch.utils.tp_infos(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps, path: List[Tuple[int, int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
equitorch.utils.generate_fully_connected_irreps_linear_paths(irreps_out: Irreps, irreps_in: Irreps)
equitorch.utils.irreps_linear_infos(irreps_out: Irreps, irreps_in: Irreps, path: List[Tuple[int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
equitorch.utils.irreps_info(irreps: Irreps)
equitorch.utils.z_rotation_infos(irreps: Irreps) Tuple[SparseProductInfo, SparseProductInfo, SparseProductInfo]

Generates SparseProductInfo for performing z-axis rotation on a tensor with the given Irreps structure using the sparse_scavec operation.

Rotation formula:

m=0: x’_{nimc} = x_{ni0c} m!=0: x’_{nimc} = cos(m*phi) * x_{nimc} + sin(m*phi) * x_{ni(-m)c}

Assumes sparse_scavec computes:

output[n, M] = sum_t scale[t] * input[n, index2[t]] * cs[n, index1[t]]

where the sum is segmented by the output index M (implicit via index_out).

Assumed cs Tensor Structure (Input 2): shape (N, cs_dim)

cs_dim = 1 + 2 * max_l cs[:, 0] = 1.0 cs[:, 2*m - 1] = sin(m*phi) for m = 1..max_l cs[:, 2*m] = cos(m*phi) for m = 1..max_l

Parameters:

irreps – The Irreps object describing the geometric structure of the input tensor (Input 1, shape (N, irreps.dim, C)) and output tensor (shape (N, irreps.dim, C)).

Returns:

A tuple (info_fwd, info_bwd1, info_bwd2) containing SparseProductInfo objects for the forward pass and gradients w.r.t. x (input1) and cs (input2).

equitorch.utils.j_matrix_info(irreps: Irreps) SparseScaleInfo

Generates SparseScaleInfo for multiplying by the J matrix by extracting non-zero elements from the dense blocks computed by the user’s j_matrix(l) function.

Operation: y_M’ = sum_M J_{M’, M} x_M

Parameters:

irreps – The Irreps object describing the geometric structure.

Returns:

A SparseScaleInfo object for the forward pass.

equitorch.utils.wigner_d_info(irreps: Irreps) WignerRotationInfo

Prepares all necessary info objects for arbitrary Wigner D rotation.

equitorch.utils.irreps_blocks_infos(irreps: Irreps) SparseProductInfo
equitorch.utils.so2_linear_info(irreps_out, irreps_in, path=None, path_norm=True, channel_norm=False, channel_scale=1.0)
equitorch.utils.so2_linear_infos(irreps_out, irreps_in, path=None, path_norm=True, channel_norm=False, channel_scale=1.0)