spamosaic.train_utils
Training utilities for SpaMosaic.
Provides seed control, edge decoding & graph reconstruction loss, and the main training loop.
- spamosaic.train_utils.graph_decode(z, edge_index)[source]
Compute edge-wise similarity scores using dot product of node embeddings.
- Parameters:
z (torch.Tensor) – Node embeddings of shape [num_nodes, embedding_dim].
edge_index (torch.Tensor) – Edge index tensor with shape [2, num_edges].
- Returns:
Edge probabilities computed via sigmoid(dot(z_i, z_j)).
- Return type:
torch.Tensor
- spamosaic.train_utils.graph_recon_crit(z, pos_edge_index, num_neg)[source]
BCE-style reconstruction loss on sampled positive/negative edges.
- Parameters:
z (torch.Tensor) – Node embeddings of shape
[N, D].pos_edge_index (torch.Tensor) – Positive edges as index tensor of shape
[2, E_pos].num_neg (int or None) – Number of negative edges to sample. If
None, defaults toE_pos.
- Returns:
Scalar loss value averaged over positives and negatives.
- Return type:
torch.Tensor
- spamosaic.train_utils.set_seeds(seed, dt=True)[source]
Set random seeds for reproducibility across multiple libraries.
- Parameters:
seed (int) – Random seed to use for reproducibility.
dt (bool, default=True) – Whether to enforce deterministic algorithms in PyTorch.
- spamosaic.train_utils.train_model(mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g, mod_input_feat, modBatch_intra_subgraphs, batch_train_meta_numbers, mod_batch_split, T, n_epochs, device)[source]
End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches.
- Parameters:
mod_model (dict) – Mapping
{modality_name -> torch.nn.Module}(encoder/decoder).mod_optims (dict) – Mapping
{modality_name -> torch.optim.Optimizer}.crit1 (dict or torch.nn.Module) – Contrastive loss. If a dict, it should be keyed by bridge batch id.
crit2 (torch.nn.Module) – Feature reconstruction criterion (e.g.,
nn.MSELoss).loss_type ({'adapted', 'ce'}) – Type of contrastive loss to apply.
w_rec_g (float) – Weight in
[0, 1]for graph vs. feature reconstruction on test batches.mod_input_feat (dict) – Modality → full-graph input features tensor.
modBatch_intra_subgraphs (dict or None) – Cached intra subgraphs; expects entries like
modBatch_intra_subgraphs[k][bi] = {'edge_index': LongTensor[2, E] or None}.batch_train_meta_numbers (dict) – Training metadata per batch (sizes, modality sets, etc.).
mod_batch_split (dict) – Per-modality list of batch sizes for splitting full-graph tensors.
T (float) – Temperature used for contrastive logits.
n_epochs (int) – Number of epochs.
device (torch.device or str) – Training device.
- Returns:
(mod_model, loss_cl, loss_rec), whereloss_cl/loss_recare lists of scalars.- Return type:
tuple
Functions
Compute edge-wise similarity scores using dot product of node embeddings. |
|
BCE-style reconstruction loss on sampled positive/negative edges. |
|
Set random seeds for reproducibility across multiple libraries. |
|
End-to-end training loop with contrastive loss on bridge batches and feature/graph reconstruction on test batches. |