"""Heterogeneous LightGCN layers and models for SpaMosaic.
Provides a LightGCN-style convolution for heterogeneous graphs and an encoder-decoder network.
"""
from typing import List, Optional, Union, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv
[docs]class HG_LGCN_Conv(MessagePassing):
"""
Heterogeneous Graph LightGCN convolution layer.
Performs degree-normalized aggregation with self-loops and separates
intra-group vs. inter-group message passing, then concatenates the two streams.
"""
def __init__(self):
super(HG_LGCN_Conv, self).__init__(aggr='add')
[docs] def forward(self, x, edge_index, edge_weight, edge_type):
"""
Forward pass for heterogeneous LightGCN convolution.
Parameters
----------
x : torch.Tensor
Input node features of shape ``(N, F)``.
edge_index : torch.LongTensor
Edge list in COO format, shape ``(2, E)``.
edge_weight : torch.Tensor
Edge weights, shape ``(E,)``.
edge_type : torch.Tensor
Edge type indicator (``1`` for intra-group, ``0`` for inter-group).
Returns
-------
torch.Tensor
Concatenated features from intra-group and inter-group aggregations.
"""
# Add self-loops to the adjacency.
edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
# Symmetric degree normalization with edge weights.
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=x.size(0))
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
# Split edges into intra-group (1) and inter-group (0).
intra_mask = edge_type == 1
inter_mask = edge_type == 0
# Convolve separately and concatenate.
intra_features = self.propagate(edge_index[:, intra_mask], x=x, norm=norm[intra_mask])
inter_features = self.propagate(edge_index[:, inter_mask], x=x, norm=norm[inter_mask])
return torch.cat([intra_features, inter_features], dim=1)
[docs] def message(self, x_j, norm):
"""
Compute per-edge messages.
Parameters
----------
x_j : torch.Tensor
Source-node features for each edge.
norm : torch.Tensor
Normalized scalar weight per edge.
Returns
-------
torch.Tensor
Weighted messages ``norm * x_j``.
"""
return norm.view(-1, 1) * x_j
[docs]class HG_LGCN_vanilla(torch.nn.Module):
"""
Stacked HG_LGCN_Conv with layer-wise feature concatenation.
Parameters
----------
num_layers : int
Number of HG_LGCN_Conv layers to stack.
Notes
-----
This class builds a list of LightGCN-style layers and concatenates the
input features with the output of each layer.
"""
def __init__(self, num_layers):
super(HG_LGCN_vanilla, self).__init__()
self.convs = torch.nn.ModuleList([HG_LGCN_Conv() for _ in range(num_layers)])
# Adjust the input dimension of the linear layer
# inp_dim = sum([2**i for i in range(num_layers+1)])*in_channels
[docs] def forward(self, x, edge_index, edge_weight, edge_type):
"""
Forward pass of stacked HG_LGCN layers.
Parameters
----------
x : torch.Tensor
Input node features.
edge_index : torch.LongTensor
Edge list in COO format.
edge_weight : torch.Tensor
Weights for each edge.
edge_type : torch.Tensor
Indicator for intra- vs. inter-group edges.
Returns
-------
torch.Tensor
Concatenated node features from all layers (including the input).
"""
all_features = [x]
for conv in self.convs:
x = conv(x, edge_index, edge_weight, edge_type)
all_features.append(x)
return torch.cat(all_features, dim=1)
[docs]class HG_LGCN(torch.nn.Module):
"""
Heterogeneous Graph LightGCN with encoder–decoder architecture.
Uses ``HG_LGCN_vanilla`` as the feature encoder and an MLP/linear decoder
for reconstruction.
Parameters
----------
input_size : int
Input feature dimensionality per node.
output_size : int
Latent embedding dimensionality.
K : int, default=8
Number of LightGCN layers.
dec_l : int, default=1
Number of decoder layers (``1`` = linear decoder).
hidden_size : int, default=512
Hidden size of the MLP head.
dropout : float, default=0.2
Dropout rate in the MLP head.
"""
def __init__(self, input_size, output_size, K=8, dec_l=1, hidden_size=512, dropout=0.2):
super(HG_LGCN, self).__init__()
self.conv1 = HG_LGCN_vanilla(num_layers=K)
inp_dim = sum([2**i for i in range(K+1)])*input_size
self.fc1 = torch.nn.Linear(inp_dim, hidden_size)
self.bn = torch.nn.BatchNorm1d(hidden_size)
self.dropout1 = torch.nn.Dropout(p=dropout)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
if dec_l == 1:
self.decoder = torch.nn.Linear(output_size, input_size)
else:
self.decoder = torch.nn.Sequential(
torch.nn.Linear(output_size, output_size),
torch.nn.ReLU(),
torch.nn.Linear(output_size, input_size)
)
[docs] def forward(self, feature, edge_index, edge_weight, edge_type):
"""
Forward pass for HG_LGCN.
Parameters
----------
feature : torch.Tensor
Input node features.
edge_index : torch.LongTensor
COO edge list.
edge_weight : torch.Tensor
Edge weights.
edge_type : torch.Tensor
Edge type indicators.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
- L2-normalized latent representation (``[N, output_size]``).
- Reconstructed features (``[N, input_size]``).
"""
x = self.conv1(feature, edge_index, edge_weight, edge_type)
x = F.leaky_relu(self.fc1(x), negative_slope=0.2)
x = self.bn(x)
x = self.dropout1(x)
x = self.fc2(x)
r = self.decoder(x)
x = F.normalize(x, p=2, dim=1)
return x, r