Source code for spamosaic.build_graph

"""Graph construction utilities for spatial transcriptomics.

This module provides helpers to build spatial neighbor graphs and cross-batch
mutual nearest neighbor (MNN) matches on top of AnnData objects.
"""

import os, gc
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import scanpy as sc
# import squidpy as sq
import h5py
import math
import sklearn
from tqdm import tqdm
import scipy.sparse as sps
import scipy.io as sio
import seaborn as sns
import warnings
import networkx as nx

from os.path import join
from collections import Counter
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from annoy import AnnoyIndex
from sklearn.preprocessing import normalize
from sklearn.ensemble import IsolationForest

import spamosaic.MNN as MNN

RAD_CUTOFF = 2000


# adapted from STAGATE
[docs]def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, max_neigh=50, model='Radius', verbose=True): """Construct spatial graph from spatial coordinates using radius or kNN method. Parameters ---------- adata : AnnData Input annotated data. rad_cutoff : float, optional Distance cutoff for radius-based graph. k_cutoff : int, optional Number of neighbors for kNN-based graph. max_neigh : int Maximum number of neighbors to consider. model : str Type of graph construction: 'Radius' or 'KNN'. verbose : bool Whether to print debug info. Returns ------- None Notes ----- On success, adds the adjacency matrix to ``adata.uns['adj']`` (SciPy sparse matrix). """ assert (model in ['Radius', 'KNN']) if verbose: print('------Calculating spatial graph...') coor = pd.DataFrame(adata.obsm['spatial']) coor.index = adata.obs.index coor.columns = ['imagerow', 'imagecol'] nbrs = sklearn.neighbors.NearestNeighbors( n_neighbors=max_neigh + 1, algorithm='ball_tree').fit(coor) distances, indices = nbrs.kneighbors(coor) if model == 'KNN': indices = indices[:, 1:k_cutoff + 1] distances = distances[:, 1:k_cutoff + 1] if model == 'Radius': indices = indices[:, 1:] distances = distances[:, 1:] KNN_list = [] for it in range(indices.shape[0]): KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) KNN_df = pd.concat(KNN_list) KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] Spatial_Net = KNN_df.copy() if model == 'Radius': Spatial_Net = KNN_df.loc[KNN_df['Distance'] < rad_cutoff,] id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) if verbose: print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs)) print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs)) adata.uns['Spatial_Net'] = Spatial_Net cells = np.array(adata.obs.index) cells_id_tran = dict(zip(cells, range(cells.shape[0]))) if 'Spatial_Net' not in adata.uns.keys(): raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") Spatial_Net = adata.uns['Spatial_Net'] G_df = Spatial_Net.copy() G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) G = sps.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) G = G + sps.eye(G.shape[0]) # self-loop adata.uns['adj'] = G
[docs]def build_intra_graph(ads, rad_cutoff, knns): """Construct intra-batch spatial graphs for a list of AnnData objects. Parameters ---------- ads : list of AnnData List of spatial AnnData objects. ``None`` entries are skipped. rad_cutoff : float Distance threshold for the radius graph. knns : list of int Maximum neighbors to query for each AnnData (same length as ``ads``). Returns ------- None """ for k, ad in zip(knns, ads): if ad is not None: Cal_Spatial_Net(ad, rad_cutoff=rad_cutoff, max_neigh=k)
[docs]def make_Ahat_sparse(A, improved=False, symm=False): """Build a normalized adjacency matrix ``A_hat`` suitable for GNN layers. Given a sparse adjacency ``A`` (N×N), this function optionally symmetrizes it, adds self-loops, and applies the symmetric normalization :math:`D^{-1/2} A D^{-1/2}`. Parameters ---------- A : scipy.sparse.spmatrix Sparse adjacency matrix of shape (N, N). improved : bool, optional If ``True``, use a self-loop weight of 2.0 (as in "improved" GCN). symm : bool, optional If ``True``, force symmetry by replacing ``A`` with ``(A + A.T)`` before normalization. Returns ------- scipy.sparse.csr_matrix Normalized sparse matrix ``A_hat`` of shape (N, N). """ N = A.shape[0] A = A.tocoo(copy=True) A.setdiag(0) A.eliminate_zeros() A_sym = A.tocoo(copy=True) if symm: A_sym = (A + A.T).tocoo() fill = 2.0 if improved else 1.0 A_sym = (A_sym + sps.eye(N, dtype=np.float32, format="coo") * fill).tocoo() deg = np.asarray(A_sym.sum(axis=1)).ravel() d_inv_sqrt = 1.0 / np.sqrt(np.maximum(deg, 1e-12)) r, c, v = A_sym.row, A_sym.col, A_sym.data.astype(np.float32) v = v * d_inv_sqrt[r] * d_inv_sqrt[c] A_hat = sps.csr_matrix((v, (r, c)), shape=(N, N), dtype=np.float32) return A_hat
[docs]def determine_kSize(adi, adj, knn_base, auto_thr): """Determine asymmetric k values for MNN search based on dataset sizes. If two datasets have similar numbers of observations (the smaller-to-larger ratio is at least ``auto_thr``), both sides use ``knn_base``. Otherwise, the smaller side uses ``floor(knn_base * size_ratio)`` (at least 1). Parameters ---------- adi : AnnData First dataset. adj : AnnData Second dataset. knn_base : int Base number of neighbors. auto_thr : float Size similarity threshold, e.g. 0.8. Returns ------- tuple of int ``(knn_adi, knn_adj)`` to use for the pair. """ size_ratio = min(adi.n_obs, adj.n_obs) / max(adi.n_obs, adj.n_obs) if size_ratio >= auto_thr: return knn_base, knn_base if adi.n_obs > adj.n_obs: return max(1, int(knn_base * size_ratio)), knn_base else: return knn_base, max(1, int(knn_base * size_ratio))
[docs]def remove_outlier(mnn_set, ad1, ad2, contamination='auto'): """Filter spatial outliers from an MNN pair set using Isolation Forest. A feature matrix is built from the concatenated spatial coordinates of both cells and their differences. Pairs predicted as outliers (label ``-1``) are removed. Parameters ---------- mnn_set : set of tuple of str Set of MNN barcode pairs ``(cell_id_1, cell_id_2)``. ad1 : AnnData Dataset of the first cell; must contain ``.obsm['spatial']``. ad2 : AnnData Dataset of the second cell; must contain ``.obsm['spatial']``. contamination : {'auto', float}, optional Expected outlier fraction; passed to ``sklearn.ensemble.IsolationForest``. Returns ------- set of tuple of str Filtered set with spatial outliers removed. """ X_spatial_1 = ad1[[p[0] for p in mnn_set]].obsm['spatial'] X_spatial_2 = ad2[[p[1] for p in mnn_set]].obsm['spatial'] data = np.c_[X_spatial_1, X_spatial_2, X_spatial_1 - X_spatial_2] clf = IsolationForest(max_samples='auto', contamination=contamination, random_state=0) y_outlier = clf.fit_predict(data) new_mnn_set = set() for p, y in zip(mnn_set, y_outlier): if y == 1: new_mnn_set.add(p) return new_mnn_set
[docs]def build_mnn_graph( bridge_ads, test_ads, use_rep, batch_key, knn_base=10, auto_knn=False, auto_thr=0.8, rmv_outlier=False, contamination='auto', seed=1234 ): """Build a mutual nearest neighbor (MNN) graph across batches in one modality. This procedure matches cells across batches using approximate kNN (Annoy) on a given embedding, returning a set of matched barcode pairs. It supports (1) bridge–bridge matches (between multi-modal batches) and (2) bridge–test matches (between a bridge batch and a single-modality batch). Optionally, it can remove spatial outliers. Parameters ---------- bridge_ads : list of AnnData AnnData objects from bridge batches. test_ads : list of AnnData AnnData objects from test batches. use_rep : str Key in ``.obsm`` containing the embedding to search. batch_key : str Column in ``.obs`` used only for logging/batch names. knn_base : int, optional Base number of neighbors per side. Default is 10. auto_knn : bool, optional If ``True``, use ``determine_kSize`` to adapt k based on batch sizes. auto_thr : float, optional Size ratio threshold for ``auto_knn``. Default is 0.8. rmv_outlier : bool, optional Whether to run ``remove_outlier``. Default is ``False``. contamination : {'auto', float}, optional Outlier rate used by Isolation Forest. seed : int, optional Random seed for stochastic components. Returns ------- set of tuple of str Set of matched barcode pairs across batches. """ n_bridge = len(bridge_ads) n_test = len(test_ads) # If no bridge available, treat test as reference if n_bridge == 0: bridge_ads = test_ads n_bridge = len(bridge_ads) n_test = 0 # No test mode mnn_bridge = set() mnn_cross = set() def compute_mnn(adi, adj): # Determine KNN size if auto_knn: knn1, knn2 = determine_kSize(adi, adj, knn_base, auto_thr) else: knn1 = knn2 = knn_base # Log pair info src_name = adi.obs[batch_key][0] tgt_name = adj.obs[batch_key][0] print(f"Finding MNN between ({src_name}, {tgt_name}) using KNN ({knn1}, {knn2})") # Run approximate MNN search mnn_pairs = MNN.mnn( adi.obsm[use_rep], adj.obsm[use_rep], adi.obs_names, adj.obs_names, knn1=knn1, knn2=knn2, approx=True, way='annoy', metric='manhattan', norm=True ) # Optionally filter outliers if rmv_outlier: n_before = len(mnn_pairs) mnn_pairs = remove_outlier(mnn_pairs, adi, adj, contamination) n_after = len(mnn_pairs) if n_before > 0: print(f"==> Filtered {n_before - n_after} from {n_before} ({(n_before - n_after) / n_before * 100:.2f}%)") return mnn_pairs # Bridge-bridge connections for i in range(n_bridge): for j in range(i + 1, n_bridge): mnn_bridge |= compute_mnn(bridge_ads[i], bridge_ads[j]) # Bridge-test connections for i in range(n_bridge): for j in range(n_test): mnn_cross |= compute_mnn(bridge_ads[i], test_ads[j]) # Return combined MNN graph return mnn_bridge | mnn_cross