"""Preprocessing utilities for SpaMosaic.
Implements TF-IDF/LSI pipelines, CLR normalization, Harmony batch correction, and
modality-specific preprocessing for RNA/ADT/epigenome.
"""
from typing import Optional
import os, gc
import torch
import sklearn
import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse
import sklearn.decomposition
import sklearn.feature_extraction.text
import sklearn.neighbors
import sklearn.preprocessing
import sklearn.utils.extmath
from harmony import harmonize
from spamosaic.utils import split_adata_ob
[docs]def sparse_log1p_scale(X, scale=1e4):
"""
Apply log1p transformation to sparse or dense matrix, scaled by a factor.
Parameters
----------
X : Union[scipy.sparse.spmatrix, np.ndarray]
Input expression matrix.
scale : float, default=1e4
Scaling factor applied before log1p.
Returns
-------
Union[scipy.sparse.spmatrix, np.ndarray]
Transformed matrix (same type as input).
"""
if scipy.sparse.issparse(X):
X = X.copy()
X.data = np.log1p(X.data * scale)
return X
else:
return np.log1p(X * scale)
# optional, other reasonable preprocessing steps also ok
# CLR-normalization
[docs]def clr_normalize(adata):
"""
Perform centered log-ratio (CLR) normalization on count data.
Parameters
----------
adata : AnnData
Input data with count matrix in ``.X``.
Returns
-------
AnnData
Normalized AnnData object.
"""
def seurat_clr(x):
s = np.sum(np.log1p(x[x > 0]))
exp = np.exp(s / len(x))
return np.log1p(x / exp)
adata.X = np.apply_along_axis(
seurat_clr, 1, (adata.X.A if scipy.sparse.issparse(adata.X) else np.array(adata.X))
)
# sc.pp.pca(adata, n_comps=min(50, adata.n_vars-1))
return adata
[docs]def harmony(latent, batch_labels, use_gpu=True):
"""
Batch correction using Harmony.
Parameters
----------
latent : np.ndarray
Low-dimensional representation (e.g., PCA).
batch_labels : list or array
Corresponding batch annotations.
use_gpu : bool, default=True
Whether to use GPU acceleration.
Returns
-------
np.ndarray
Batch-corrected latent representation.
"""
df_batches = pd.DataFrame(np.reshape(batch_labels, (-1, 1)), columns=['batch'])
bc_latent = harmonize(
latent, df_batches, batch_key="batch", use_gpu=use_gpu, verbose=True
)
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
return bc_latent
[docs]def RNA_preprocess(
rna_ads, batch_corr=False, favor='adapted', n_hvg=5000, lognorm=True,
scale=False, n_comps=50, batch_key='src', key='dimred_bc', return_hvf=False
):
"""
Preprocessing pipeline for RNA modality.
Parameters
----------
rna_ads : list of AnnData
RNA modality per batch.
batch_corr : bool
Whether to perform batch correction.
favor : {'adapted', 'scanpy'}
Which pipeline to use.
n_hvg : int
Number of highly variable genes.
lognorm : bool
Whether to apply log-normalization.
scale : bool
Whether to scale features.
n_comps : int
Number of output components.
batch_key : str
Key in ``.obs`` indicating batch identity.
key : str
Key to store result in ``.obsm``.
return_hvf : bool
If ``True``, return indices of selected HVGs.
Returns
-------
Optional[Tuple[np.ndarray, np.ndarray]]
If ``return_hvf`` is ``True``, returns (gene_names, indices); otherwise ``None``.
"""
measured_ads = [ad for ad in rna_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
if favor=='scanpy':
if lognorm:
sc.pp.normalize_total(ad_concat, target_sum=1e4)
sc.pp.log1p(ad_concat)
if n_hvg:
sc.pp.highly_variable_genes(ad_concat, n_top_genes=n_hvg, batch_key=batch_key)
ad_concat = ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()].copy()
if scale:
sc.pp.scale(ad_concat)
sc.pp.pca(ad_concat, n_comps=min(n_comps, ad_concat.n_vars-1))
tmp_key = 'X_pca'
else:
n_hvg = n_hvg if n_hvg else ad_concat.shape[1]
sc.pp.highly_variable_genes(ad_concat, flavor='seurat_v3', n_top_genes=n_hvg, batch_key=batch_key)
transformer = lsiTransformer(
n_components=n_comps, drop_first=False, log=True, norm=True,
z_score=True, tfidf=False, svd=True, pcaAlgo='arpack'
)
ad_concat.obsm['X_lsi'] = transformer.fit_transform(ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()]).values
tmp_key = 'X_lsi'
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(
ad_concat.obsm[tmp_key],
ad_concat.obs[batch_key].to_list(),
use_gpu=True
)
else:
ad_concat.obsm[key] = ad_concat.obsm[tmp_key]
split_adata_ob([ad for ad in rna_ads if ad is not None], ad_concat, ob='obsm', key=key)
if n_hvg and return_hvf:
return ad_concat.var.query('highly_variable').index.to_numpy(), np.where(ad_concat.var['highly_variable'])[0]
[docs]def ADT_preprocess(
adt_ads, batch_corr=False, favor='clr', lognorm=True, scale=False,
n_comps=50, batch_key='src', key='dimred_bc'
):
"""
Preprocessing pipeline for ADT (protein) modality.
Parameters
----------
adt_ads : list of AnnData
ADT modality per batch.
batch_corr : bool
Whether to perform batch correction.
favor : {'clr', 'lognorm'}
Whether to use CLR or log-normalization.
lognorm : bool
Apply log-normalization (if ``favor='lognorm'``).
scale : bool
Whether to scale features.
n_comps : int
Number of components for PCA.
batch_key : str
Key for batch annotation.
key : str
Key to store reduced dimension result.
Returns
-------
None
"""
measured_ads = [ad for ad in adt_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
if favor=='clr':
ad_concat = clr_normalize(ad_concat)
# if scale: sc.pp.scale(ad_concat)
else:
if lognorm:
sc.pp.normalize_total(ad_concat, target_sum=1e4)
sc.pp.log1p(ad_concat)
if scale: sc.pp.scale(ad_concat)
sc.pp.pca(ad_concat, n_comps=min(n_comps, ad_concat.n_vars-1))
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(ad_concat.obsm['X_pca'], ad_concat.obs[batch_key].to_list(), use_gpu=True)
else:
ad_concat.obsm[key] = ad_concat.obsm['X_pca']
split_adata_ob([ad for ad in adt_ads if ad is not None], ad_concat, ob='obsm', key=key)
[docs]def Epigenome_preprocess(
epi_ads, batch_corr=False, n_peak=100000,
n_comps=50, batch_key='src', key='dimred_bc', return_hvf=False):
"""
Preprocessing pipeline for epigenomic modality (e.g., ATAC-seq).
Parameters
----------
epi_ads : list of AnnData
Epigenomic modality per batch.
batch_corr : bool
Whether to apply Harmony batch correction.
n_peak : int
Number of variable peaks to keep.
n_comps : int
Number of LSI components.
batch_key : str
Batch identifier key.
key : str
Output key in ``.obsm``.
return_hvf : bool
Whether to return selected peak indices.
Returns
-------
Optional[Tuple[np.ndarray, np.ndarray]]
If ``return_hvf`` is ``True``, returns (peak_names, indices); otherwise ``None``.
"""
measured_ads = [ad for ad in epi_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
sc.pp.highly_variable_genes(ad_concat, flavor='seurat_v3', n_top_genes=n_peak, batch_key=batch_key)
transformer = lsiTransformer(
n_components=n_comps, drop_first=True, log=True, norm=True,
z_score=True, tfidf=True, svd=True, pcaAlgo='arpack'
)
ad_concat.obsm['X_lsi'] = transformer.fit_transform(ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()]).values
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(ad_concat.obsm['X_lsi'], ad_concat.obs[batch_key].to_list(), use_gpu=True)
else:
ad_concat.obsm[key] = ad_concat.obsm['X_lsi']
split_adata_ob([ad for ad in epi_ads if ad is not None], ad_concat, ob='obsm', key=key)
if return_hvf:
return ad_concat.var.query('highly_variable').index.to_numpy(), np.where(ad_concat.var['highly_variable'])[0]