spamosaic.architectures.hgt

Heterogeneous Graph Transformer (HGT) for SpaMosaic.

Defines an HGT model built on PyTorch Geometric’s HGTConv with per-type projections and decoders.

class spamosaic.architectures.hgt.HGT(*args: Any, **kwargs: Any)[source]

Bases: Module

Heterogeneous Graph Transformer (HGT) model for learning on heterogeneous graphs.

This implementation builds on PyTorch Geometric’s HGTConv and supports node-type specific input projection, multi-layer HGT attention, and per-type decoders.

Parameters:
  • in_channels (int) – Input feature dimension for each node.

  • hidden_channels (int) – Hidden feature dimension used in HGT layers.

  • num_heads (int) – Number of attention heads in each HGTConv layer.

  • num_layers (int) – Number of stacked HGTConv layers.

  • n_dec_l (int) – Number of layers in decoder. If 1, uses a single linear layer.

  • data_obj (HeteroData) – PyG HeteroData object describing node/edge types in the graph.

  • out_channels (int, optional) – If specified, adds an intermediate projection to this dimension before decoding.

lin_dict

Node-type specific input projections to hidden space.

Type:

nn.ModuleDict

convs

List of stacked HGTConv layers.

Type:

nn.ModuleList

decoder

Node-type specific decoders for reconstruction.

Type:

nn.ModuleDict

lin

Optional projection layer if out_channels is specified.

Type:

nn.Linear or None

node_type

The primary node type (first in data_obj.node_types) used for outputs.

Type:

str

forward(x_dict, edge_index_dict)[source]

Forward pass of the HGT model.

Parameters:
  • x_dict (dict[str, torch.Tensor]) – Mapping from node types to feature matrices.

  • edge_index_dict (dict[tuple[str, str, str], torch.Tensor]) – Mapping from edge types to edge index tensors in COO format.

Returns:

  • torch.Tensor – L2-normalized node embeddings for the primary node type.

  • torch.Tensor – Reconstructed inputs for the primary node type (for AE-style training).

Notes

The layer stacking follows the official PyG HGT example (DBLP). See the PyG repository for reference.

Classes

spamosaic.architectures.hgt.HGT

Heterogeneous Graph Transformer (HGT) model for learning on heterogeneous graphs.