spamosaic.train_utils

spamosaic.train_utils.graph_decode(z, edge_index)[source]

Compute edge-wise similarity scores using dot product of node embeddings.

Parameters:
  • z (torch.Tensor) – Node embeddings of shape [num_nodes, embedding_dim].

  • edge_index (torch.Tensor) – Edge index tensor with shape [2, num_edges].

Returns:

Edge probabilities computed via sigmoid(dot(z_i, z_j)).

Return type:

torch.Tensor

spamosaic.train_utils.graph_recon_crit(z, pos_edge_index)[source]

Compute binary cross-entropy loss for graph reconstruction using positive and negative edges.

Parameters:
  • z (torch.Tensor) – Node embeddings of shape [num_nodes, embedding_dim].

  • pos_edge_index (torch.Tensor) – Positive edge indices of shape [2, num_pos_edges].

Returns:

Scalar reconstruction loss combining positive and negative samples.

Return type:

torch.Tensor

spamosaic.train_utils.set_seeds(seed, dt=True)[source]

Set random seeds for reproducibility across multiple libraries.

Parameters:
  • seed (int) – Random seed to use for reproducibility.

  • dt (bool, default=True) – Whether to enforce deterministic algorithms in PyTorch.

spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, mod_graphs_edge, mod_graphs_edge_w, mod_graphs_edge_c, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]

Train multi-modal GNN model with contrastive and reconstruction losses.

Parameters:
  • mod_model (dict[str, torch.nn.Module]) – Dictionary mapping modality to its encoder model.

  • mod_optims (dict[str, torch.optim.Optimizer]) – Optimizer for each modality.

  • crit1 (dict[int, Callable] or torch.nn.Module) – Contrastive loss function or dict of such functions per batch.

  • crit2 (Callable) – Reconstruction loss function (e.g., MSELoss).

  • loss_type (str) – Either ‘adapted’ (multi-modality) or ‘standard’ (2-modality) contrastive loss.

  • w_rec_g (float) – Weight for graph reconstruction loss; (1 - w_rec_g) used for feature reconstruction.

  • mod_input_feat (dict[str, torch.Tensor]) – Input features per modality.

  • mod_graphs_edge (dict[str, torch.Tensor]) – Edge indices per modality.

  • mod_graphs_edge_w (dict[str, torch.Tensor]) – Edge weights per modality.

  • mod_graphs_edge_c (dict[str, torch.Tensor]) – Batch-specific edge masks (used to isolate subgraphs).

  • batch_train_meta_numbers (dict[int, tuple]) – Metadata per batch describing whether it is bridge/test, size, and modalities.

  • mod_batch_split (dict[str, list[int]]) – Mapping modality to list of cell counts per batch.

  • T (float) – Temperature for contrastive loss.

  • n_epochs (int) – Number of training epochs.

  • device (torch.device) – Device used for training.

Returns:

Trained model dict, list of contrastive losses, list of reconstruction losses.

Return type:

tuple

Functions

spamosaic.train_utils.graph_decode

Compute edge-wise similarity scores using dot product of node embeddings.

spamosaic.train_utils.graph_recon_crit

Compute binary cross-entropy loss for graph reconstruction using positive and negative edges.

spamosaic.train_utils.set_seeds

Set random seeds for reproducibility across multiple libraries.

spamosaic.train_utils.train_model

Train multi-modal GNN model with contrastive and reconstruction losses.