"""SpaMosaic framework for multi-modal spatial omics integration.
Builds intra-/inter-batch spatial graphs, smooths features, aligns modalities, trains encoders,
and provides embedding and imputation utilities.
"""
import os, gc
from collections.abc import Iterable
from pathlib import Path, PurePath
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import scanpy as sc
import math
from tqdm import tqdm
import scipy.sparse as sps
import warnings
import itertools
from os.path import join
import torch
import logging
import torch
from matplotlib import rcParams
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score
from scipy.sparse.csgraph import connected_components
from torch_sparse import SparseTensor
from torch_geometric.data import HeteroData, Data
from torch_geometric.nn import GAE
from torch_geometric.utils import train_test_split_edges, negative_sampling
from spamosaic.train_utils import set_seeds, train_model
import spamosaic.utils as utls
from spamosaic.build_graph import make_Ahat_sparse
import spamosaic.build_graph as build_graph
import spamosaic.architectures as archs
from spamosaic.loss import CL_loss
cur_dir = os.path.dirname(os.path.abspath(__file__))
[docs]class SpaMosaic(object):
"""
SpaMosaic: a modular framework for multi-modal spatial omics integration.
This class orchestrates data pre-processing, intra- and inter-batch graph construction,
optional feature smoothing, model initialization/training, cross-modality alignment,
and downstream embedding/imputation.
Parameters
----------
modBatch_dict : dict
Mapping modality name (e.g., ``'rna'``, ``'adt'``) to a list of AnnData batches.
Example: ``{'rna': [batch1, None, ...], 'adt': [batch1, batch2, ...]}``.
input_key : str
Key in ``.obsm`` where input features are stored (e.g., ``'dimred_bc'``).
mnn_rep_key : str, optional
Representation key used for MNN search. If ``None``, defaults to ``input_key``.
batch_key : str
Column name in ``.obs`` denoting batch identity.
radius_cutoff : int
Radius threshold used to construct spatial neighbor graphs.
intra_knns : int or list of int
Number of neighbors for intra-batch graphs (single int or per-batch list).
inter_knn_base : int
Base KNN size for inter-batch MNN search.
smooth_input : bool
If ``True``, apply WLGCN-based input feature smoothing.
smooth_L : int
Number of WLGCN layers used for smoothing.
inter_auto_knn : bool
If ``True``, adapt inter-batch KNN size based on batch-size ratio.
inter_auto_thr : float
Size-ratio threshold for adaptive KNN.
rmv_outlier : bool
If ``True``, remove outlier MNN pairs via Isolation Forest.
contamination : str or float
Contamination level for outlier detection (IsolationForest).
w_g : float
Weight for inter-batch expression edges in the merged graph.
log_dir : str, optional
Directory for saving logs or results.
seed : int
Random seed.
num_workers : int
Number of workers for computation.
device : str
Device string, e.g., ``'cuda:0'`` or ``'cpu'``.
"""
def __init__(
self,
modBatch_dict={}, # dict={'rna':[batch1, None, ...], 'adt':[batch1, batch2, ...], ...}
input_key='dimred_bc',
mnn_rep_key=None,
batch_key='batch',
radius_cutoff=2000, intra_knns=10, inter_knn_base=10,
smooth_input=False, smooth_L=1,
inter_auto_knn=False, inter_auto_thr=0.8,
rmv_outlier=False, contamination='auto',
w_g=0.8, log_dir=None,
seed=1234, num_workers=6,
device='cuda:0'
):
"""
Initialize the SpaMosaic framework and prepare graphs/inputs.
Notes
-----
This constructor will:
1) check dataset integrity across modalities/batches,
2) build per-modality intra-batch spatial graphs,
3) (optionally) smooth input features and update the MNN representation key,
4) build inter-batch MNN graphs per modality,
5) assemble final inputs (features/graphs) for training.
Parameters
----------
modBatch_dict, input_key, mnn_rep_key, batch_key, radius_cutoff, intra_knns,
inter_knn_base, smooth_input, smooth_L, inter_auto_knn, inter_auto_thr,
rmv_outlier, contamination, w_g, log_dir, seed, num_workers, device
See class parameters.
"""
if 'cuda' in device and not torch.cuda.is_available():
warnings.warn("CUDA was requested but no CUDA device is available. Falling back to CPU.")
self.device = torch.device(device) if ('cuda' in device) and torch.cuda.is_available() else torch.device('cpu')
self.log_dir = log_dir
if log_dir:
os.makedirs(log_dir, exist_ok=True)
set_seeds(seed)
self.radius_cutoff = radius_cutoff
self.w_g = w_g
self.seed = seed
self.num_workers = num_workers
self.rmv_outlier = rmv_outlier
self.contamination = contamination
self.smooth_input = smooth_input
self.smooth_L = smooth_L
self.input_key = input_key
self.mnn_rep_key = input_key if mnn_rep_key is None else mnn_rep_key
self.mod_list = np.array(list(modBatch_dict.keys()))
self.n_mods = len(self.mod_list)
self.n_batches = len(modBatch_dict[self.mod_list[0]])
self.batch_key = batch_key
self.barc2batch = utls.get_barc2batch(modBatch_dict)
self.intra_knns = intra_knns if isinstance(intra_knns, Iterable) else [intra_knns]*self.n_batches
self.inter_knn_base = inter_knn_base
self.inter_auto_knn = inter_auto_knn
self.inter_auto_thr = inter_auto_thr
self.use_expr_adj = (w_g!=0)
# check if there is empty batch
self.batch_contained_mod_ids = utls.check_batch_empty(modBatch_dict)
# check if this dataset can be integrated
self.check_integrity()
# prepare spot-spot for each modality
self.prepare_intra_graphs(modBatch_dict)
self.intra_smth_suffix = '_intraSmth'
self.global_smth_suffix = '_globalSmth'
if self.smooth_input:
self.apply_smoothing(modBatch_dict, self.mod_intraGraphs, self.input_key, self.input_key + self.intra_smth_suffix)
self.mnn_rep_key = self.input_key + self.intra_smth_suffix
self.prepare_inter_graphs(modBatch_dict)
self.prepare_inputs(modBatch_dict)
[docs] def check_integrity(self):
"""
Verify that all modalities across batches form a connected integration graph.
Raises
------
RuntimeError
If the graph of shared modalities across batches is not fully connected.
"""
mod_graph = np.zeros((self.n_mods, self.n_mods))
for bi in range(self.n_batches):
modIds_in_bi = self.batch_contained_mod_ids[bi]
mod_pairs = np.array(list(itertools.product(modIds_in_bi, modIds_in_bi)))
mod_graph[mod_pairs[:, 0], mod_pairs[:, 1]] = 1
n_cs, labels = connected_components(mod_graph, directed=False, return_labels=True)
if n_cs > 1:
for ci in np.unique(labels):
ni_msk = labels == ci
print(f'conn {ci}:', self.mod_list[ni_msk])
raise RuntimeError('Dataset not connected, cannot be integrated')
[docs] def prepare_intra_graphs(self, modBatch_dict):
"""
Build spatial neighbor graphs for each modality across batches.
Parameters
----------
modBatch_dict : dict
Mapping modality name to list of AnnData objects.
"""
mod_intraGraphs = {}
mod_nodename2idx, mod_nodeidx2name = {}, {}
# spatial-neighbor graph
for key in self.mod_list:
build_graph.build_intra_graph(modBatch_dict[key], rad_cutoff=self.radius_cutoff, knns=self.intra_knns)
# intra graph: block diagonal
intra_graph = sps.block_diag([adx.uns['adj'] for adx in modBatch_dict[key] if adx is not None]).tocoo()
obs_names = np.hstack([adx.obs_names.to_list() for adx in modBatch_dict[key] if adx is not None])
# build inter-graph from mnn_set
mod_nodename2idx[key] = dict(zip(obs_names, np.arange(len(obs_names))))
mod_nodeidx2name[key] = {v: k for k, v in mod_nodename2idx[key].items()}
mod_intraGraphs[key] = intra_graph
self.mod_intraGraphs = mod_intraGraphs
self.mod_nodename2idx = mod_nodename2idx
self.mod_nodeidx2name = mod_nodeidx2name
[docs] def apply_smoothing(self, modBatch_dict, mod_graphs, key, added_key, symm=False):
"""
Apply WLGCN feature smoothing per modality.
Parameters
----------
modBatch_dict : dict
Modality -> list of AnnData objects.
mod_graphs : dict
Modality -> adjacency graph (scipy sparse).
key : str
Key for input features in ``.obsm``.
added_key : str
Key to store smoothed features in ``.obsm``.
symm : bool, optional
If ``True``, symmetrize the input adjacency before normalization. Default is ``False``.
"""
model = archs.wlgcn.WLGCN_vanilla(K=self.smooth_L).to(self.device).eval()
for mod in self.mod_list:
ads = [ad for ad in modBatch_dict[mod] if ad is not None]
attr = np.vstack([ad.obsm[key] for ad in ads])
attr = torch.as_tensor(attr, dtype=torch.float32, device=self.device)
A_hat = make_Ahat_sparse(mod_graphs[mod].tocoo(), symm=symm)
A_hat = self.g2ts(A_hat).to(self.device)
with torch.inference_mode():
smth = model(attr, A_hat).cpu().numpy()
sizes = [ad.n_obs for ad in ads]
if len(sizes) == 1:
chunks = [smth]
else:
chunks = np.split(smth, np.cumsum(sizes)[:-1], axis=0)
for ad, chunk in zip(ads, chunks):
ad.obsm[added_key] = chunk
del attr, A_hat
if self.device.type == "cuda":
torch.cuda.synchronize()
[docs] def g2ts(self, A_hat):
"""
Convert a SciPy sparse matrix to ``torch_sparse.SparseTensor``.
Parameters
----------
A_hat : scipy.sparse.spmatrix
Symmetrized normalized adjacency.
Returns
-------
torch_sparse.SparseTensor
Coalesced sparse tensor with the same shape as ``A_hat``.
"""
A = A_hat.tocoo()
row = torch.from_numpy(A.row).long()
col = torch.from_numpy(A.col).long()
val = torch.from_numpy(A.data.astype(np.float32))
return SparseTensor(row=row, col=col, value=val, sparse_sizes=A.shape).coalesce()
[docs] def prepare_inter_graphs(self, modBatch_dict):
"""
Build mutual nearest neighbor (MNN) graphs between batches for each modality.
Notes
--------
- Identify bridge vs. non-bridge batches.
- Compute MNN pairs within each modality.
- Optionally filter outliers.
"""
# determine which batches as bridges, which not.
mod_mask = np.zeros((self.n_mods, self.n_batches))
for bi in range(self.n_batches):
for ki, k in enumerate(self.mod_list):
if modBatch_dict[k][bi] is not None:
mod_mask[ki][bi] = 1
self.bridge_batch_num_ids = np.where(mod_mask.sum(axis=0) >= 2)[0] # e.g., [idx, idx2, ]
self.non_bridge_batch_num_ids = np.where(mod_mask.sum(axis=0) < 2)[0]
# prepare meta parameter for training
self.mod_batch_split = {
key: [
modBatch_dict[key][bi].n_obs
if modBatch_dict[key][bi] is not None
else 0
for bi in range(self.n_batches)
]
for key in self.mod_list
}
# mapping batch id to their modality set
self.bridge_batch_num_ids2mod = {
bi: [key for key in self.mod_list if modBatch_dict[key][bi] is not None]
for bi in self.bridge_batch_num_ids
}
self.non_bridge_batch_num_ids2mod = {
bi: [key for key in self.mod_list if modBatch_dict[key][bi] is not None]
for bi in self.non_bridge_batch_num_ids
}
bridge_ads = {
key: [modBatch_dict[key][aid] for aid in self.bridge_batch_num_ids if modBatch_dict[key][aid] is not None]
for key in self.mod_list
}
test_ads = {
key: [modBatch_dict[key][aid] for aid in self.non_bridge_batch_num_ids if modBatch_dict[key][aid] is not None]
for key in self.mod_list
}
if self.use_expr_adj:
mod_mnn_set = {}
for key in self.mod_list:
print(f'Searching MNN within {key}')
mod_mnn_set[key] = build_graph.build_mnn_graph(
bridge_ads[key], test_ads[key], self.mnn_rep_key, self.batch_key,
self.inter_knn_base, self.inter_auto_knn, self.inter_auto_thr,
self.rmv_outlier, self.contamination,
self.seed
)
print('Number of mnn pairs for {}:{}'.format(key, len(mod_mnn_set[key])))
else:
mod_mnn_set = {key: [] for key in self.mod_list}
self.mod_mnn_set = mod_mnn_set
[docs] def cache_spatial_graph(self):
"""
Cache per-modality, per-batch intra subgraphs using batch index ranges.
Assumes each batch occupies a contiguous node range in the merged intra graph.
Returns
-------
dict
Nested mapping ``{modality -> {batch_id -> {'nodes': Tensor, 'edge_index': Tensor or None}}}``.
"""
cache = {k: {} for k in self.mod_list}
for k in self.mod_list:
A = self.mod_intraGraphs[k].tocoo()
row = torch.from_numpy(A.row).long()
col = torch.from_numpy(A.col).long()
# drop self-loops if present
m = row != col
row, col = row[m], col[m]
sizes = self.mod_batch_split[k]
offsets = np.cumsum([0] + sizes) # len = n_batches+1
for bi in range(self.n_batches):
start, end = offsets[bi], offsets[bi + 1]
if end <= start:
cache[k][bi] = {"nodes": None, "edge_index": None}
continue
in_row = (row >= start) & (row < end)
in_col = (col >= start) & (col < end)
mask = in_row & in_col
if not mask.any():
nodes = torch.arange(start, end, dtype=torch.long)
cache[k][bi] = {"nodes": nodes, "edge_index": None}
continue
er = row[mask] - start
ec = col[mask] - start
e_local = torch.stack([er, ec], dim=0)
nodes = torch.arange(start, end, dtype=torch.long)
cache[k][bi] = {"nodes": nodes, "edge_index": e_local}
return cache
[docs] def prepare_net(self, net):
"""
Instantiate the architecture for each modality from a config.
Parameters
----------
net : str
Name of model architecture (must match a YAML config).
Returns
-------
dict
Mapping ``{modality_name -> torch.nn.Module}``.
"""
config = utls.load_config(f'{cur_dir}/configs/{net}.yaml')
mod_model = {}
for k in self.mod_list:
encoder = archs.wlgcn.HEAD(
self.mod_graphs_attr_dim[k],
config.model.out_dim,
dec_l=config.model.n_dec_l,
hidden_size=config.model.hid_dim,
dropout=config.model.dropout,
slope=config.model.slope
)
mod_model[k] = encoder.to(self.device)
return mod_model
[docs] def train(self, net, lr, use_mini_thr=8000, mini_batch_size=1024, loss_type='adapted', T=0.01, bias=0, n_epochs=100, w_rec_g=0.):
"""
Train SpaMosaic using contrastive and reconstruction losses.
Parameters
----------
net : str
Architecture name (used to load config).
lr : float
Learning rate.
use_mini_thr : int
Threshold above which mini-batch training is used.
mini_batch_size : int
Size of mini-batches if needed.
loss_type : {'adapted', 'ce'}
Contrastive loss type. ``'adapted'`` supports ≥3 modalities.
T : float
Temperature for contrastive loss.
bias : float
Bias term in adapted contrastive loss.
n_epochs : int
Number of training epochs.
w_rec_g : float
Weight for graph reconstruction loss on test batches.
Returns
-------
None
"""
# set model architectures
mod_model = self.prepare_net(net)
# set optimizer
mod_optims = {k: torch.optim.Adam(mod_model[k].parameters(), lr=lr, weight_decay=5e-4) for k in self.mod_list}
# for each batch, the meaning of parameter:
# bridge_batch_meta_numbers = [{if as bridge}, {number of spots}, {size of mini-batches}, {number of measured modalites}, {measured modalities}]
# test_batch_meta_numbers = [{if as bridge}, {measured modalities}]
batch_train_meta_numbers = {}
for bi, ms in self.bridge_batch_num_ids2mod.items():
n_cell = self.mod_batch_split[ms[0]][bi]
n_loss_batch = n_cell if n_cell <= use_mini_thr else mini_batch_size # determine mini-batch size for each batch
n_batch = n_cell // n_loss_batch # size of mini-batches
batch_train_meta_numbers[bi] = (True, n_cell, n_loss_batch, n_batch, len(ms), ms)
for bi, ms in self.non_bridge_batch_num_ids2mod.items():
batch_train_meta_numbers[bi] = (False, ms)
if w_rec_g > 0: # using graph reconstruction loss for test batches
self.cached_modBatch_intra_subgraphs = self.cache_spatial_graph()
else:
self.cached_modBatch_intra_subgraphs = None
# detetermine contrastive learning loss
if loss_type == 'adapted':
# Proposed contrastive loss; handles ≥3-modality alignment
crit1 = {
bi: CL_loss(batch_train_meta_numbers[bi][2], rep=batch_train_meta_numbers[bi][4], bias=bias).to(self.device)
for bi in self.bridge_batch_num_ids
}
else:
# Canonical CE-based CL; only handles 2 modalities
crit1 = nn.CrossEntropyLoss().to(self.device)
# feature reconstruction loss
crit2 = nn.MSELoss().to(self.device)
# Train
mod_model, loss_cl, loss_rec = train_model(
mod_model, mod_optims, crit1, crit2, loss_type, w_rec_g,
self.mod_feats, self.cached_modBatch_intra_subgraphs,
batch_train_meta_numbers, self.mod_batch_split,
T, n_epochs, self.device
)
self.mod_model = mod_model
self.loss_cl = loss_cl
self.loss_rec = loss_rec
[docs] def infer_emb(self, modBatch_dict, emb_key='emb', final_latent_key='merged_emb', cat=False):
"""
Infer latent embeddings for each cell and return merged AnnData list.
Parameters
----------
modBatch_dict : dict
Original input dictionary of modalities and batches.
emb_key : str
Key to store intermediate embeddings.
final_latent_key : str
Key to store final merged embedding in returned AnnData.
cat : bool
If ``True``, concatenate modality embeddings; otherwise average them.
Returns
-------
list of AnnData
Reconstructed AnnData objects with merged embeddings.
"""
for k in self.mod_list:
self.mod_model[k].eval()
z, _ = self.mod_model[k](self.mod_feats[k].to(self.device))
z_split = torch.split(z, self.mod_batch_split[k])
for bi in range(self.n_batches):
if modBatch_dict[k][bi] is not None:
modBatch_dict[k][bi].obsm[emb_key] = z_split[bi].detach().cpu().numpy()
# merge embs from measured modality
ad_finals = []
for bi in range(self.n_batches):
embs = []
for m in self.mod_list:
if modBatch_dict[m][bi] is not None:
embs.append(modBatch_dict[m][bi].obsm[emb_key])
ad_tmp = modBatch_dict[m][bi]
if cat:
emb = np.hstack(embs)
else:
emb = np.mean(embs, axis=0)
ad = sc.AnnData(np.zeros((emb.shape[0], 2)), obs=ad_tmp.obs.copy(), obsm={'spatial': ad_tmp.obsm['spatial']})
ad.obsm[final_latent_key] = emb
ad_finals.append(ad)
return ad_finals
[docs] def impute(self, modBatch_dict, emb_key='emb', layer_key='counts', imp_knn=10):
"""
Impute missing modalities using the aligned embedding space and KNN.
Parameters
----------
modBatch_dict : dict
Input dictionary of modalities and batches.
emb_key : str
Key where latent embeddings are stored.
layer_key : str
Which layer to impute (e.g., ``'counts'``).
imp_knn : int
Number of neighbors to use in KNN-based imputation.
Returns
-------
dict
Imputed data dictionary: ``{modality -> list of arrays (or None)}``.
"""
aligned_pool = {
k: np.vstack([ad.obsm[emb_key] for ad in modBatch_dict[k] if ad is not None])
for k in modBatch_dict.keys()
}
target_pool = {
k: np.vstack([ad.layers[layer_key].A if sps.issparse(ad.layers[layer_key]) else ad.layers[layer_key]
for ad in modBatch_dict[k] if ad is not None])
for k in modBatch_dict.keys()
}
imputed_batchDict = {
k: [None]*self.n_batches
for k in modBatch_dict.keys()
}
for bi in range(self.n_batches):
bi_measued_mod_names = [_ for _ in modBatch_dict.keys() if modBatch_dict[_][bi] is not None]
bi_missing_mod_names = list(set(modBatch_dict.keys()) - set(bi_measued_mod_names))
for k_q in bi_missing_mod_names:
print(f'impute {k_q}-{layer_key} for batch-{bi+1}')
# visit all measured mods to impute missing k
imps = []
for k_v in bi_measued_mod_names:
knn_ind = utls.nn_approx(modBatch_dict[k_v][bi].obsm[emb_key], aligned_pool[k_q], knn=imp_knn)
p_q = target_pool[k_q][knn_ind.ravel()].reshape(*(knn_ind.shape), target_pool[k_q].shape[1])
imps.append(np.mean(p_q, axis=1))
imp = np.mean(imps, axis=0)
imputed_batchDict[k_q][bi] = imp
return imputed_batchDict