equitorch.transforms
Provides data transformation utilities, particularly for graph-based data compatible with PyTorch Geometric.
- class equitorch.transforms.RadiusGraph(r: float, loop: bool = False, max_num_neighbors: int = 32, flow: str = 'source_to_target', pos_attr: str = 'pos', edge_index_attr: str = 'edge_index', edge_vector_attr: str = 'edge_vec', num_workers: int = 1)
Bases:
BaseTransformCreates edges based on node positions
pos_attrto all points within a given cutoff distance (functional name:radius_graph_eqt).- Parameters:
r (float) – The cutoff distance.
loop (bool, optional) – If
True, the graph will contain self-loops. Defaults toFalse.max_num_neighbors (int, optional) – The maximum number of neighbors to return for each element. This flag is only needed for CUDA tensors. Defaults to
32.flow (str, optional) – The flow direction when using in combination with message passing (
"source_to_target"or"target_to_source"). Defaults to"source_to_target".pos_attr (str, optional) – The attribute name for positions in the data. Defaults to
"pos".edge_index_attr (str, optional) – The attribute name for creating edge index in the data. Defaults to
"edge_index".edge_vector_attr (str, optional) – The attribute name for creating edge vectors in the data. Defaults to
"edge_vec".num_workers (int, optional) – Number of workers to use for computation. Has no effect in case batch is not
None, or the input lies on the GPU. Defaults to1.
Example
>>> N = 50 >>> pos = torch.randn(N,3) >>> data = Data(pos=pos) >>> print(data) Data(pos=[50, 3]) >>> data = RadiusGraph(0.5)(data) >>> print(data) Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3])
- class equitorch.transforms.AddSphericalHarmonics(l_max: int, vector_attr: str = 'edge_vec', sh_attr: str = 'edge_sh', integral_normalize: bool = False)
Bases:
BaseTransformCreates spherical harmonics embedding based on direction vectors
vector_attr(functional name:add_spherical_harmonics).- Parameters:
l_max (int) – The maximum degree of spherical harmonics.
vector_attr (str, optional) – The attribute name for direction vectors in the data. Defaults to
"edge_vec".sh_attr (str, optional) – The attribute name for storing spherical harmonics in the data. Defaults to
"edge_sh".integral_normalize (bool, optional) – Whether to normalize the spherical harmonics by \(\sqrt{4\pi / (2l+1)}\). Defaults to
False.
Example
>>> print(data) Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3]) >>> data = AddSphericalHarmonics(l_max=3)(data) >>> print(data) Data(pos=[50, 3], edge_index=[2, 36], edge_vec=[36, 3], edge_sh=[36, 16])
- class equitorch.transforms.AddVectorNorm(vector_attr: str = 'edge_vec', norm_attr: str = 'edge_norm')
Bases:
BaseTransformComputes the norm of a vector attribute and adds it to the data object.
This transform is useful for obtaining scalar distance information from vector representations, commonly used in graph neural networks for edge features.
(functional name:
add_vector_norm)- Parameters:
vector_attr (str, optional) – The attribute name for the input vector in the
torch_geometric.data.Dataobject. Defaults to'edge_vec'.norm_attr (str, optional) – The attribute name under which the computed norm will be stored in the
torch_geometric.data.Dataobject. Defaults to'edge_norm'.