Source code for metachat.tools._spatial_communication

# ============================================================
from typing import Optional
import gc

import numpy as np
import pandas as pd
from scipy import sparse

import anndata

from .._optimal_transport import fot_combine_sparse
# ============================================================

### MetaChat cell communication
class CellCommunication(object):

    def __init__(self,
        adata,
        df_metasen,
        LRC_type,
        dis_thr,
        cost_scale,
        cost_type
    ):
        # Find overlap metabolites and sensors in df_metasen
        data_var = set(adata.var_names)
        self.mets = list(set(df_metasen["HMDB.ID"]).intersection(data_var))
        self.sens = list(set(df_metasen["Sensor.Gene"]).intersection(data_var))

        # Generate an infinite matrix A. If the metabolite and the sensor can interact, 
        # let the corresponding position be 1
        A = np.inf * np.ones([len(self.mets), len(self.sens)], float)
        
        LRC = {}
        for i in range(len(df_metasen)):
            tmp_met = df_metasen.loc[i,"HMDB.ID"]
            tmp_sen = df_metasen.loc[i,"Sensor.Gene"]
            LRC[tmp_met] = df_metasen.loc[i,"Long.Range.Channel"]
            if tmp_met in self.mets and tmp_sen in self.sens:
                if cost_scale is None:
                    A[self.mets.index(tmp_met), self.sens.index(tmp_sen)] = 1.0
                else:
                    A[self.mets.index(tmp_met), self.sens.index(tmp_sen)] = cost_scale[(tmp_met, tmp_sen)]
        self.A = A.copy()
        self.LRC = LRC.copy()

        LRC_type = ["base"] + LRC_type
        self.LRC_type = LRC_type.copy()

        # Generate expression matrices of metabolites and sensors for all spots 
        self.S = adata[:,self.mets].X.toarray()
        self.D = adata[:,self.sens].X.toarray()
        
        # The dictionary approach to storing the distance matrix, since there are multiple LRC channels,
        # allows you to store the new distances generated by the corresponding channel in dmat according
        # to the name of the channel to be inferred.   
        if cost_type == 'euc':
            dmat = {}
            for tLRC in LRC_type:
                dmat[tLRC] = adata.obsp['spatial_distance_LRC_' + tLRC].copy()
        elif cost_type == 'euc_square':
            dmat = {}
            for tLRC in LRC_type:
                dmat[tLRC] = adata.obsp['spatial_distance_LRC_' + tLRC].copy() ** 2   
        self.M = dmat.copy()
        
        if np.isscalar(dis_thr):
            if cost_type == 'euc_square':
                dis_thr = dis_thr ** 2
            self.cutoff = float(dis_thr) * np.ones_like(A)
        elif type(dis_thr) is dict:
            self.cutoff = np.zeros_like(A)
            for i in range(A.shape[0]):
                for j in range(A.shape[1]):
                    if A[i,j] > 0:
                        if cost_type == 'euc_square':
                            self.cutoff[i,j] = dis_thr[(self.mets[i], self.sens[j])] ** 2
                        else:
                            self.cutoff[i,j] = dis_thr[(self.mets[i], self.sens[j])]
        self.nmet = self.S.shape[1]; self.nsen = self.D.shape[1]
        self.npts = adata.shape[0]

    def run_fot_signaling(self,
        fot_eps_p=1e-1, 
        fot_eps_mu=None, 
        fot_eps_nu=None, 
        fot_rho=1e1, 
        fot_nitermax=1e4, 
        fot_weights=(1.0,0.0,0.0,0.0),
    ):
        self.comm_network_sender, self.comm_network_receiver = fot_combine_sparse(self.S, self.mets, self.D, self.A, self.M, self.LRC, self.LRC_type, self.cutoff, \
            eps_p=fot_eps_p, eps_mu=fot_eps_mu, eps_nu=fot_eps_nu, rho=fot_rho, weights=fot_weights, nitermax=fot_nitermax)

[docs] def metabolic_communication( adata: anndata.AnnData, database_name: str = None, df_metasen: pd.DataFrame = None, LRC_type: list = None, dis_thr: Optional[float] = None, cost_scale: Optional[dict] = None, cost_type: str = 'euc', fot_eps_p: float = 2.5e-1, fot_eps_mu: Optional[float] = None, fot_eps_nu: Optional[float] = None, fot_rho: float = 1e-1, fot_nitermax: int = 10000, fot_weights: tuple = (1.0,0.0,0.0,0.0), copy: bool = False ): """ Infer spatial metabolic cell communication (MCC) using the Flow Optimal Transport (FOT) framework. Parameters ---------- adata : anndata.AnnData The data matrix of shape ``n_obs × n_var`` (cells/spots × genes). Rows correspond to cells or spots and columns to genes. database_name : str Name of the metabolite–sensor interaction database (used as prefix for storing results). df_metasen : pd.DataFrame DataFrame describing metabolite–sensor pairs, typically from MetaChatDB or related sources. Must include columns: ``['HMDB.ID', 'Sensor.Gene', 'Metabolite.Pathway', 'Sensor.Pathway', 'Metabolite.Names', 'Long.Range.Channel']``. LRC_type : list of str, optional Names of long-range channels (LRCs) such as ``["Blood"]``, ``["CSF"]`` or ``["Blood", "CSF"]``. dis_thr : float Distance threshold for defining the region influenced by each LRC. cost_scale : dict, optional Weight coefficients of the cost matrix for each metabolite-sensor pair, e.g. cost_scale[('metA','senA')] specifies weight for the pair metA and senA. If None, all pairs have the same weight. cost_type : str, {'euc', 'euc_square'}, default='euc' Defines how spatial distances are used as cost matrix. fot_eps_p : float, default=2.5e-1 The coefficient of entropy regularization for transport plan. fot_eps_mu : float, optional The coefficient of entropy regularization for untransported source (metabolite). Set to equal to fot_eps_p for fast algorithm. fot_eps_nu : float, optional The coefficient of entropy regularization for unfulfilled target (sensor). Set to equal to fot_eps_p for fast algorithm. fot_rho : float, optional The coefficient of penalty for unmatched mass. fot_nitermax : int, default=10000 Maximum iteration for flow optimal transport algorithm. fot_weights : tuple of float, default=(1.0, 0.0, 0.0, 0.0) A tuple of four weights that add up to one. The weights corresponds to four setups of flow optimal transport: 1) all metabolite-all sensors, 2) each metabolite-all sensors, 3) all metabolite-each sensor, 4) each metabolite-each sensor. copy : bool, default=False Whether to return a new AnnData object or modify in place. Returns ------- adata : anndata.AnnData or None Updates the following fields: - ``.obsp['MetaChat-{database}-sender-{met}-{sen}']`` : n_obs × n_obs sparse matrix of sender signals. - ``.obsp['MetaChat-{database}-receiver-{met}-{sen}']`` : n_obs × n_obs sparse matrix of receiver signals. - ``.obsm['MetaChat-{database}-sum-sender']`` : DataFrame of per-cell sender signal sums. - ``.obsm['MetaChat-{database}-sum-receiver']`` : DataFrame of per-cell receiver signal sums. - ``.uns['MetaChat-{database}-info']`` : metadata including distance threshold. If ``copy=True``, return a modified copy of `adata`; otherwise return None. """ # ==== Basic checks ==== assert database_name is not None, "Please provide a database_name." assert df_metasen is not None, "Please provide a metabolite-sensor database." assert dis_thr is not None, "Please provide a dis_thr (distance threshold)." if LRC_type is None: print("You didn't input LRC_type, so long-range communication will not be consider in inference.") # ==== Filter valid metabolite–sensor pairs ==== data_var = list(adata.var_names) tmp_metasen = [] for i in range(df_metasen.shape[0]): if df_metasen.loc[i,"HMDB.ID"] in data_var and df_metasen.loc[i,"Sensor.Gene"] in data_var: tmp_metasen.append(df_metasen.loc[i,:]) tmp_metasen = np.array(tmp_metasen, str) df_metasen_filtered = pd.DataFrame(data = tmp_metasen) df_metasen_filtered.columns = df_metasen.columns.copy() df_metasen_filtered = df_metasen_filtered.drop_duplicates() adata.uns["df_metasen_filtered"] = df_metasen_filtered.copy() print("There are %d pairs were found from the spatial data." %df_metasen_filtered.shape[0]) # ==== Initialize and run FOT-based communication model ==== model = CellCommunication( adata, df_metasen_filtered, LRC_type, dis_thr, cost_scale, cost_type ) model.run_fot_signaling( fot_eps_p = fot_eps_p, fot_eps_mu = fot_eps_mu, fot_eps_nu = fot_eps_nu, fot_rho = fot_rho, fot_nitermax = fot_nitermax, fot_weights = fot_weights ) # ==== Store metadata ==== adata.uns[f"MetaChat-{database_name}-info"] = {} adata.uns[f"MetaChat-{database_name}-info"]['distance_threshold'] = dis_thr # ==== Aggregate sender and receiver matrices ==== ncell = adata.shape[0] X_sender = np.empty([ncell,0], float) X_receiver = np.empty([ncell,0], float) col_names_sender = [] col_names_receiver = [] tmp_mets = model.mets tmp_sens = model.sens P_sender_total = sparse.csr_matrix((ncell, ncell), dtype=float) P_receiver_total = sparse.csr_matrix((ncell, ncell), dtype=float) for (i,j) in model.comm_network_sender.keys(): P_sender = model.comm_network_sender[(i,j)] P_receiver = model.comm_network_receiver[(i,j)] P_sender_total = P_sender_total + P_sender P_receiver_total = P_receiver_total + P_receiver adata.obsp['MetaChat-'+database_name+'-sender-'+tmp_mets[i]+'-'+tmp_sens[j]] = P_sender.tocsr() adata.obsp['MetaChat-'+database_name+'-receiver-'+tmp_mets[i]+'-'+tmp_sens[j]] = P_receiver.tocsr() lig_sum = np.array(P_sender.sum(axis=1)) rec_sum = np.array(P_receiver.sum(axis=0).T) X_sender = np.concatenate((X_sender, lig_sum), axis=1) X_receiver = np.concatenate((X_receiver, rec_sum), axis=1) col_names_sender.append("s-%s-%s" % (tmp_mets[i], tmp_sens[j])) col_names_receiver.append("r-%s-%s" % (tmp_mets[i], tmp_sens[j])) # ==== Add total summaries ==== X_sender = np.concatenate((X_sender, X_sender.sum(axis=1).reshape(-1,1)), axis=1) X_receiver = np.concatenate((X_receiver, X_receiver.sum(axis=1).reshape(-1,1)), axis=1) col_names_sender.append("s-total-total") col_names_receiver.append("r-total-total") adata.obsp['MetaChat-'+database_name+'-sender-total-total'] = P_sender_total adata.obsp['MetaChat-'+database_name+'-receiver-total-total'] = P_receiver_total df_sender = pd.DataFrame(data=X_sender, columns=col_names_sender, index=adata.obs_names) df_receiver = pd.DataFrame(data=X_receiver, columns=col_names_receiver, index=adata.obs_names) adata.obsm['MetaChat-'+database_name+'-sum-sender'] = df_sender adata.obsm['MetaChat-'+database_name+'-sum-receiver'] = df_receiver del model gc.collect() return adata.copy() if copy else None