equitorch.utils
Utility functions.
- equitorch.utils.rand_spherical_xyz(shape: Tuple[int, ...], device=None, dtype=None) Tensor
Generate random points uniformly distributed on a unit sphere.
- Parameters:
shape – Tuple defining the batch dimensions (e.g., (10,) or (5,5))
device – Torch device for the output tensor
dtype – Torch dtype for the output tensor
- Returns:
Tensor of shape
(*shape, 3)where each vector has unit norm
- equitorch.utils.rand_spherical_angles(shape: Tuple[int, ...], device=None, dtype=None) Tuple[Tensor, Tensor]
Generate random spherical angles with uniform distribution.
- Parameters:
shape – Tuple defining the batch dimensions
device – Torch device for the output tensors
dtype – Torch dtype for the output tensors
- Returns:
theta is in [0, π) (polar angle from +z axis)
phi is in [0, 2π) (azimuthal angle from +x axis)
- Return type:
Tuple of (theta, phi) where
- equitorch.utils.rand_rotation_angles(shape: Tuple[int, ...], device=None, dtype=None) Tuple[Tensor, Tensor, Tensor]
Generate random Euler angles (ZYZ convention) with uniform distribution.
- Parameters:
shape – Tuple defining the batch dimensions
device – Torch device for the output tensors
dtype – Torch dtype for the output tensors
- Returns:
alpha is in [0, 2π) (first rotation about z-axis)
beta is in [0, π) (rotation about y-axis)
gamma is in [0, 2π) (second rotation about z-axis)
- Return type:
Tuple of (alpha, beta, gamma) where
- equitorch.utils.rand_rotation_matrices(shape: Tuple[int, ...], device=None, dtype=None) Tensor
Generate random rotation matrices using Rodrigues’ rotation formula.
- Parameters:
shape – Tuple defining the batch dimensions
device – Torch device for the output tensor
dtype – Torch dtype for the output tensor
- Returns:
Tensor of shape (*shape, 3, 3) containing valid rotation matrices
- equitorch.utils.expand_left(source: Tensor, target: Tensor, dim: int)
- equitorch.utils.extract_batch_segments(keys: List[List[int]])
Process sorted integer key lists to generate batch indices, boundary pointers, and key values.
- Parameters:
keys (List[List[int]]) – A list of sorted integer key lists. All lists must have the same length.
- Returns:
batch (List[int]) – A list where each element indicates the batch index it belongs to.
seg (List[int]) – A list of boundary pointers indicating the start and end of each batch.
val (List[List[int]]) – A list of lists containing the key values at the boundary points for each key list.
Notes
The input key lists must be sorted in ascending order.
If the input is empty, the function returns empty lists for batch, seg, and val.
Examples
>>> keys = [ ... [1, 1, 2, 2], ... [1, 1, 2, 2] ... ] >>> extract_batch_seg_native(keys) ([0, 0, 1, 1], [0, 2, 4], [[1, 2], [1, 2]])
>>> keys = [ ... [5, 5, 5], ... [5, 5, 5] ... ] >>> extract_batch_seg_native(keys) ([0, 0, 0], [0, 3], [[5], [5]])
>>> keys = [ ... [1, 1, 2, 3, 3], ... [1, 2, 2, 3, 3] ... ] >>> extract_batch_seg_native(keys) ([0, 1, 2, 3, 3], [0, 1, 2, 3, 5], [[1, 1, 2, 3], [1, 2, 2, 3]])
- equitorch.utils.sort_by_column_key(to_sort: List[List[Any]], key: List[List[Any]] = None) List[List[Any]]
Sort the columns of the first 2D list based on the column-wise lexicographical order of the key 2D list.
- Parameters:
to_sort (List[List[Any]]) – The first 2D list whose columns are to be sorted.
key (List[List[Any]]) – The key 2D list used to determine the sorting order of columns.
- Returns:
The first 2D list with columns sorted according to the column-wise lexicographical order of the key.
- Return type:
List[List[Any]]
- Raises:
ValueError – If either to_sort or key is empty, or if their lengths do not match.
Examples
>>> to_sort = [[1, 2, 3], ... [4, 5, 6]] >>> key = [[2, 1, 3], ... [1, 3, 2]] >>> sort_by_column_key(to_sort, key) [[2, 1, 3], [5, 4, 6]]
- equitorch.utils.extract_scatter_indices(keys: List[List[int]]) Tuple[List[int], List[List[int]]]
Process integer key lists to generate scatter indices and sorted unique keys.
- Parameters:
keys (List[List[int]]) – A list of integer key lists. All lists must have the same length.
- Returns:
indices (List[int]) – A list where each element is the index of the corresponding key tuple in the sorted unique list.
scatter_keys (List[List[int]]) – A list of lists containing the sorted unique key values for each original key list.
Notes
If the input is empty, the function returns empty lists for indices and scatter_keys.
Examples
>>> keys = [ ... [1, 1, 2, 2], ... [1, 1, 2, 2] ... ] >>> extract_scatter_indices(keys) ([0, 0, 1, 1], [[1, 2], [1, 2]])
>>> keys = [ ... [5, 5, 5], ... [5, 5, 5] ... ] >>> extract_scatter_indices(keys) ([0, 0, 0], [[5], [5]])
>>> keys = [ ... [1, 1, 2, 3, 3], ... [1, 2, 2, 3, 3] ... ] >>> extract_scatter_indices(keys) ([0, 1, 2, 3, 3], [[1, 1, 2, 3], [1, 2, 2, 3]])
- equitorch.utils.sparse_scale_info(index=None, index_out=None, scale=None, out_size=None)
- equitorch.utils.sparse_scale_infos(index=None, index_out=None, scale=None, out_size=None, in_size=None)
- equitorch.utils.sparse_product_info(index1=None, index2=None, index=None, scale=None, out_size=None)
- equitorch.utils.sparse_product_infos(index1=None, index2=None, index=None, scale=None, out_size=None, in1_size=None, in2_size=None)
- equitorch.utils.generate_fully_connected_tp_paths(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps)
- equitorch.utils.tp_info(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps, path: List[Tuple[int, int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
- equitorch.utils.tp_infos(irreps_out: Irreps, irreps1: Irreps, irreps2: Irreps, path: List[Tuple[int, int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
- equitorch.utils.generate_fully_connected_irreps_linear_paths(irreps_out: Irreps, irreps_in: Irreps)
- equitorch.utils.irreps_linear_infos(irreps_out: Irreps, irreps_in: Irreps, path: List[Tuple[int, int]] = None, path_norm: bool = True, channel_norm: bool = False, channel_scale: float = 1.0)
- equitorch.utils.z_rotation_infos(irreps: Irreps) Tuple[SparseProductInfo, SparseProductInfo, SparseProductInfo]
Generates SparseProductInfo for performing z-axis rotation on a tensor with the given Irreps structure using the sparse_scavec operation.
- Rotation formula:
m=0: x’_{nimc} = x_{ni0c} m!=0: x’_{nimc} = cos(m*phi) * x_{nimc} + sin(m*phi) * x_{ni(-m)c}
- Assumes sparse_scavec computes:
output[n, M] = sum_t scale[t] * input[n, index2[t]] * cs[n, index1[t]]
where the sum is segmented by the output index M (implicit via index_out).
- Assumed cs Tensor Structure (Input 2): shape (N, cs_dim)
cs_dim = 1 + 2 * max_l cs[:, 0] = 1.0 cs[:, 2*m - 1] = sin(m*phi) for m = 1..max_l cs[:, 2*m] = cos(m*phi) for m = 1..max_l
- Parameters:
irreps – The Irreps object describing the geometric structure of the input tensor (Input 1, shape (N, irreps.dim, C)) and output tensor (shape (N, irreps.dim, C)).
- Returns:
A tuple (info_fwd, info_bwd1, info_bwd2) containing SparseProductInfo objects for the forward pass and gradients w.r.t. x (input1) and cs (input2).
- equitorch.utils.j_matrix_info(irreps: Irreps) SparseScaleInfo
Generates SparseScaleInfo for multiplying by the J matrix by extracting non-zero elements from the dense blocks computed by the user’s j_matrix(l) function.
Operation: y_M’ = sum_M J_{M’, M} x_M
- Parameters:
irreps – The Irreps object describing the geometric structure.
- Returns:
A SparseScaleInfo object for the forward pass.
- equitorch.utils.wigner_d_info(irreps: Irreps) WignerRotationInfo
Prepares all necessary info objects for arbitrary Wigner D rotation.
- equitorch.utils.irreps_blocks_infos(irreps: Irreps) SparseProductInfo
- equitorch.utils.so2_linear_info(irreps_out, irreps_in, path=None, path_norm=True, channel_norm=False, channel_scale=1.0)
- equitorch.utils.so2_linear_infos(irreps_out, irreps_in, path=None, path_norm=True, channel_norm=False, channel_scale=1.0)