spamosaic.train_utils.graph_recon_crit

spamosaic.train_utils.graph_recon_crit(z, pos_edge_index, num_neg)[source]

BCE-style reconstruction loss on sampled positive/negative edges.

Parameters:
  • z (torch.Tensor) – Node embeddings of shape [N, D].

  • pos_edge_index (torch.Tensor) – Positive edges as index tensor of shape [2, E_pos].

  • num_neg (int or None) – Number of negative edges to sample. If None, defaults to E_pos.

Returns:

Scalar loss value averaged over positives and negatives.

Return type:

torch.Tensor