spamosaic.framework

SpaMosaic framework for multi-modal spatial omics integration.

Builds intra-/inter-batch spatial graphs, smooths features, aligns modalities, trains encoders, and provides embedding and imputation utilities.

class spamosaic.framework.SpaMosaic(modBatch_dict={}, input_key='dimred_bc', mnn_rep_key=None, batch_key='batch', radius_cutoff=2000, intra_knns=10, inter_knn_base=10, smooth_input=False, smooth_L=1, inter_auto_knn=False, inter_auto_thr=0.8, rmv_outlier=False, contamination='auto', w_g=0.8, log_dir=None, seed=1234, num_workers=6, device='cuda:0')[source]

Bases: object

SpaMosaic: a modular framework for multi-modal spatial omics integration.

This class orchestrates data pre-processing, intra- and inter-batch graph construction, optional feature smoothing, model initialization/training, cross-modality alignment, and downstream embedding/imputation.

Parameters:
  • modBatch_dict (dict) – Mapping modality name (e.g., 'rna', 'adt') to a list of AnnData batches. Example: {'rna': [batch1, None, ...], 'adt': [batch1, batch2, ...]}.

  • input_key (str) – Key in .obsm where input features are stored (e.g., 'dimred_bc').

  • mnn_rep_key (str, optional) – Representation key used for MNN search. If None, defaults to input_key.

  • batch_key (str) – Column name in .obs denoting batch identity.

  • radius_cutoff (int) – Radius threshold used to construct spatial neighbor graphs.

  • intra_knns (int or list of int) – Number of neighbors for intra-batch graphs (single int or per-batch list).

  • inter_knn_base (int) – Base KNN size for inter-batch MNN search.

  • smooth_input (bool) – If True, apply WLGCN-based input feature smoothing.

  • smooth_L (int) – Number of WLGCN layers used for smoothing.

  • inter_auto_knn (bool) – If True, adapt inter-batch KNN size based on batch-size ratio.

  • inter_auto_thr (float) – Size-ratio threshold for adaptive KNN.

  • rmv_outlier (bool) – If True, remove outlier MNN pairs via Isolation Forest.

  • contamination (str or float) – Contamination level for outlier detection (IsolationForest).

  • w_g (float) – Weight for inter-batch expression edges in the merged graph.

  • log_dir (str, optional) – Directory for saving logs or results.

  • seed (int) – Random seed.

  • num_workers (int) – Number of workers for computation.

  • device (str) – Device string, e.g., 'cuda:0' or 'cpu'.

apply_smoothing(modBatch_dict, mod_graphs, key, added_key, symm=False)[source]

Apply WLGCN feature smoothing per modality.

Parameters:
  • modBatch_dict (dict) – Modality -> list of AnnData objects.

  • mod_graphs (dict) – Modality -> adjacency graph (scipy sparse).

  • key (str) – Key for input features in .obsm.

  • added_key (str) – Key to store smoothed features in .obsm.

  • symm (bool, optional) – If True, symmetrize the input adjacency before normalization. Default is False.

cache_spatial_graph()[source]

Cache per-modality, per-batch intra subgraphs using batch index ranges.

Assumes each batch occupies a contiguous node range in the merged intra graph.

Returns:

Nested mapping {modality -> {batch_id -> {'nodes': Tensor, 'edge_index': Tensor or None}}}.

Return type:

dict

check_integrity()[source]

Verify that all modalities across batches form a connected integration graph.

Raises:

RuntimeError – If the graph of shared modalities across batches is not fully connected.

g2ts(A_hat)[source]

Convert a SciPy sparse matrix to torch_sparse.SparseTensor.

Parameters:

A_hat (scipy.sparse.spmatrix) – Symmetrized normalized adjacency.

Returns:

Coalesced sparse tensor with the same shape as A_hat.

Return type:

torch_sparse.SparseTensor

impute(modBatch_dict, emb_key='emb', layer_key='counts', imp_knn=10)[source]

Impute missing modalities using the aligned embedding space and KNN.

Parameters:
  • modBatch_dict (dict) – Input dictionary of modalities and batches.

  • emb_key (str) – Key where latent embeddings are stored.

  • layer_key (str) – Which layer to impute (e.g., 'counts').

  • imp_knn (int) – Number of neighbors to use in KNN-based imputation.

Returns:

Imputed data dictionary: {modality -> list of arrays (or None)}.

Return type:

dict

infer_emb(modBatch_dict, emb_key='emb', final_latent_key='merged_emb', cat=False)[source]

Infer latent embeddings for each cell and return merged AnnData list.

Parameters:
  • modBatch_dict (dict) – Original input dictionary of modalities and batches.

  • emb_key (str) – Key to store intermediate embeddings.

  • final_latent_key (str) – Key to store final merged embedding in returned AnnData.

  • cat (bool) – If True, concatenate modality embeddings; otherwise average them.

Returns:

Reconstructed AnnData objects with merged embeddings.

Return type:

list of AnnData

prepare_inputs(modBatch_dict)[source]

Merge AnnData objects and construct final PyTorch graph inputs.

Notes

  • Concatenate features and adjacency matrices.

  • Add intra- and inter-batch edges.

  • Smooth features on merged graphs.

prepare_inter_graphs(modBatch_dict)[source]

Build mutual nearest neighbor (MNN) graphs between batches for each modality.

Notes

  • Identify bridge vs. non-bridge batches.

  • Compute MNN pairs within each modality.

  • Optionally filter outliers.

prepare_intra_graphs(modBatch_dict)[source]

Build spatial neighbor graphs for each modality across batches.

Parameters:

modBatch_dict (dict) – Mapping modality name to list of AnnData objects.

prepare_net(net)[source]

Instantiate the architecture for each modality from a config.

Parameters:

net (str) – Name of model architecture (must match a YAML config).

Returns:

Mapping {modality_name -> torch.nn.Module}.

Return type:

dict

train(net, lr, use_mini_thr=8000, mini_batch_size=1024, loss_type='adapted', T=0.01, bias=0, n_epochs=100, w_rec_g=0.0)[source]

Train SpaMosaic using contrastive and reconstruction losses.

Parameters:
  • net (str) – Architecture name (used to load config).

  • lr (float) – Learning rate.

  • use_mini_thr (int) – Threshold above which mini-batch training is used.

  • mini_batch_size (int) – Size of mini-batches if needed.

  • loss_type ({'adapted', 'ce'}) – Contrastive loss type. 'adapted' supports ≥3 modalities.

  • T (float) – Temperature for contrastive loss.

  • bias (float) – Bias term in adapted contrastive loss.

  • n_epochs (int) – Number of training epochs.

  • w_rec_g (float) – Weight for graph reconstruction loss on test batches.

Return type:

None

Classes

spamosaic.framework.SpaMosaic

SpaMosaic: a modular framework for multi-modal spatial omics integration.