spamosaic.train_utils.train_model

spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, modBatch_intra_subgraphs, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]

End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches.

Parameters:
  • mod_model (dict) – Mapping {modality_name -> torch.nn.Module} (encoder/decoder).

  • mod_optims (dict) – Mapping {modality_name -> torch.optim.Optimizer}.

  • crit1 (dict or torch.nn.Module) – Contrastive loss. If a dict, it should be keyed by bridge batch id.

  • crit2 (torch.nn.Module) – Feature reconstruction criterion (e.g., nn.MSELoss).

  • loss_type ({'adapted', 'ce'}) – Type of contrastive loss to apply.

  • w_rec_g (float) – Weight in [0, 1] for graph vs. feature reconstruction on test batches.

  • mod_input_feat (dict) – Modality → full-graph input features tensor.

  • modBatch_intra_subgraphs (dict or None) – Cached intra subgraphs; expects entries like modBatch_intra_subgraphs[k][bi] = {'edge_index': LongTensor[2, E] or None}.

  • batch_train_meta_numbers (dict) – Training metadata per batch (sizes, modality sets, etc.).

  • mod_batch_split (dict) – Per-modality list of batch sizes for splitting full-graph tensors.

  • T (float) – Temperature used for contrastive logits.

  • n_epochs (int) – Number of epochs.

  • device (torch.device or str) – Training device.

Returns:

(mod_model, loss_cl, loss_rec), where loss_cl/loss_rec are lists of scalars.

Return type:

tuple