spamosaic.train_utils.train_model

spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, mod_graphs_edge, mod_graphs_edge_w, mod_graphs_edge_c, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]

Train multi-modal GNN model with contrastive and reconstruction losses.

Parameters:
  • mod_model (dict[str, torch.nn.Module]) – Dictionary mapping modality to its encoder model.

  • mod_optims (dict[str, torch.optim.Optimizer]) – Optimizer for each modality.

  • crit1 (dict[int, Callable] or torch.nn.Module) – Contrastive loss function or dict of such functions per batch.

  • crit2 (Callable) – Reconstruction loss function (e.g., MSELoss).

  • loss_type (str) – Either ‘adapted’ (multi-modality) or ‘standard’ (2-modality) contrastive loss.

  • w_rec_g (float) – Weight for graph reconstruction loss; (1 - w_rec_g) used for feature reconstruction.

  • mod_input_feat (dict[str, torch.Tensor]) – Input features per modality.

  • mod_graphs_edge (dict[str, torch.Tensor]) – Edge indices per modality.

  • mod_graphs_edge_w (dict[str, torch.Tensor]) – Edge weights per modality.

  • mod_graphs_edge_c (dict[str, torch.Tensor]) – Batch-specific edge masks (used to isolate subgraphs).

  • batch_train_meta_numbers (dict[int, tuple]) – Metadata per batch describing whether it is bridge/test, size, and modalities.

  • mod_batch_split (dict[str, list[int]]) – Mapping modality to list of cell counts per batch.

  • T (float) – Temperature for contrastive loss.

  • n_epochs (int) – Number of training epochs.

  • device (torch.device) – Device used for training.

Returns:

Trained model dict, list of contrastive losses, list of reconstruction losses.

Return type:

tuple