equitorch.structs
Defines specialized data structures (NamedTuples) for managing sparse operations and tensor product information.
- equitorch.structs.add_operation_methods(cls)
- class equitorch.structs.SparseScaleInfo(scale: Tensor | None = None, index: Tensor | None = None, seg_out: Tensor | None = None, index_out: Tensor | None = None, out_size: int | None = None)
Bases:
NamedTuplez_M = sum_{t in Ind*[M]} s_t * x_Ind’[t]
or
z_M = sum_{M’} s_{MM’} x_M’
- scale: Tensor | None
Alias for field number 0
- index: Tensor | None
Alias for field number 1
- seg_out: Tensor | None
Alias for field number 2
- index_out: Tensor | None
Alias for field number 3
- out_size: int | None
Alias for field number 4
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)
- class equitorch.structs.SparseProductInfo(scale: Tensor | None = None, index1: Tensor | None = None, index2: Tensor | None = None, seg_out: Tensor | None = None, gather_index: Tensor | None = None, index_out: Tensor | None = None, out_size: int | None = None)
Bases:
NamedTuplez_M = sum_{t in Ind*[M]} s_t * x_Ind1[t] * y_Ind2[t]
or
z_M = sum_{M1M2} s_{MM1M2} x_M1 * y_M2
- scale: Tensor | None
Alias for field number 0
- index1: Tensor | None
Alias for field number 1
- index2: Tensor | None
Alias for field number 2
- seg_out: Tensor | None
Alias for field number 3
- gather_index: Tensor | None
Alias for field number 4
- index_out: Tensor | None
Alias for field number 5
- out_size: int | None
Alias for field number 6
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)
- class equitorch.structs.TensorProductInfo(info_Mij_fwd, info_Mij_bwd1, info_Mij_bwd2, info_M_fwd, info_M_bwd1, info_M_bwd2, info_kM1j_fwd, info_kM1j_bwd1, info_kM1j_bwd2, info_kM1M2_fwd, info_kM1M2_bwd1, info_kM1M2_bwd2, info_M_kM1M2_fwd, info_M_kM1M2_bwd, out_size)
Bases:
NamedTuple- info_Mij_fwd: SparseProductInfo
Alias for field number 0
- info_Mij_bwd1: SparseProductInfo
Alias for field number 1
- info_Mij_bwd2: SparseProductInfo
Alias for field number 2
- info_M_fwd: SparseProductInfo
Alias for field number 3
- info_M_bwd1: SparseProductInfo
Alias for field number 4
- info_M_bwd2: SparseProductInfo
Alias for field number 5
- info_kM1j_fwd: SparseProductInfo
Alias for field number 6
- info_kM1j_bwd1: SparseProductInfo
Alias for field number 7
- info_kM1j_bwd2: SparseProductInfo
Alias for field number 8
- info_kM1M2_fwd: SparseProductInfo
Alias for field number 9
- info_kM1M2_bwd1: SparseProductInfo
Alias for field number 10
- info_kM1M2_bwd2: SparseProductInfo
Alias for field number 11
- info_M_kM1M2_fwd: SparseScaleInfo
Alias for field number 12
- info_M_kM1M2_bwd: SparseScaleInfo
Alias for field number 13
- out_size: int
Alias for field number 14
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)
- class equitorch.structs.IrrepsInfo(rsqrt_dims, rdims, irrep_index, irrep_seg, num_irreps)
Bases:
NamedTuple- rsqrt_dims: Tensor
Alias for field number 0
- rdims: Tensor
Alias for field number 1
- irrep_index: Tensor
Alias for field number 2
- irrep_seg: Tensor
Alias for field number 3
- num_irreps: int
Alias for field number 4
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)
- class equitorch.structs.IrrepsLinearInfo(scale_MM0, M_seg_MM0, ii0_MM0, M0_MM0, M_MM0, M_out, ii0_seg_ii0MM0, M_ii0MM0, M0_ii0MM0, scales_ii0, out_size)
Bases:
NamedTuple- scale_MM0: Tensor
Alias for field number 0
- M_seg_MM0: Tensor
Alias for field number 1
- ii0_MM0: Tensor
Alias for field number 2
- M0_MM0: Tensor
Alias for field number 3
- M_MM0: Tensor
Alias for field number 4
- M_out: Tensor
Alias for field number 5
- ii0_seg_ii0MM0: Tensor
Alias for field number 6
- M_ii0MM0: Tensor
Alias for field number 7
- M0_ii0MM0: Tensor
Alias for field number 8
- scales_ii0: Tensor
Alias for field number 9
- out_size: int
Alias for field number 10
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)
- class equitorch.structs.WignerRotationInfo(j_matrix_info, rotate_z_info_fwd, rotate_z_info_bwd_input, rotate_z_info_bwd_cs, sign, max_m)
Bases:
NamedTuple- j_matrix_info: SparseScaleInfo
Alias for field number 0
- rotate_z_info_fwd: SparseProductInfo
Alias for field number 1
- rotate_z_info_bwd_input: SparseProductInfo | None
Alias for field number 2
- rotate_z_info_bwd_cs: SparseProductInfo | None
Alias for field number 3
- sign: Tensor | None
Alias for field number 4
- max_m: int | None
Alias for field number 5
- cpu(*args, **kwargs)
- cuda(*args, **kwargs)
- to(*args, **kwargs)