equitorch.transforms

Provides data transformation utilities, particularly for graph-based data compatible with PyTorch Geometric.

class equitorch.transforms.RadiusGraph(r: float, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', pos_attr: str = 'pos', edge_index_attr: str = 'edge_index', edge_vector_attr: str = 'edge_vec', num_workers: int = 1)

Bases: BaseTransform

Creates edges based on node positions pos_attr to all points within a given cutoff distance (functional name: radius_graph_eqt).

Parameters:
  • r (float) – The cutoff distance.

  • loop (bool, optional) – If True, the graph will contain self-loops. Defaults to False.

  • max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element. This flag is only needed for CUDA tensors. Defaults to 32.

  • flow (str, optional) – The flow direction when using in combination with message passing ("source_to_target" or "target_to_source"). Defaults to "source_to_target".

  • pos_attr (str, optional) – The attribute name for positions in the data. Defaults to "pos".

  • edge_index_attr (str, optional) – The attribute name for creating edge index in the data. Defaults to "edge_index".

  • edge_vector_attr (str, optional) – The attribute name for creating edge vectors in the data. Defaults to "edge_vec".

  • num_workers (int, optional) – Number of workers to use for computation. Has no effect in case batch is not None, or the input lies on the GPU. Defaults to 1.

Example

>>> N = 50
>>> pos = torch.randn(N,3)
>>> data = Data(pos=pos)
>>> print(data)
Data(pos=[50, 3])
>>> data = RadiusGraph(0.5)(data)
>>> print(data)
Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3])
class equitorch.transforms.AddSphericalHarmonics(l_max: int, vector_attr: str = 'edge_vec', sh_attr: str = 'edge_sh', integral_normalize: bool = False)

Bases: BaseTransform

Creates spherical harmonics embedding based on direction vectors vector_attr (functional name: add_spherical_harmonics).

Parameters:
  • l_max (int) – The maximum degree of spherical harmonics.

  • vector_attr (str, optional) – The attribute name for direction vectors in the data. Defaults to "edge_vec".

  • sh_attr (str, optional) – The attribute name for storing spherical harmonics in the data. Defaults to "edge_sh".

  • integral_normalize (bool, optional) – Whether to normalize the spherical harmonics by \(\sqrt{4\pi / (2l+1)}\). Defaults to False.

Example

>>> print(data)
Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3])
>>> data = AddSphericalHarmonics(l_max=3)(data)
>>> print(data)
Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3], edge_sh=[36, 16])
class equitorch.transforms.AddVectorNorm(vector_attr: str = 'edge_vec', norm_attr: str = 'edge_norm')

Bases: BaseTransform

Computes the norm of a vector attribute and adds it to the data object.

This transform is useful for obtaining scalar distance information from vector representations, commonly used in graph neural networks for edge features.

(functional name: add_vector_norm)

Parameters:
  • vector_attr (str, optional) – The attribute name for the input vector in the torch_geometric.data.Data object. Defaults to 'edge_vec'.

  • norm_attr (str, optional) – The attribute name under which the computed norm will be stored in the torch_geometric.data.Data object. Defaults to 'edge_norm'.