from typing import Optional
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import sklearn
import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse
import sklearn.decomposition
import sklearn.feature_extraction.text
import sklearn.neighbors
import sklearn.preprocessing
import sklearn.utils.extmath
from harmony import harmonize
from spamosaic.utils import split_adata_ob
[docs]def sparse_log1p_scale(X, scale=1e4):
"""
Apply log1p transformation to sparse or dense matrix, scaled by a factor.
Parameters
----------
X : Union[scipy.sparse.spmatrix, np.ndarray]
Input expression matrix.
scale : float, default=1e4
Scaling factor applied before log1p.
Returns
-------
Transformed matrix (same type as input)
"""
if scipy.sparse.issparse(X):
X = X.copy()
X.data = np.log1p(X.data * scale)
return X
else:
return np.log1p(X * scale)
# optional, other reasonable preprocessing steps also ok
# CLR-normalization
[docs]def clr_normalize(adata):
"""
Perform centered log-ratio (CLR) normalization on count data.
Parameters
----------
adata : AnnData
Input data with count matrix in `.X`.
Returns
-------
adata : AnnData
Normalized AnnData object.
"""
def seurat_clr(x):
s = np.sum(np.log1p(x[x > 0]))
exp = np.exp(s / len(x))
return np.log1p(x / exp)
adata.X = np.apply_along_axis(
seurat_clr, 1, (adata.X.A if scipy.sparse.issparse(adata.X) else np.array(adata.X))
)
# sc.pp.pca(adata, n_comps=min(50, adata.n_vars-1))
return adata
[docs]def harmony(latent, batch_labels, use_gpu=True):
"""
Batch correction using Harmony.
Parameters
----------
latent : np.ndarray
Low-dimensional representation (e.g., PCA).
batch_labels : list or array
Corresponding batch annotations.
use_gpu : bool, default=True
Whether to use GPU acceleration.
Returns
-------
np.ndarray
Batch-corrected latent representation.
"""
df_batches = pd.DataFrame(np.reshape(batch_labels, (-1, 1)), columns=['batch'])
bc_latent = harmonize(
latent, df_batches, batch_key="batch", use_gpu=use_gpu, verbose=True
)
return bc_latent
[docs]def RNA_preprocess(
rna_ads, batch_corr=False, favor='adapted', n_hvg=5000, lognorm=True,
scale=False, n_comps=50, batch_key='src', key='dimred_bc', return_hvf=False
):
"""
Preprocessing pipeline for RNA modality.
Parameters
----------
rna_ads : list of AnnData
RNA modality per batch.
batch_corr : bool
Whether to perform batch correction.
favor : {'adapted', 'scanpy'}
Which pipeline to use.
n_hvg : int
Number of highly variable genes.
lognorm : bool
Whether to apply log-normalization.
scale : bool
Whether to scale features.
n_comps : int
Number of output components.
batch_key : str
Key in `.obs` indicating batch identity.
key : str
Key to store result in `.obsm`.
return_hvf : bool
If True, return indices of selected HVGs.
Returns
-------
If return_hvf is True, returns tuple of (gene names, indices).
"""
measured_ads = [ad for ad in rna_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
if favor=='scanpy':
if lognorm:
sc.pp.normalize_total(ad_concat, target_sum=1e4)
sc.pp.log1p(ad_concat)
if n_hvg:
sc.pp.highly_variable_genes(ad_concat, n_top_genes=n_hvg, batch_key=batch_key)
ad_concat = ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()].copy()
if scale:
sc.pp.scale(ad_concat)
sc.pp.pca(ad_concat, n_comps=min(n_comps, ad_concat.n_vars-1))
tmp_key = 'X_pca'
else:
n_hvg = n_hvg if n_hvg else ad_concat.shape[1]
sc.pp.highly_variable_genes(ad_concat, flavor='seurat_v3', n_top_genes=n_hvg, batch_key=batch_key)
transformer = lsiTransformer(
n_components=n_comps, drop_first=False, log=True, norm=True,
z_score=True, tfidf=False, svd=True, pcaAlgo='arpack'
)
ad_concat.obsm['X_lsi'] = transformer.fit_transform(ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()]).values
tmp_key = 'X_lsi'
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(
ad_concat.obsm[tmp_key],
ad_concat.obs[batch_key].to_list(),
use_gpu=True
)
else:
ad_concat.obsm[key] = ad_concat.obsm[tmp_key]
split_adata_ob([ad for ad in rna_ads if ad is not None], ad_concat, ob='obsm', key=key)
if n_hvg and return_hvf:
return ad_concat.var.query('highly_variable').index.to_numpy(), np.where(ad_concat.var['highly_variable'])[0]
[docs]def ADT_preprocess(
adt_ads, batch_corr=False, favor='clr', lognorm=True, scale=False,
n_comps=50, batch_key='src', key='dimred_bc'
):
"""
Preprocessing pipeline for ADT (protein) modality.
Parameters
----------
adt_ads : list of AnnData
ADT modality per batch.
batch_corr : bool
Whether to perform batch correction.
favor : {'clr', 'lognorm'}
Whether to use CLR or log-normalization.
lognorm : bool
Apply log-normalization (if favor='lognorm').
scale : bool
Whether to scale features.
n_comps : int
Number of components for PCA.
batch_key : str
Key for batch annotation.
key : str
Key to store reduced dimension result.
Returns
-------
None
"""
measured_ads = [ad for ad in adt_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
if favor=='clr':
ad_concat = clr_normalize(ad_concat)
# if scale: sc.pp.scale(ad_concat)
else:
if lognorm:
sc.pp.normalize_total(ad_concat, target_sum=1e4)
sc.pp.log1p(ad_concat)
if scale: sc.pp.scale(ad_concat)
sc.pp.pca(ad_concat, n_comps=min(n_comps, ad_concat.n_vars-1))
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(ad_concat.obsm['X_pca'], ad_concat.obs[batch_key].to_list(), use_gpu=True)
else:
ad_concat.obsm[key] = ad_concat.obsm['X_pca']
split_adata_ob([ad for ad in adt_ads if ad is not None], ad_concat, ob='obsm', key=key)
[docs]def Epigenome_preprocess(
epi_ads, batch_corr=False, n_peak=100000,
n_comps=50, batch_key='src', key='dimred_bc', return_hvf=False):
"""
Preprocessing pipeline for epigenomic modality (e.g. ATAC-seq).
Parameters
----------
epi_ads : list of AnnData
Epigenomic modality per batch.
batch_corr : bool
Whether to apply Harmony batch correction.
n_peak : int
Number of variable peaks to keep.
n_comps : int
Number of LSI components.
batch_key : str
Batch identifier key.
key : str
Output key in `.obsm`.
return_hvf : bool
Whether to return selected peak indices.
Returns
-------
Optional[Tuple[np.ndarray, np.ndarray]]
If return_hvf is True, returns names and indices of selected peaks.
"""
measured_ads = [ad for ad in epi_ads if ad is not None]
ad_concat = sc.concat(measured_ads)
sc.pp.highly_variable_genes(ad_concat, flavor='seurat_v3', n_top_genes=n_peak, batch_key=batch_key)
transformer = lsiTransformer(
n_components=n_comps, drop_first=True, log=True, norm=True,
z_score=True, tfidf=True, svd=True, pcaAlgo='arpack'
)
ad_concat.obsm['X_lsi'] = transformer.fit_transform(ad_concat[:, ad_concat.var.query('highly_variable').index.to_numpy()]).values
if len(measured_ads) > 1 and batch_corr:
ad_concat.obsm[key] = harmony(ad_concat.obsm['X_lsi'], ad_concat.obs[batch_key].to_list(), use_gpu=True)
else:
ad_concat.obsm[key] = ad_concat.obsm['X_lsi']
split_adata_ob([ad for ad in epi_ads if ad is not None], ad_concat, ob='obsm', key=key)
if return_hvf:
return ad_concat.var.query('highly_variable').index.to_numpy(), np.where(ad_concat.var['highly_variable'])[0]