Source code for spamosaic.loss

import torch
import torch.nn.functional as F

[docs]def set_diag(matrix, v): """ Set the diagonal of a square matrix to a specified value. Parameters ---------- matrix : torch.Tensor A 2D square tensor whose diagonal will be modified. v : float Value to set on the diagonal. Returns ------- torch.Tensor The modified tensor with updated diagonal values. """ mask = torch.eye(matrix.size(0), dtype=torch.bool) matrix[mask] = v return matrix
[docs]class CL_loss(torch.nn.Module): """ Contrastive Loss for multi-view representation alignment. This loss function is designed to encourage representations from the same sample across multiple modalities to be similar, while pushing apart representations from different samples. Parameters ---------- batch_size : int Number of samples in a mini-batch (per modality). rep : int, optional Number of modalities or views (default: 3). bias : float, optional Small constant added to negative sample logits to avoid instability in log computations (default: 0). """ def __init__(self, batch_size, rep=3, bias=0): super().__init__() self.batch_size = batch_size self.n_mods = rep self.register_buffer("negatives_mask", (~torch.eye(batch_size * rep, batch_size * rep, dtype=bool)).float()) iids = torch.arange(batch_size).repeat(rep) pos_mask = set_diag(iids.view(-1, 1) == iids.view(1, -1), 0) self.register_buffer('pos_mask', pos_mask.float()) self.bias = bias
[docs] def forward(self, simi): """ Compute contrastive loss from a similarity matrix. Parameters ---------- simi : torch.Tensor Pairwise similarity matrix of shape [B * rep, B * rep]. Returns ------- torch.Tensor Scalar loss value. """ simi_max, _ = torch.max(simi, dim=1, keepdim=True) simi = simi - simi_max.detach() positives = (simi * self.pos_mask).sum(dim=1) / self.pos_mask.sum(dim=1) negatives = (torch.exp(simi) * self.negatives_mask).sum(dim=1) loss = -(positives - torch.log(negatives+self.bias)).mean() # adding a non-zero constant in case grad explosion return loss