equitorch.structs

Defines specialized data structures (NamedTuples) for managing sparse operations and tensor product information.

equitorch.structs.add_operation_methods(cls)
class equitorch.structs.SparseScaleInfo(scale: Tensor | None = None, index: Tensor | None = None, seg_out: Tensor | None = None, index_out: Tensor | None = None, out_size: int | None = None)

Bases: NamedTuple

z_M = sum_{t in Ind*[M]} s_t * x_Ind’[t]

or

z_M = sum_{M’} s_{MM’} x_M’

scale: Tensor | None

Alias for field number 0

index: Tensor | None

Alias for field number 1

seg_out: Tensor | None

Alias for field number 2

index_out: Tensor | None

Alias for field number 3

out_size: int | None

Alias for field number 4

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)
class equitorch.structs.SparseProductInfo(scale: Tensor | None = None, index1: Tensor | None = None, index2: Tensor | None = None, seg_out: Tensor | None = None, gather_index: Tensor | None = None, index_out: Tensor | None = None, out_size: int | None = None)

Bases: NamedTuple

z_M = sum_{t in Ind*[M]} s_t * x_Ind1[t] * y_Ind2[t]

or

z_M = sum_{M1M2} s_{MM1M2} x_M1 * y_M2

scale: Tensor | None

Alias for field number 0

index1: Tensor | None

Alias for field number 1

index2: Tensor | None

Alias for field number 2

seg_out: Tensor | None

Alias for field number 3

gather_index: Tensor | None

Alias for field number 4

index_out: Tensor | None

Alias for field number 5

out_size: int | None

Alias for field number 6

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)
class equitorch.structs.TensorProductInfo(info_Mij_fwd, info_Mij_bwd1, info_Mij_bwd2, info_M_fwd, info_M_bwd1, info_M_bwd2, info_kM1j_fwd, info_kM1j_bwd1, info_kM1j_bwd2, info_kM1M2_fwd, info_kM1M2_bwd1, info_kM1M2_bwd2, info_M_kM1M2_fwd, info_M_kM1M2_bwd, out_size)

Bases: NamedTuple

info_Mij_fwd: SparseProductInfo

Alias for field number 0

info_Mij_bwd1: SparseProductInfo

Alias for field number 1

info_Mij_bwd2: SparseProductInfo

Alias for field number 2

info_M_fwd: SparseProductInfo

Alias for field number 3

info_M_bwd1: SparseProductInfo

Alias for field number 4

info_M_bwd2: SparseProductInfo

Alias for field number 5

info_kM1j_fwd: SparseProductInfo

Alias for field number 6

info_kM1j_bwd1: SparseProductInfo

Alias for field number 7

info_kM1j_bwd2: SparseProductInfo

Alias for field number 8

info_kM1M2_fwd: SparseProductInfo

Alias for field number 9

info_kM1M2_bwd1: SparseProductInfo

Alias for field number 10

info_kM1M2_bwd2: SparseProductInfo

Alias for field number 11

info_M_kM1M2_fwd: SparseScaleInfo

Alias for field number 12

info_M_kM1M2_bwd: SparseScaleInfo

Alias for field number 13

out_size: int

Alias for field number 14

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)
class equitorch.structs.IrrepsInfo(rsqrt_dims, rdims, irrep_index, irrep_seg, num_irreps)

Bases: NamedTuple

rsqrt_dims: Tensor

Alias for field number 0

rdims: Tensor

Alias for field number 1

irrep_index: Tensor

Alias for field number 2

irrep_seg: Tensor

Alias for field number 3

num_irreps: int

Alias for field number 4

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)
class equitorch.structs.IrrepsLinearInfo(scale_MM0, M_seg_MM0, ii0_MM0, M0_MM0, M_MM0, M_out, ii0_seg_ii0MM0, M_ii0MM0, M0_ii0MM0, scales_ii0, out_size)

Bases: NamedTuple

scale_MM0: Tensor

Alias for field number 0

M_seg_MM0: Tensor

Alias for field number 1

ii0_MM0: Tensor

Alias for field number 2

M0_MM0: Tensor

Alias for field number 3

M_MM0: Tensor

Alias for field number 4

M_out: Tensor

Alias for field number 5

ii0_seg_ii0MM0: Tensor

Alias for field number 6

M_ii0MM0: Tensor

Alias for field number 7

M0_ii0MM0: Tensor

Alias for field number 8

scales_ii0: Tensor

Alias for field number 9

out_size: int

Alias for field number 10

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)
class equitorch.structs.WignerRotationInfo(j_matrix_info, rotate_z_info_fwd, rotate_z_info_bwd_input, rotate_z_info_bwd_cs, sign, max_m)

Bases: NamedTuple

j_matrix_info: SparseScaleInfo

Alias for field number 0

rotate_z_info_fwd: SparseProductInfo

Alias for field number 1

rotate_z_info_bwd_input: SparseProductInfo | None

Alias for field number 2

rotate_z_info_bwd_cs: SparseProductInfo | None

Alias for field number 3

sign: Tensor | None

Alias for field number 4

max_m: int | None

Alias for field number 5

cpu(*args, **kwargs)
cuda(*args, **kwargs)
to(*args, **kwargs)