spamosaic.train_utils.train_model
- spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, mod_graphs_edge, mod_graphs_edge_w, mod_graphs_edge_c, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]
Train multi-modal GNN model with contrastive and reconstruction losses.
- Parameters:
mod_model (dict[str, torch.nn.Module]) – Dictionary mapping modality to its encoder model.
mod_optims (dict[str, torch.optim.Optimizer]) – Optimizer for each modality.
crit1 (dict[int, Callable] or torch.nn.Module) – Contrastive loss function or dict of such functions per batch.
crit2 (Callable) – Reconstruction loss function (e.g., MSELoss).
loss_type (str) – Either ‘adapted’ (multi-modality) or ‘standard’ (2-modality) contrastive loss.
w_rec_g (float) – Weight for graph reconstruction loss; (1 - w_rec_g) used for feature reconstruction.
mod_input_feat (dict[str, torch.Tensor]) – Input features per modality.
mod_graphs_edge (dict[str, torch.Tensor]) – Edge indices per modality.
mod_graphs_edge_w (dict[str, torch.Tensor]) – Edge weights per modality.
mod_graphs_edge_c (dict[str, torch.Tensor]) – Batch-specific edge masks (used to isolate subgraphs).
batch_train_meta_numbers (dict[int, tuple]) – Metadata per batch describing whether it is bridge/test, size, and modalities.
mod_batch_split (dict[str, list[int]]) – Mapping modality to list of cell counts per batch.
T (float) – Temperature for contrastive loss.
n_epochs (int) – Number of training epochs.
device (torch.device) – Device used for training.
- Returns:
Trained model dict, list of contrastive losses, list of reconstruction losses.
- Return type:
tuple