spamosaic.train_utils.train_model
- spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, modBatch_intra_subgraphs, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]
End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches.
- Parameters:
mod_model (dict) – Mapping
{modality_name -> torch.nn.Module}(encoder/decoder).mod_optims (dict) – Mapping
{modality_name -> torch.optim.Optimizer}.crit1 (dict or torch.nn.Module) – Contrastive loss. If a dict, it should be keyed by bridge batch id.
crit2 (torch.nn.Module) – Feature reconstruction criterion (e.g.,
nn.MSELoss).loss_type ({'adapted', 'ce'}) – Type of contrastive loss to apply.
w_rec_g (float) – Weight in
[0, 1]for graph vs. feature reconstruction on test batches.mod_input_feat (dict) – Modality → full-graph input features tensor.
modBatch_intra_subgraphs (dict or None) – Cached intra subgraphs; expects entries like
modBatch_intra_subgraphs[k][bi] = {'edge_index': LongTensor[2, E] or None}.batch_train_meta_numbers (dict) – Training metadata per batch (sizes, modality sets, etc.).
mod_batch_split (dict) – Per-modality list of batch sizes for splitting full-graph tensors.
T (float) – Temperature used for contrastive logits.
n_epochs (int) – Number of epochs.
device (torch.device or str) – Training device.
- Returns:
(mod_model, loss_cl, loss_rec), whereloss_cl/loss_recare lists of scalars.- Return type:
tuple