Source code for spamosaic.architectures.hgt

from typing import List, Optional, Union, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv

[docs]class HGT(torch.nn.Module): """ Heterogeneous Graph Transformer (HGT) model for learning on heterogeneous graphs. This implementation builds on PyTorch Geometric's HGTConv and supports node-type specific input projection, multi-layer HGT attention, and per-type decoders. Parameters ---------- in_channels : int Input feature dimension for each node. hidden_channels : int Hidden feature dimension used in HGT layers. num_heads : int Number of attention heads in each HGTConv layer. num_layers : int Number of stacked HGTConv layers. n_dec_l : int Number of layers in decoder. If 1, uses a single linear layer. data_obj : HeteroData A PyG HeteroData object describing the node and edge types in the graph. out_channels : int, optional If specified, adds an intermediate projection layer to this dimension before decoding. Attributes ---------- lin_dict : nn.ModuleDict Node-type specific input projections to hidden space. convs : nn.ModuleList List of stacked HGTConv layers. decoder : nn.ModuleDict Node-type specific decoders for reconstruction. lin : nn.Linear, optional Optional projection layer if out_channels is specified. node_type : str The target node type for downstream prediction. """ def __init__(self, in_channels, hidden_channels, num_heads, num_layers, n_dec_l, data_obj, out_channels=None): super().__init__() self.lin_dict = torch.nn.ModuleDict() for node_type in data_obj.node_types: self.lin_dict[node_type] = torch.nn.Linear(in_channels, hidden_channels) self.node_type = data_obj.node_types[0] self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HGTConv(hidden_channels, hidden_channels, data_obj.metadata(), num_heads, group='sum') self.convs.append(conv) self.lin, dec_inp_channels = None, hidden_channels if out_channels is not None: self.lin = torch.nn.Linear(hidden_channels, out_channels) dec_inp_channels = out_channels self.decoder = torch.nn.ModuleDict() for node_type in data_obj.node_types: if n_dec_l == 1: self.decoder[node_type] = torch.nn.Linear(dec_inp_channels, in_channels) else: self.decoder[node_type] = torch.nn.Sequential( torch.nn.Linear(dec_inp_channels, dec_inp_channels), torch.nn.ReLU(), torch.nn.Linear(dec_inp_channels, in_channels) )
[docs] def forward(self, x_dict, edge_index_dict): """ Follow https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py Forward pass of the HGT model. Parameters ---------- x_dict : Dict[str, Tensor] Dictionary mapping node types to feature matrices. edge_index_dict : Dict[Tuple[str, str, str], Tensor] Dictionary mapping edge types to edge index tensors. Returns ------- z : Tensor The final node embeddings for the primary node type, L2 normalized. recon : Tensor Reconstructed input for the primary node type (for autoencoder training). """ x_dict = { node_type: self.lin_dict[node_type](x).relu_() for node_type, x in x_dict.items() } for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) r_dict = { node_type: self.decoder[node_type](x) for node_type, x in x_dict.items() } return F.normalize(x_dict[self.node_type], p=2, dim=1), r_dict[self.node_type]