Source code for metachat.tools._spatial_communication
# ============================================================
from typing import Optional
import gc
import numpy as np
import pandas as pd
from scipy import sparse
import anndata
from .._optimal_transport import fot_combine_sparse
# ============================================================
### MetaChat cell communication
class CellCommunication(object):
def __init__(self,
adata,
df_metasen,
LRC_type,
dis_thr,
cost_scale,
cost_type
):
# Find overlap metabolites and sensors in df_metasen
data_var = set(adata.var_names)
self.mets = list(set(df_metasen["HMDB.ID"]).intersection(data_var))
self.sens = list(set(df_metasen["Sensor.Gene"]).intersection(data_var))
# Generate an infinite matrix A. If the metabolite and the sensor can interact,
# let the corresponding position be 1
A = np.inf * np.ones([len(self.mets), len(self.sens)], float)
LRC = {}
for i in range(len(df_metasen)):
tmp_met = df_metasen.loc[i,"HMDB.ID"]
tmp_sen = df_metasen.loc[i,"Sensor.Gene"]
LRC[tmp_met] = df_metasen.loc[i,"Long.Range.Channel"]
if tmp_met in self.mets and tmp_sen in self.sens:
if cost_scale is None:
A[self.mets.index(tmp_met), self.sens.index(tmp_sen)] = 1.0
else:
A[self.mets.index(tmp_met), self.sens.index(tmp_sen)] = cost_scale[(tmp_met, tmp_sen)]
self.A = A.copy()
self.LRC = LRC.copy()
LRC_type = ["base"] + LRC_type
self.LRC_type = LRC_type.copy()
# Generate expression matrices of metabolites and sensors for all spots
self.S = adata[:,self.mets].X.toarray()
self.D = adata[:,self.sens].X.toarray()
# The dictionary approach to storing the distance matrix, since there are multiple LRC channels,
# allows you to store the new distances generated by the corresponding channel in dmat according
# to the name of the channel to be inferred.
if cost_type == 'euc':
dmat = {}
for tLRC in LRC_type:
dmat[tLRC] = adata.obsp['spatial_distance_LRC_' + tLRC].copy()
elif cost_type == 'euc_square':
dmat = {}
for tLRC in LRC_type:
dmat[tLRC] = adata.obsp['spatial_distance_LRC_' + tLRC].copy() ** 2
self.M = dmat.copy()
if np.isscalar(dis_thr):
if cost_type == 'euc_square':
dis_thr = dis_thr ** 2
self.cutoff = float(dis_thr) * np.ones_like(A)
elif type(dis_thr) is dict:
self.cutoff = np.zeros_like(A)
for i in range(A.shape[0]):
for j in range(A.shape[1]):
if A[i,j] > 0:
if cost_type == 'euc_square':
self.cutoff[i,j] = dis_thr[(self.mets[i], self.sens[j])] ** 2
else:
self.cutoff[i,j] = dis_thr[(self.mets[i], self.sens[j])]
self.nmet = self.S.shape[1]; self.nsen = self.D.shape[1]
self.npts = adata.shape[0]
def run_fot_signaling(self,
fot_eps_p=1e-1,
fot_eps_mu=None,
fot_eps_nu=None,
fot_rho=1e1,
fot_nitermax=1e4,
fot_weights=(1.0,0.0,0.0,0.0),
):
self.comm_network_sender, self.comm_network_receiver = fot_combine_sparse(self.S, self.mets, self.D, self.A, self.M, self.LRC, self.LRC_type, self.cutoff, \
eps_p=fot_eps_p, eps_mu=fot_eps_mu, eps_nu=fot_eps_nu, rho=fot_rho, weights=fot_weights, nitermax=fot_nitermax)
[docs]
def metabolic_communication(
adata: anndata.AnnData,
database_name: str = None,
df_metasen: pd.DataFrame = None,
LRC_type: list = None,
dis_thr: Optional[float] = None,
cost_scale: Optional[dict] = None,
cost_type: str = 'euc',
fot_eps_p: float = 2.5e-1,
fot_eps_mu: Optional[float] = None,
fot_eps_nu: Optional[float] = None,
fot_rho: float = 1e-1,
fot_nitermax: int = 10000,
fot_weights: tuple = (1.0,0.0,0.0,0.0),
copy: bool = False
):
"""
Infer spatial metabolic cell communication (MCC) using the Flow Optimal Transport (FOT) framework.
Parameters
----------
adata : anndata.AnnData
The data matrix of shape ``n_obs × n_var`` (cells/spots × genes).
Rows correspond to cells or spots and columns to genes.
database_name : str
Name of the metabolite–sensor interaction database (used as prefix for storing results).
df_metasen : pd.DataFrame
DataFrame describing metabolite–sensor pairs, typically from MetaChatDB or related sources.
Must include columns: ``['HMDB.ID', 'Sensor.Gene', 'Metabolite.Pathway', 'Sensor.Pathway', 'Metabolite.Names', 'Long.Range.Channel']``.
LRC_type : list of str, optional
Names of long-range channels (LRCs) such as ``["Blood"]``, ``["CSF"]`` or ``["Blood", "CSF"]``.
dis_thr : float
Distance threshold for defining the region influenced by each LRC.
cost_scale : dict, optional
Weight coefficients of the cost matrix for each metabolite-sensor pair, e.g. cost_scale[('metA','senA')] specifies weight for the pair metA and senA.
If None, all pairs have the same weight.
cost_type : str, {'euc', 'euc_square'}, default='euc'
Defines how spatial distances are used as cost matrix.
fot_eps_p : float, default=2.5e-1
The coefficient of entropy regularization for transport plan.
fot_eps_mu : float, optional
The coefficient of entropy regularization for untransported source (metabolite). Set to equal to fot_eps_p for fast algorithm.
fot_eps_nu : float, optional
The coefficient of entropy regularization for unfulfilled target (sensor). Set to equal to fot_eps_p for fast algorithm.
fot_rho : float, optional
The coefficient of penalty for unmatched mass.
fot_nitermax : int, default=10000
Maximum iteration for flow optimal transport algorithm.
fot_weights : tuple of float, default=(1.0, 0.0, 0.0, 0.0)
A tuple of four weights that add up to one. The weights corresponds to four setups of flow optimal transport:
1) all metabolite-all sensors, 2) each metabolite-all sensors, 3) all metabolite-each sensor, 4) each metabolite-each sensor.
copy : bool, default=False
Whether to return a new AnnData object or modify in place.
Returns
-------
adata : anndata.AnnData or None
Updates the following fields:
- ``.obsp['MetaChat-{database}-sender-{met}-{sen}']`` : n_obs × n_obs sparse matrix of sender signals.
- ``.obsp['MetaChat-{database}-receiver-{met}-{sen}']`` : n_obs × n_obs sparse matrix of receiver signals.
- ``.obsm['MetaChat-{database}-sum-sender']`` : DataFrame of per-cell sender signal sums.
- ``.obsm['MetaChat-{database}-sum-receiver']`` : DataFrame of per-cell receiver signal sums.
- ``.uns['MetaChat-{database}-info']`` : metadata including distance threshold.
If ``copy=True``, return a modified copy of `adata`; otherwise return None.
"""
# ==== Basic checks ====
assert database_name is not None, "Please provide a database_name."
assert df_metasen is not None, "Please provide a metabolite-sensor database."
assert dis_thr is not None, "Please provide a dis_thr (distance threshold)."
if LRC_type is None:
print("You didn't input LRC_type, so long-range communication will not be consider in inference.")
# ==== Filter valid metabolite–sensor pairs ====
data_var = list(adata.var_names)
tmp_metasen = []
for i in range(df_metasen.shape[0]):
if df_metasen.loc[i,"HMDB.ID"] in data_var and df_metasen.loc[i,"Sensor.Gene"] in data_var:
tmp_metasen.append(df_metasen.loc[i,:])
tmp_metasen = np.array(tmp_metasen, str)
df_metasen_filtered = pd.DataFrame(data = tmp_metasen)
df_metasen_filtered.columns = df_metasen.columns.copy()
df_metasen_filtered = df_metasen_filtered.drop_duplicates()
adata.uns["df_metasen_filtered"] = df_metasen_filtered.copy()
print("There are %d pairs were found from the spatial data." %df_metasen_filtered.shape[0])
# ==== Initialize and run FOT-based communication model ====
model = CellCommunication(
adata,
df_metasen_filtered,
LRC_type,
dis_thr,
cost_scale,
cost_type
)
model.run_fot_signaling(
fot_eps_p = fot_eps_p,
fot_eps_mu = fot_eps_mu,
fot_eps_nu = fot_eps_nu,
fot_rho = fot_rho,
fot_nitermax = fot_nitermax,
fot_weights = fot_weights
)
# ==== Store metadata ====
adata.uns[f"MetaChat-{database_name}-info"] = {}
adata.uns[f"MetaChat-{database_name}-info"]['distance_threshold'] = dis_thr
# ==== Aggregate sender and receiver matrices ====
ncell = adata.shape[0]
X_sender = np.empty([ncell,0], float)
X_receiver = np.empty([ncell,0], float)
col_names_sender = []
col_names_receiver = []
tmp_mets = model.mets
tmp_sens = model.sens
P_sender_total = sparse.csr_matrix((ncell, ncell), dtype=float)
P_receiver_total = sparse.csr_matrix((ncell, ncell), dtype=float)
for (i,j) in model.comm_network_sender.keys():
P_sender = model.comm_network_sender[(i,j)]
P_receiver = model.comm_network_receiver[(i,j)]
P_sender_total = P_sender_total + P_sender
P_receiver_total = P_receiver_total + P_receiver
adata.obsp['MetaChat-'+database_name+'-sender-'+tmp_mets[i]+'-'+tmp_sens[j]] = P_sender.tocsr()
adata.obsp['MetaChat-'+database_name+'-receiver-'+tmp_mets[i]+'-'+tmp_sens[j]] = P_receiver.tocsr()
lig_sum = np.array(P_sender.sum(axis=1))
rec_sum = np.array(P_receiver.sum(axis=0).T)
X_sender = np.concatenate((X_sender, lig_sum), axis=1)
X_receiver = np.concatenate((X_receiver, rec_sum), axis=1)
col_names_sender.append("s-%s-%s" % (tmp_mets[i], tmp_sens[j]))
col_names_receiver.append("r-%s-%s" % (tmp_mets[i], tmp_sens[j]))
# ==== Add total summaries ====
X_sender = np.concatenate((X_sender, X_sender.sum(axis=1).reshape(-1,1)), axis=1)
X_receiver = np.concatenate((X_receiver, X_receiver.sum(axis=1).reshape(-1,1)), axis=1)
col_names_sender.append("s-total-total")
col_names_receiver.append("r-total-total")
adata.obsp['MetaChat-'+database_name+'-sender-total-total'] = P_sender_total
adata.obsp['MetaChat-'+database_name+'-receiver-total-total'] = P_receiver_total
df_sender = pd.DataFrame(data=X_sender, columns=col_names_sender, index=adata.obs_names)
df_receiver = pd.DataFrame(data=X_receiver, columns=col_names_receiver, index=adata.obs_names)
adata.obsm['MetaChat-'+database_name+'-sum-sender'] = df_sender
adata.obsm['MetaChat-'+database_name+'-sum-receiver'] = df_receiver
del model
gc.collect()
return adata.copy() if copy else None