Source code for spamosaic.architectures.wlgcn

from typing import List, Optional, Union, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv, GCNConv, GAE
from torch_geometric.utils import negative_sampling

[docs]def sym_norm(edge_index:torch.Tensor, num_nodes:int, edge_weight:Optional[Union[Any,torch.Tensor]]=None, improved:Optional[bool]=False, dtype:Optional[Any]=None )-> List: """ Compute the symmetric normalized adjacency matrix with optional edge weights. Parameters ---------- edge_index : torch.Tensor The edge indices of the graph (2, num_edges). num_nodes : int Number of nodes in the graph. edge_weight : torch.Tensor, optional Edge weights corresponding to edge_index. If None, assumes unweighted graph. improved : bool, optional Whether to use improved self-loops (value=2 instead of 1). Default is False. dtype : torch.dtype, optional Data type for the output tensor. Returns ------- edge_index : torch.Tensor Edge indices with added self-loops. norm : torch.Tensor Normalized edge weights. """ if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 if not improved else 2 edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value, num_nodes) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
[docs]class WLGCN_vanilla(MessagePassing): """ Vanilla Weighted Light Graph Convolutional Network (WLGCN) layer. This layer performs K steps of message passing and returns a concatenated representation of all intermediate embeddings. Parameters ---------- K : int, optional Number of propagation steps. Default is 1. cached : bool, optional Not used; for compatibility. bias : bool, optional Not used; for compatibility. """ def __init__(self, K:Optional[int]=1, cached:Optional[bool]=False, bias:Optional[bool]=True, **kwargs): super(WLGCN_vanilla, self).__init__(aggr='add', **kwargs) self.K = K
[docs] def forward(self, x:torch.Tensor, edge_index:torch.Tensor, edge_weight:Union[torch.Tensor,None]=None): """ Forward pass for WLGCN_vanilla. Parameters ---------- x : torch.Tensor Input node features of shape (N, F). edge_index : torch.Tensor Edge indices in COO format. edge_weight : torch.Tensor or None Optional edge weights. Returns ------- torch.Tensor Concatenated output features from all propagation steps. """ edge_index, norm = sym_norm(edge_index, x.size(0), edge_weight, dtype=x.dtype) xs = [x] for k in range(self.K): xs.append(self.propagate(edge_index, x=xs[-1], norm=norm)) return torch.cat(xs, dim = 1)
[docs] def message(self, x_j, norm): return norm.view(-1, 1) * x_j
def __repr__(self): return '{}({}, {}, K={})'.format(self.__class__.__name__, # self.in_channels, self.out_channels, self.K)
[docs]class WLGCN(torch.nn.Module): """ Deep WLGCN with encoder-decoder structure for representation learning. Parameters ---------- input_size : int Input feature dimension. output_size : int Output embedding dimension. K : int, optional Number of GCN propagation steps (default: 8). dec_l : int, optional Number of layers in the decoder (1 or 2). Default is 1. hidden_size : int, optional Hidden layer size for encoder. Default is 512. dropout : float, optional Dropout rate. Default is 0.2. slope : float, optional LeakyReLU negative slope. Default is 0.2. """ def __init__(self, input_size, output_size, K=8, dec_l=1, hidden_size=512, dropout=0.2, slope=0.2): super(WLGCN, self).__init__() self.conv1 = WLGCN_vanilla(K=K) self.fc1 = torch.nn.Linear(input_size * (K + 1), hidden_size) self.bn = torch.nn.BatchNorm1d(hidden_size) self.dropout1 = torch.nn.Dropout(p=dropout) self.fc2 = torch.nn.Linear(hidden_size, output_size) self.negative_slope = slope if dec_l == 1: self.decoder = torch.nn.Linear(output_size, input_size) else: self.decoder = torch.nn.Sequential( torch.nn.Linear(output_size, output_size), torch.nn.ReLU(), torch.nn.Linear(output_size, input_size) )
[docs] def forward(self, feature, edge_index, edge_weight=None): """ Forward pass of the WLGCN model. Parameters ---------- feature : torch.Tensor Input node features of shape (N, F). edge_index : torch.Tensor Edge indices in COO format. edge_weight : torch.Tensor or None Optional edge weights. Returns ------- Tuple[torch.Tensor, torch.Tensor] - Normalized latent embeddings. - Reconstructed input features. """ x = self.conv1(feature, edge_index, edge_weight) x = F.leaky_relu(self.fc1(x), negative_slope=self.negative_slope) x = self.bn(x) x = self.dropout1(x) x = self.fc2(x) r = self.decoder(x) x = F.normalize(x, p=2, dim=1) return x, r
# class FGAE(GAE): # def __init__(self, encoder, w_g): # super(FGAE, self).__init__(encoder) # self.w_f = 1-w_g # self.w_g = w_g # self.rec_crit = nn.MSELoss() # def encode(self, x, edge_index, edge_weight=None): # return self.encoder(x, edge_index, edge_weight) # def decode(self, z, edge_index): # # z = F.normalize(z, p=2, dim=1) # logit = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1) # return torch.sigmoid(logit) # def recon_loss(self, x, r, z, pos_edge_index): # if self.w_g > 0: # # graph reconstruction loss # pos_loss = -torch.log(self.decode(z, pos_edge_index)+1e-20) # pos_loss = pos_loss.mean() # # print(f'pos_edge_index.shape={pos_edge_index.shape}, max_edge_index=({pos_edge_index[0].max()}, {pos_edge_index[1].max()}), z.size={z.size(0)}') # neg_edge_index = negative_sampling( # pos_edge_index, z.size(0), num_neg_samples=pos_edge_index.size(1) # ) # # neg_edge_index = self.subgraph_negative_sampling(pos_edge_index) # neg_loss = -torch.log(1 - self.decode(z, neg_edge_index) + 1e-20) # neg_loss = neg_loss.mean() # graph_rec_loss = pos_loss + neg_loss # else: # graph_rec_loss = 0. # if self.w_f > 0: # # feature reconstruction loss # feat_rec_loss = self.rec_crit(r, x) # else: # feat_rec_loss = 0. # return self.w_f * feat_rec_loss + self.w_g * graph_rec_loss