spamosaic.architectures.gat
Graph attention layers for SpaMosaic.
Provides a GATConv and a 4-layer tied-weight GAT encoder-decoder for spatial data.
- class spamosaic.architectures.gat.GAT(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA 4-layer Graph Attention Network used for spatial transcriptomics.
This module implements a symmetric encoder-decoder GAT with tied weights.
- Parameters:
hidden_dims (list of int) – A list [in_dim, hidden_dim, out_dim] specifying the feature dimensions.
- forward(features, edge_index)[source]
Forward pass of the GAT model.
- Parameters:
features (Tensor) – Input node features.
edge_index (Tensor) – Edge indices in COO format.
- Returns:
Latent embeddings after encoder (normalized).
Reconstructed features after decoder.
- Return type:
tuple of (Tensor, Tensor)
- class spamosaic.architectures.gat.GATConv(*args: Any, **kwargs: Any)[source]
Bases:
MessagePassingGraph Attention Network (GAT) layer adapted from STAGATE implementation.
- Parameters:
in_channels (int or tuple of int) – Dimension(s) of input node features.
out_channels (int) – Dimension of output node features.
heads (int, default=1) – Number of attention heads.
concat (bool, default=True) – Whether to concatenate multi-head outputs (True) or average them (False).
negative_slope (float, default=0.2) – LeakyReLU angle of the negative slope.
dropout (float, default=0.0) – Dropout probability for attention coefficients.
add_self_loops (bool, default=True) – Whether to add self-loops to the input graph.
bias (bool, default=True) – Whether to add bias (not used here).
**kwargs (optional) – Additional arguments for MessagePassing.
- forward(x: Union[torch.Tensor, torch_geometric.typing.OptPairTensor], edge_index: torch_geometric.typing.Adj, size: Optional[torch_geometric.typing.Size] = None, return_attention_weights=None, attention=True, tied_attention=None)[source]
Forward pass for the GAT layer.
- Parameters:
x (Tensor or tuple of Tensors) – Input node features, shape (N, F).
edge_index (Tensor or SparseTensor) – Graph connectivity in COO format.
size (tuple of int, optional) – Size of source and target node sets.
return_attention_weights (bool, optional) – Whether to return attention weights along with output.
attention (bool, default=True) – Whether to apply attention mechanism.
tied_attention (tuple of Tensors, optional) – Precomputed attention weights to reuse.
- Returns:
Output node embeddings, and optionally attention weights if
return_attention_weightsis True.- Return type:
Tensor or (Tensor, Any)
- message(x_j: torch.Tensor, alpha_j: torch.Tensor, alpha_i: torch_geometric.typing.OptTensor, index: torch.Tensor, ptr: torch_geometric.typing.OptTensor, size_i: Optional[int]) torch.Tensor[source]
Message passing step: applies attention-weighted aggregation.
- Parameters:
x_j (Tensor) – Features of source nodes.
alpha_j (Tensor) – Attention logits from source nodes.
alpha_i (Tensor or None) – Attention logits from target nodes.
index (Tensor) – Target indices for aggregation.
ptr (Tensor or None) – Optional pointer array for variable-sized batches.
size_i (int, optional) – Number of target nodes.
- Returns:
Aggregated node features after attention weighting.
- Return type:
Tensor
Classes
A 4-layer Graph Attention Network used for spatial transcriptomics. |
|
Graph Attention Network (GAT) layer adapted from STAGATE implementation. |