Source code for spamosaic.architectures.wlgcn

"""Weighted Light graph convolution (WLGCN) layers for SpaMosaic.

Implements a parameter-free propagation layer on pre-normalized sparse adjacencies
and an MLP head for embedding + reconstruction.
"""

from __future__ import annotations
from typing import Optional, Literal, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor

[docs]class WLGCN_vanilla(nn.Module): """ Parameter-free LightGCN-style propagation. Accepts a precomputed normalized sparse adjacency :math:`\\hat{A} \\approx D^{-1/2}(A + I\\cdot\\text{fill})D^{-1/2}` as a ``torch_sparse.SparseTensor``. Parameters ---------- K : int, default=1 Number of propagation steps (returns ``K+1`` representations including the 0-th order input). agg : {'cat', 'sum', 'mean'}, default='cat' How to aggregate representations across steps. """ def __init__( self, K: int = 1, agg: Literal["cat", "sum", "mean"] = "cat", ): super().__init__() self.K = int(K) self.agg = agg
[docs] def forward(self, x: torch.Tensor, A_hat: SparseTensor) -> torch.Tensor: """ Propagate features via ``A_hat`` for ``K`` steps and aggregate. Parameters ---------- x : torch.Tensor Node features of shape ``[N, F]``. A_hat : torch_sparse.SparseTensor Pre-normalized sparse adjacency on the correct device. Returns ------- torch.Tensor Aggregated output. Shapes depend on ``agg``: - ``'cat'`` → ``[N, F * (K+1)]`` - ``'sum'`` → ``[N, F]`` - ``'mean'`` → ``[N, F]`` """ H = x if self.agg == "cat": outs = [H] else: acc = H for _ in range(self.K): H = A_hat.matmul(H) if self.agg == "cat": outs.append(H) else: acc = acc + H if self.agg == "cat": return torch.cat(outs, dim=-1) elif self.agg == "sum": return acc else: # mean denom = self.K + 1 return acc / (denom + 1e-12)