spamosaic.train_utils

Training utilities for SpaMosaic.

Provides seed control, edge decoding & graph reconstruction loss, and the main training loop.

spamosaic.train_utils.graph_decode(z, edge_index)[source]

Compute edge-wise similarity scores using dot product of node embeddings.

Parameters:
  • z (torch.Tensor) – Node embeddings of shape [num_nodes, embedding_dim].

  • edge_index (torch.Tensor) – Edge index tensor with shape [2, num_edges].

Returns:

Edge probabilities computed via sigmoid(dot(z_i, z_j)).

Return type:

torch.Tensor

spamosaic.train_utils.graph_recon_crit(z, pos_edge_index, num_neg)[source]

BCE-style reconstruction loss on sampled positive/negative edges.

Parameters:
  • z (torch.Tensor) – Node embeddings of shape [N, D].

  • pos_edge_index (torch.Tensor) – Positive edges as index tensor of shape [2, E_pos].

  • num_neg (int or None) – Number of negative edges to sample. If None, defaults to E_pos.

Returns:

Scalar loss value averaged over positives and negatives.

Return type:

torch.Tensor

spamosaic.train_utils.set_seeds(seed, dt=True)[source]

Set random seeds for reproducibility across multiple libraries.

Parameters:
  • seed (int) – Random seed to use for reproducibility.

  • dt (bool, default=True) – Whether to enforce deterministic algorithms in PyTorch.

spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, modBatch_intra_subgraphs, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]

End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches.

Parameters:
  • mod_model (dict) – Mapping {modality_name -> torch.nn.Module} (encoder/decoder).

  • mod_optims (dict) – Mapping {modality_name -> torch.optim.Optimizer}.

  • crit1 (dict or torch.nn.Module) – Contrastive loss. If a dict, it should be keyed by bridge batch id.

  • crit2 (torch.nn.Module) – Feature reconstruction criterion (e.g., nn.MSELoss).

  • loss_type ({'adapted', 'ce'}) – Type of contrastive loss to apply.

  • w_rec_g (float) – Weight in [0, 1] for graph vs. feature reconstruction on test batches.

  • mod_input_feat (dict) – Modality → full-graph input features tensor.

  • modBatch_intra_subgraphs (dict or None) – Cached intra subgraphs; expects entries like modBatch_intra_subgraphs[k][bi] = {'edge_index': LongTensor[2, E] or None}.

  • batch_train_meta_numbers (dict) – Training metadata per batch (sizes, modality sets, etc.).

  • mod_batch_split (dict) – Per-modality list of batch sizes for splitting full-graph tensors.

  • T (float) – Temperature used for contrastive logits.

  • n_epochs (int) – Number of epochs.

  • device (torch.device or str) – Training device.

Returns:

(mod_model, loss_cl, loss_rec), where loss_cl/loss_rec are lists of scalars.

Return type:

tuple

Functions

spamosaic.train_utils.graph_decode

Compute edge-wise similarity scores using dot product of node embeddings.

spamosaic.train_utils.graph_recon_crit

BCE-style reconstruction loss on sampled positive/negative edges.

spamosaic.train_utils.set_seeds

Set random seeds for reproducibility across multiple libraries.

spamosaic.train_utils.train_model

End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches.