from typing import Optional
import numpy as np
import pandas as pd
from scipy import sparse
from scipy.spatial import distance_matrix
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.cm import ScalarMappable
from matplotlib.lines import Line2D
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as patches
import matplotlib.lines as mlines
import seaborn as sns
import squidpy as sq
import plotly.express as px
import plotly.graph_objects as go
from pycirclize import Circos
from adjustText import adjust_text
from plotly.subplots import make_subplots
import networkx as nx
import anndata
from .._utils import plot_cell_signaling
from .._utils import get_cmap_qualitative
[docs]
def plot_communication_flow(
adata: anndata.AnnData,
database_name: str,
metabolite_name: str = None,
metapathway_name: str = None,
customerlist_name: str = None,
ms_pair_name: str = None,
summary: str = "receiver",
plot_method: str = "grid",
background: str = "image",
library_id: str = None,
group_name: str = None,
group_cmap: dict = None,
background_legend: bool = False,
cmap: str = "coolwarm",
pos_idx: np.ndarray = np.array([0,1], int),
normalize_summary_quantile: float = 0.995,
ndsize: float = 3,
vmin=None,
vmax=None,
grid_density: float = 1.0,
grid_knn: int = None,
grid_scale: float = 1.0,
grid_thresh: float = 1.0,
arrow_color: str = "#000000",
arrow_width: float = 0.005,
largest_arrow: float = 0.05,
normalize_v: bool = False,
normalize_v_quantile: float = 0.95,
title: str = None,
plot_savepath: str = None,
ax: Optional[mpl.axes.Axes] = None
):
"""
Visualize inferred metabolic communication vector fields on tissue images or annotated backgrounds.
This function overlays communication flow vectors computed from MetaChat analysis results on tissue images,
categorical annotations, or summary backgrounds. It supports spot-wise, grid-interpolated, or streamline rendering.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing spatial coordinates in ``adata.obsm["spatial"]``.
database_name : str
Name of the metabolite–sensor interaction database used for communication inference.
metabolite_name : str, optional
Name of a specific metabolite to visualize. Mutually exclusive with ``metapathway_name``,
``customerlist_name``, and ``ms_pair_name``.
metapathway_name : str, optional
Name of a specific metabolic pathway to visualize.
customerlist_name : str, optional
Name of a custom metabolite–sensor list to visualize.
ms_pair_name : str, optional
Name of a metabolite–sensor pair to visualize.
summary : {"sender", "receiver"}, default="receiver"
Type of summary statistic used for background coloring.
plot_method : {"cell", "grid", "stream"}, default="grid"
Rendering mode for vector visualization.
background : {"summary", "image", "group"}, default="image"
Type of background to draw behind the vectors.
library_id : str, optional
Visium library identifier used when ``background="image"``.
group_name : str, optional
Column in ``adata.obs`` used for categorical coloring when ``background="group"``.
group_cmap : dict, optional
Mapping from category names to colors.
background_legend : bool, default=False
Whether to display legend for the background.
cmap : str, default="coolwarm"
Colormap name for summary or numeric backgrounds.
pos_idx : np.ndarray, default=np.array([0,1])
Indices of spatial coordinates used for plotting.
normalize_summary_quantile : float, default=0.995
Quantile for clipping background values.
ndsize : float, default=3
Spot marker size.
vmin, vmax : float, optional
Color scale limits for background values.
grid_density : float, default=1.0
Density of interpolation grid when ``plot_method="grid"``.
grid_knn : int, optional
Number of neighbors used for vector interpolation.
grid_scale : float, default=1.0
Kernel scale factor for interpolation.
grid_thresh : float, default=1.0
Minimum interpolation weight threshold.
arrow_color : str, default="#000000"
Color of flow arrows.
arrow_width : float, default=0.005
Width of quiver arrows.
largest_arrow : float, default=0.05
Maximum arrow length after normalization.
normalize_v : bool, default=False
Whether to normalize vector magnitudes for visualization.
normalize_v_quantile : float, default=0.95
Quantile for magnitude clipping before normalization.
title : str, optional
Title of the plot.
plot_savepath : str, optional
File path to save the plot; no saving if ``None``.
ax : matplotlib.axes.Axes, optional
Axis to draw the plot on; creates new one if ``None``.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the resulting plot.
coords_plot : np.ndarray
Coordinates used for plotting (after interpolation).
V_plot : np.ndarray
Vector field used for plotting (after interpolation).
"""
# Check inputs
assert database_name is not None, "Please at least specify database_name."
not_none_count = sum(x is not None for x in [metabolite_name, metapathway_name, customerlist_name, ms_pair_name])
assert not_none_count <= 1, "Please specify at most one of metabolite_name, metapathway_name, customerlist_name, or ms_pair_name."
if summary == 'sender':
summary_abbr = 's'
elif summary == 'receiver':
summary_abbr = 'r'
else:
raise ValueError("summary must be either 'sender' or 'receiver'.")
if metabolite_name is None and metapathway_name is None and customerlist_name is None and ms_pair_name is None:
vf_name = 'total-total'
sum_name = 'total-total'
obsm_name = ''
elif metabolite_name is not None:
vf_name = metabolite_name
sum_name = metabolite_name
obsm_name = '-metabolite'
elif metapathway_name is not None:
vf_name = metapathway_name
sum_name = metapathway_name
obsm_name = '-pathway'
elif customerlist_name is not None:
vf_name = customerlist_name
sum_name = customerlist_name
obsm_name = '-customer'
elif ms_pair_name is not None:
vf_name = ms_pair_name
sum_name = ms_pair_name
obsm_name = ''
V = adata.obsm['MetaChat-vf-' + database_name + '-' + summary + '-' + vf_name][:,pos_idx].copy()
signal_sum = adata.obsm['MetaChat-' + database_name + "-sum-" + summary + obsm_name][summary_abbr + '-' + sum_name].copy()
if background=='group':
if group_cmap is None:
group_cmap = dict(zip(adata.obs[group_name].cat.categories.tolist(), adata.uns[group_name + '_colors']))
elif not group_cmap.lower() in ['plotly','light24','dark24','alphabet']:
group_cmap='alphabet'
if ax is None:
fig, ax = plt.subplots()
if normalize_v:
V = V / np.quantile(np.linalg.norm(V, axis=1), normalize_v_quantile)
ax, coords_plot, V_plot = plot_cell_signaling(
coords = adata.obsm["spatial"][:,pos_idx],
Vector = V,
signal_sum = signal_sum,
summary = summary,
plot_method = plot_method,
adata = adata,
background = background,
library_id = library_id,
group_name = group_name,
group_cmap = group_cmap,
background_legend = background_legend,
cmap = cmap,
normalize_summary_quantile = normalize_summary_quantile,
ndsize = ndsize,
vmin = vmin,
vmax = vmax,
grid_density = grid_density,
grid_knn = grid_knn,
grid_scale = grid_scale,
grid_thresh = grid_thresh,
arrow_color = arrow_color,
arrow_width = arrow_width,
largest_arrow = largest_arrow,
title = title,
plot_savepath = plot_savepath,
ax = ax
)
return ax, coords_plot, V_plot
[docs]
def plot_group_communication_chord(
adata: anndata.AnnData,
database_name: str = None,
group_name: str = None,
summary: str = 'sender',
metabolite_name: str = None,
metapathway_name: str = None,
customerlist_name: str = None,
ms_pair_name: str = None,
permutation_spatial: bool = False,
p_value_cutoff: float = 0.05,
self_communication_off: bool = False,
highlight_group_sender: str = None,
highlight_group_receiver: str = None,
space: int = 5,
group_cmap: str = None,
figsize: tuple = (5, 5),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None,
):
"""
Plot a chord diagram representing group-level metabolic cell communication (MCC).
This function visualizes inter-group communication strength as a chord diagram,
based on results computed by MetaChat group-level analysis. Each arc represents
a cell group, and links represent the strength of metabolite-mediated communication
between sender and receiver groups.
The communication matrix and associated p-values must be precomputed by
:func:`mc.tl.communication_group` or :func:`mc.tl.communication_group_spatial`.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing group-level MCC results in ``adata.uns``.
database_name : str
Name of the metabolite–sensor interaction database used for MCC analysis.
group_name : str
Column name in ``adata.obs`` specifying cell group annotations.
summary : {"sender", "receiver"}, default="sender"
Direction of communication summary to visualize.
metabolite_name : str, optional
Name of a specific metabolite to visualize.
metapathway_name : str, optional
Name of a specific metabolic pathway to visualize.
customerlist_name : str, optional
Name of a specific custom metabolite list to visualize.
ms_pair_name: str, optional
Name of a specific metabolite–sensor pair to visualize.
permutation_spatial : bool, default=False
Whether to use results from spatially permuted communication tests
(``mc.tl.communication_group_spatial``).
p_value_cutoff : float, default=0.05
P-value threshold for filtering significant group-level communications.
self_communication_off : bool, default=False
Whether to remove self-communication.
highlight_group_sender : str or list of str, optional
Name(s) of sender group(s) to highlight. Other groups are rendered transparent.
highlight_group_receiver : str or list of str, optional
Name(s) of receiver group(s) to highlight. Other groups are rendered transparent.
space : int, default=5
Angular spacing between group arcs in degrees.
group_cmap : dict or str, optional
Mapping from group names to colors, or qualitative palette name
(e.g., ``'Plotly'``, ``'Alphabet'``).
figsize : tuple of float, default=(5, 5)
Figure size (width, height).
ax : matplotlib.axes.Axes, optional
Existing axis to draw the chord diagram. If ``None``, a new one is created.
plot_savepath : str, optional
File path to save the resulting figure. The format is inferred from the file extension.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the chord diagram.
Notes
-----
- The function requires that group-level communication results are already computed and stored in:
- ``adata.uns["MetaChat_group-<group_name>-<database_name>-<summary>-<target>"]``
(for :func:`mc.tl.communication_group`), or
- ``adata.uns["MetaChat_group_spatial-..."]``
(for :func:`mc.tl.communication_group_spatial`).
- Links are filtered by ``p_value_cutoff`` and optionally exclude self-communications.
- Highlight options can be used to emphasize communication from or to selected cell groups.
- This visualization helps identify dominant sender–receiver relationships across cell types or regions.
"""
# Check inputs
assert database_name is not None, "Please at least specify database_name."
assert group_name is not None, "Please at least specify group_name."
not_none_count = sum(x is not None for x in [metabolite_name, metapathway_name, customerlist_name, ms_pair_name])
assert not_none_count <= 1, ("Please specify at most one of metabolite_name, metapathway_name, customerlist_name, or ms_pair_name.")
if metabolite_name:
uns_names = metabolite_name
elif metapathway_name:
uns_names = metapathway_name
elif customerlist_name:
uns_names = customerlist_name
elif ms_pair_name:
uns_names = ms_pair_name
else:
uns_names = "total-total"
if permutation_spatial == True:
df_communMatrix = adata.uns["MetaChat_group_spatial-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_matrix'].copy()
df_pvalue = adata.uns["MetaChat_group_spatial-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_pvalue'].copy()
else:
df_communMatrix = adata.uns["MetaChat_group-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_matrix'].copy()
df_pvalue = adata.uns["MetaChat_group-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_pvalue'].copy()
df_communMatrix[df_pvalue > p_value_cutoff] = 0
if self_communication_off:
for i in range(df_communMatrix.shape[0]):
df_communMatrix.iloc[i,i] = 0
df_communMatrix = df_communMatrix.loc[df_communMatrix.sum(axis=1) != 0]
df_communMatrix = df_communMatrix.loc[:, df_communMatrix.sum(axis=0) != 0]
link_kws_handler = None
if (not highlight_group_sender is None) or (not highlight_group_receiver is None):
def link_kws_handler(from_label: str,
to_label: str):
if (not highlight_group_sender is None) and (highlight_group_receiver is None):
if from_label in highlight_group_sender:
return dict(alpha=0.7, zorder=1.0)
else:
return dict(alpha=0.2, zorder=0)
elif (highlight_group_sender is None) and (not highlight_group_receiver is None):
if to_label in highlight_group_receiver:
return dict(alpha=0.7, zorder=1.0)
else:
return dict(alpha=0.2, zorder=0)
else:
if from_label in highlight_group_sender or to_label in highlight_group_receiver:
return dict(alpha=0.7, zorder=1.0)
else:
return dict(alpha=0.2, zorder=0)
if group_cmap is None:
group_cmap = dict(zip(adata.obs[group_name].cat.categories.tolist(), adata.uns[group_name + '_colors']))
if np.sum(np.sum(df_communMatrix)) != 0:
circos = Circos.initialize_from_matrix(
df_communMatrix,
space = space,
cmap = group_cmap,
label_kws = dict(size=12),
link_kws = dict(ec="black", lw=0.5, direction=1),
link_kws_handler = link_kws_handler
)
if plot_savepath is not None:
circos.savefig(plot_savepath, figsize=figsize)
else:
circos.plotfig(figsize=figsize, ax=ax)
else:
print("There is no significant group communication in " + uns_names)
return ax
[docs]
def plot_group_communication_heatmap(
adata: anndata.AnnData,
database_name: str = None,
group_name: str = None,
metabolite_name: str = None,
metapathway_name: str = None,
customerlist_name: str = None,
ms_pair_name: str = None,
summary: str = "sender",
permutation_spatial: bool = False,
p_value_plot: bool = True,
p_value_cutoff: float = 0.05,
size_scale: float = 300,
cmap: str = "green",
palette: Optional[list] = None,
marker: str = "s",
x_order: Optional[list] = None,
y_order: Optional[list] = None,
figsize: tuple = (10, 10),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: Optional[str] = None
):
"""
Plot a heatmap diagram for group-level metabolic cell communication (MCC).
This function visualizes group-level MCC intensity between sender and receiver cell groups,
using aheatmap representation. Each point corresponds to a sender–receiver pair,
colored by communication strength and optionally marked with significance (*).
The group-level communication results must be computed beforehand by
:func:`mc.tl.communication_group` or :func:`mc.tl.communication_group_spatial`.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing group-level MCC results in ``adata.uns``.
database_name : str
Name of the metabolite–sensor interaction database.
group_name : str
Column name in ``adata.obs`` specifying the cell group annotation.
metabolite_name : str, optional
Name of a specific metabolite to visualize (e.g., ``'HMDB0000148'``).
metapathway_name : str, optional
Name of a specific metabolic pathway to visualize.
customerlist_name : str, optional
Name of a custom metabolite list to visualize.
ms_pair_name : str, optional
Name of a specific metabolite–sensor pair to visualize.
summary : {"sender", "receiver"}, default="sender"
Direction of communication summary to visualize.
permutation_spatial : bool, default=False
Whether to use results from ``mc.tl.communication_group_spatial``.
p_value_plot : bool, default=True
Whether to mark significant interactions with an asterisk ``*``.
p_value_cutoff : float, default=0.05
Significance threshold for group-level MCC visualization.
size_scale : float, default=300
Scaling factor controlling marker size proportional to MCC score.
cmap : {"green", "red", "blue"}, default="green"
Predefined color theme for communication strength.
palette : list, optional
Custom color palette list. If specified, overrides ``cmap``.
marker : str, default="s"
Matplotlib marker style used in the scatter plot.
x_order : list of str, optional
Order of sender groups along the x-axis.
y_order : list of str, optional
Order of receiver groups along the y-axis.
figsize : tuple of float, default=(10, 10)
Size of the output figure (width, height).
plot_savepath : str, optional
File path to save the figure. The format is inferred from the extension.
ax : matplotlib.axes.Axes, optional
Existing axis to draw on. If ``None``, a new figure and axis are created.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the heatmap plot.
Notes
-----
- The function expects communication matrices and p-values stored in:
- ``adata.uns["MetaChat_group-<group_name>-<database_name>-<summary>-<target>"]`` or
- ``adata.uns["MetaChat_group_spatial-..."]`` for spatial permutation results.
- Significance marks (``*``) are added when ``p_value_plot=True`` and p < ``p_value_cutoff``.
- Marker size and color are scaled by communication intensity.
- Use ``x_order`` and ``y_order`` to enforce custom group display order.
"""
# ==== Check inputs ====
assert database_name is not None, "Please at least specify database_name."
assert group_name is not None, "Please at least specify group_name."
not_none_count = sum(
x is not None
for x in [metabolite_name, metapathway_name, customerlist_name, ms_pair_name]
)
assert not_none_count <= 1, (
"Please specify at most one of metabolite_name, metapathway_name, "
"customerlist_name, or ms_pair_name."
)
if metabolite_name:
uns_names = metabolite_name
elif metapathway_name:
uns_names = metapathway_name
elif customerlist_name:
uns_names = customerlist_name
elif ms_pair_name:
uns_names = ms_pair_name
else:
uns_names = "total-total"
# ==== Load communication results ====
if permutation_spatial == True:
df_communMatrix = adata.uns["MetaChat_group_spatial-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_matrix'].copy()
df_pvalue = adata.uns["MetaChat_group_spatial-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_pvalue'].copy()
else:
df_communMatrix = adata.uns["MetaChat_group-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_matrix'].copy()
df_pvalue = adata.uns["MetaChat_group-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names]['communication_pvalue'].copy()
# ==== Prepare plot settings ====
df_communMatrix = df_communMatrix.reset_index()
melt_communMatrix = pd.melt(df_communMatrix, id_vars='index', var_name='Column', value_name='Value')
melt_communMatrix.columns = ['Sender','Receiver','MCC_score']
df_pvalue = df_pvalue.reset_index()
melt_pvalue = pd.melt(df_pvalue, id_vars='index', var_name='Column', value_name='Value')
melt_pvalue.columns = ['Sender','Receiver','p_value']
melt_df =pd.concat([melt_communMatrix, melt_pvalue['p_value']], axis=1)
sender = melt_df['Sender']
receiver = melt_df['Receiver']
color = melt_df['MCC_score']
size = melt_df['MCC_score']
p_value = melt_df['p_value']
if palette is not None:
n_colors = len(palette)
else:
n_colors = 256 # Use 256 colors for the diverging color palette
if cmap == 'green':
palette = sns.blend_palette(["#D8E6E5", "#94C1BE", "#49A59D"], n_colors=n_colors)
elif cmap == 'red':
palette = sns.blend_palette(["#FCF5B8", "#EBA55A", "#C23532"], n_colors=n_colors)
elif cmap == 'blue':
palette = sns.blend_palette(["#CCE8F9", "#72BBE7", "#4872B4"], n_colors=n_colors)
color_min, color_max = min(color), max(color)
def value_to_color(val):
color_list = color.tolist()
color_list.sort()
if color_min == color_max:
return palette[-1]
else:
index = np.searchsorted(color_list, val, side='left')
val_position = index / (len(color_list)-1)
val_position = min(max(val_position, 0), 1)
ind = int(val_position * (n_colors - 1))
return palette[ind]
size_min, size_max = min(size), max(size)
def value_to_size(val):
size_list = size.tolist()
size_list.sort()
if size_min == size_max:
return 1 * size_scale
else:
index = np.searchsorted(size_list, val, side='left')
val_position = index / (len(size_list)-1)
val_position = min(max(val_position, 0), 1)
return val_position * size_scale
if x_order is not None:
x_names = x_order
else:
x_names = [t for t in sorted(set([v for v in sender]))]
x_to_num = {p[1]:p[0] for p in enumerate(x_names)}
if y_order is not None:
y_names = y_order
else:
y_names = [t for t in sorted(set([v for v in receiver]))]
y_to_num = {p[1]:p[0] for p in enumerate(y_names)}
if figsize is None:
figsize = (len(x_names), len(y_names))
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
# ==== Plot heatmap ====
ax.scatter(
x = [x_to_num[v] for v in sender],
y = [y_to_num[v] for v in receiver],
marker = marker,
s = [value_to_size(v) for v in size],
c = [value_to_color(v) for v in color]
)
norm = mpl.colors.Normalize(vmin=color_min, vmax=color_max)
cmap_mpl = mpl.colors.ListedColormap(palette)
sm = ScalarMappable(cmap=cmap_mpl, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("MCC Score", fontsize=10)
if p_value_plot == True:
for iter in range(len(sender)):
isender = sender[iter]
ireceiver = receiver[iter]
ipvalue = p_value[iter]
if ipvalue < p_value_cutoff:
ax.text(x_to_num[isender], y_to_num[ireceiver], '*', color='black', ha='center', va='center')
ax.set_xticks([v for k,v in x_to_num.items()])
ax.set_xticklabels([k for k in x_to_num], rotation=45, horizontalalignment='right')
ax.set_yticks([v for k,v in y_to_num.items()])
ax.set_yticklabels([k for k in y_to_num])
ax.grid(False, 'major')
ax.grid(True, 'minor')
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
ax.set_facecolor('white')
ax.set_xlabel('Sender')
ax.set_ylabel('Receiver')
if figsize[0] == figsize[1]:
ax.set_box_aspect(1)
if plot_savepath is not None:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_group_communication_compare_hierarchy_diagram(
adata_A: anndata.AnnData,
adata_B: anndata.AnnData,
condition_name_A: str = None,
condition_name_B: str = None,
database_name: str = None,
group_name: str = None,
summary: str = 'sender',
metabolite_name: str = None,
metapathway_name: str = None,
customerlist_name: str = None,
ms_pair_name: str = None,
permutation_spatial: bool = False,
p_value_cutoff: float = 0.05,
node_sizes_limit: tuple = (50, 300),
edge_sizes_limit: tuple = (0.5,10),
group_cmap: dict = None,
alpha: float = 0.5,
figsize: tuple = (10, 3),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None,
):
"""
Plot a hierarchy-style diagram comparing group-level MCC between two conditions.
This visualization contrasts the direction and strength of group-level metabolic
cell communication (MCC) between two conditions (e.g., Control vs Disease).
Sender–receiver relationships are represented as directed edges between nodes
on two parallel panels, showing both conditions side by side.
Parameters
----------
adata_A : anndata.AnnData
Annotated data matrix for condition A. Must contain group-level MCC results.
adata_B : anndata.AnnData
Annotated data matrix for condition B. Must contain the same structure as ``adata_A``.
condition_name_A : str, optional
Label for condition A, shown on the left panel (e.g., "Control").
condition_name_B : str, optional
Label for condition B, shown on the right panel (e.g., "Disease").
database_name : str
Name of the metabolite–sensor interaction database used for MCC computation.
group_name : str
Column name in ``adata.obs`` containing group/cell-type annotations.
summary : {"sender", "receiver"}, default="sender"
Direction of communication summary to visualize.
metabolite_name : str, optional
Specific metabolite to visualize.
metapathway_name : str, optional
Specific metabolic pathway to visualize.
customerlist_name : str, optional
Custom metabolite list name to visualize.
ms_pair_name : str, optional
Name of a specific metabolite–sensor pair to visualize.
permutation_spatial : bool, default=False
Whether to use spatially permuted results from :func:`mc.tl.communication_group_spatial`.
p_value_cutoff : float, default=0.05
Threshold for significance in group-level communication.
node_sizes_limit : tuple of float, default=(50, 300)
Minimum and maximum node size for scaling group strength.
edge_sizes_limit : tuple of float, default=(0.5, 10)
Minimum and maximum edge width for scaling communication intensity.
group_cmap : dict, optional
Mapping from group names to colors. If ``None``, retrieved from
``adata.obs[group_name].cat.categories`` and corresponding color table in ``adata.uns[group_name + "_colors"]``.
alpha : float, default=0.5
Transparency for non-significant edges.
figsize : tuple of float, default=(10, 3)
Figure size (width, height).
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw the hierarchy diagram on.
If ``None``, a new figure and axis are created.
plot_savepath : str, optional
File path to save the figure (e.g., ``"results/group_compare_hierarchy.pdf"``).
The format is inferred from the file extension. If ``None``, the plot is displayed
interactively without saving.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the hierarchy comparison diagram.
Notes
-----
- This function visualizes the output from :func:`mc.tl.communication_group`
or :func:`mc.tl.communication_group_spatial` for two different conditions.
- Node size reflects total communication strength per group.
- Edge width and color indicate direction and magnitude of communication.
- Non-significant edges (p ≥ cutoff) are rendered transparent with reduced opacity.
"""
# ==== Check inputs ====
assert database_name is not None, "Please specify `database_name`."
assert group_name is not None, "Please specify `group_name`."
not_none_count = sum(x is not None for x in [metabolite_name, metapathway_name, customerlist_name, ms_pair_name])
assert not_none_count <= 1, (
"Please specify at most one of metabolite_name, metapathway_name, customerlist_name or ms_pair_name."
)
if metabolite_name:
uns_names = metabolite_name
elif metapathway_name:
uns_names = metapathway_name
elif customerlist_name:
uns_names = customerlist_name
elif ms_pair_name:
uns_names = ms_pair_name
else:
uns_names = "total-total"
if permutation_spatial == True:
culster_name = "MetaChat_group_spatial-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names
else:
culster_name = "MetaChat_group-" + group_name + "-" + database_name + '-' + summary + '-' + uns_names
# ==== Prepare data ====
matrix_condition_A = adata_A.uns[culster_name]['communication_matrix'].copy()
pvalue_condition_A = adata_A.uns[culster_name]['communication_pvalue'].copy()
matrix_condition_B = adata_B.uns[culster_name]['communication_matrix'].copy()
pvalue_condition_B = adata_B.uns[culster_name]['communication_pvalue'].copy()
classes = sorted(set(matrix_condition_A.index).union(set(matrix_condition_B.index)))
expanded_matrix_condition_A = pd.DataFrame(0, index=classes, columns=classes)
expanded_matrix_condition_B = pd.DataFrame(0, index=classes, columns=classes)
expanded_matrix_condition_A.loc[matrix_condition_A.index, matrix_condition_A.columns] = matrix_condition_A
expanded_matrix_condition_B.loc[matrix_condition_B.index, matrix_condition_B.columns] = matrix_condition_B
expanded_pvalue_condition_A = pd.DataFrame(1, index=classes, columns=classes)
expanded_pvalue_condition_B = pd.DataFrame(1, index=classes, columns=classes)
expanded_pvalue_condition_A.loc[pvalue_condition_A.index, pvalue_condition_A.columns] = pvalue_condition_A
expanded_pvalue_condition_B.loc[pvalue_condition_B.index, pvalue_condition_B.columns] = pvalue_condition_B
matrix_condition_A = expanded_matrix_condition_A.copy()
pvalue_condition_A = expanded_pvalue_condition_A.copy()
matrix_condition_B = expanded_matrix_condition_B.copy()
pvalue_condition_B = expanded_pvalue_condition_B.copy()
if group_cmap is None:
group_cmap_A = dict(zip(adata_A.obs[group_name].cat.categories.tolist(), adata_A.uns[group_name + '_colors']))
group_cmap_B = dict(zip(adata_B.obs[group_name].cat.categories.tolist(), adata_B.uns[group_name + '_colors']))
group_cmap = {**group_cmap_A, **group_cmap_B}
G_signif = nx.DiGraph()
G_non_signif = nx.DiGraph()
node_sizes = {}
node_colors = {}
for cls in classes:
size_L = np.sum(matrix_condition_A, 1)[cls]
size_M = (np.sum(matrix_condition_A, 0)[cls] + np.sum(matrix_condition_B, 0)[cls])/2
size_R = np.sum(matrix_condition_B, 1)[cls]
color = group_cmap[cls]
G_signif.add_node(cls + "_L", side='left', size=size_L, color=color)
G_signif.add_node(cls + "_M", side='middle', size=size_M, color=color)
G_signif.add_node(cls + "_R", side='right', size=size_R, color=color)
G_non_signif.add_node(cls + "_L", side='left', size=size_L, color=color)
G_non_signif.add_node(cls + "_M", side='middle', size=size_M, color=color)
G_non_signif.add_node(cls + "_R", side='right', size=size_R, color=color)
node_sizes[cls + "_L"] = size_L
node_sizes[cls + "_M"] = size_M
node_sizes[cls + "_R"] = size_R
node_colors[cls + "_L"] = color
node_colors[cls + "_M"] = color
node_colors[cls + "_R"] = color
node_sizes_min_value = min(node_sizes.values())
node_sizes_max_value = max(node_sizes.values())
node_sizes_min_value_new = node_sizes_limit[0]
node_sizes_max_value_new = node_sizes_limit[1]
node_sizes_visual = {}
for node, size in node_sizes.items():
new_size = node_sizes_min_value_new + ((size - node_sizes_min_value) * (node_sizes_max_value_new - node_sizes_min_value_new) / (node_sizes_max_value - node_sizes_min_value))
node_sizes_visual[node] = new_size
edges_signif = []
edges_non_signif = []
edge_sizes_min_value = np.min([np.min(matrix_condition_A), np.min(matrix_condition_B)])
edge_sizes_max_value = np.max([np.max(matrix_condition_A), np.max(matrix_condition_B)])
edge_sizes_min_value_new = edge_sizes_limit[0]
edge_sizes_max_value_new = edge_sizes_limit[1]
for cls_sender in classes:
for cls_receiver in classes:
weight_A = matrix_condition_A.loc[cls_sender,cls_receiver]
weight_A = edge_sizes_min_value_new + ((weight_A - edge_sizes_min_value) * (edge_sizes_max_value_new - edge_sizes_min_value_new) / (edge_sizes_max_value - edge_sizes_min_value))
if pvalue_condition_A.loc[cls_sender,cls_receiver] < p_value_cutoff:
edges_signif.append((cls_sender + "_L", cls_receiver + "_M", weight_A))
else:
edges_non_signif.append((cls_sender + "_L", cls_receiver + "_M", weight_A))
weight_B = matrix_condition_B.loc[cls_sender,cls_receiver]
weight_B = edge_sizes_min_value_new + ((weight_B - edge_sizes_min_value) * (edge_sizes_max_value_new - edge_sizes_min_value_new) / (edge_sizes_max_value - edge_sizes_min_value))
if pvalue_condition_B.loc[cls_sender,cls_receiver] < p_value_cutoff:
edges_signif.append((cls_sender + "_R", cls_receiver + "_M", weight_B))
else:
edges_non_signif.append((cls_sender + "_R", cls_receiver + "_M", weight_B))
G_signif.add_weighted_edges_from(edges_signif)
G_non_signif.add_weighted_edges_from(edges_non_signif)
pos = {}
for node in G_signif.nodes():
if '_L' in node:
pos[node] = (2, len(classes) - classes.index(node[:-2]))
elif '_M' in node:
pos[node] = (4, len(classes) - classes.index(node[:-2]))
else:
pos[node] = (6, len(classes) - classes.index(node[:-2]))
# ==== Draw plot ====
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
edges_signif = G_signif.edges(data=True)
edge_colors_signif = [node_colors[edge[0]] for edge in edges_signif]
edge_widths_signif = [edge[2]['weight'] for edge in edges_signif]
edges_non_signif = G_non_signif.edges(data=True)
edge_colors_non_signif = [node_colors[edge[0]] for edge in edges_non_signif]
edge_widths_non_signif = [edge[2]['weight'] for edge in edges_non_signif]
for node in G_signif.nodes():
if '_M' in node:
nx.draw_networkx_nodes(G_signif, pos, nodelist=[node],
node_color=[node_colors[node]],
node_shape='s',
node_size=node_sizes_visual[node],
ax=ax)
else:
nx.draw_networkx_nodes(G_signif, pos, nodelist=[node],
node_color=[node_colors[node]],
node_size=node_sizes_visual[node],
ax=ax)
labels = {}
labels_pos = {}
for cls in classes:
labels[cls + '_L'] = cls
labels_pos[cls + '_L'] = (pos[cls + '_L'][0]-0.2, pos[cls + '_L'][1])
nx.draw_networkx_labels(G_signif, labels_pos, labels=labels, horizontalalignment='right', ax=ax)
nx.draw_networkx_edges(G_signif, pos, edgelist=edges_signif, edge_color=edge_colors_signif,
width=edge_widths_signif, arrowstyle='-|>', arrowsize=10, alpha=1, ax=ax)
nx.draw_networkx_edges(G_non_signif, pos, edgelist=edges_non_signif, edge_color=edge_colors_non_signif,
width=edge_widths_non_signif, arrowstyle='-|>', arrowsize=10, alpha=alpha, ax=ax)
ax.axis('off')
ax.set_frame_on(False)
ax.set_xlim([-1,6.5])
ax.text(2,len(classes) + 0.8, "Sender", ha='center', va='center', fontsize=12)
ax.text(4,len(classes) + 0.8, "Receiver", ha='center', va='center', fontsize=12)
ax.text(6,len(classes) + 0.8, "Sender", ha='center', va='center', fontsize=12)
ax.arrow(2.35, len(classes) + 0.8, 1.1, 0, head_width=0.3, head_length=0.15, fc='#4F9B79', ec='#4F9B79', linewidth=2)
ax.arrow(5.65, len(classes) + 0.8, -1.1, 0, head_width=0.3, head_length=0.15, fc='#253071', ec='#253071', linewidth=2)
ax.text(2.9,len(classes) + 1.4, condition_name_A, ha='center', va='center', fontsize=14)
ax.text(5.1,len(classes) + 1.4, condition_name_B, ha='center', va='center', fontsize=14)
# ==== Save ====
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_MSpair_contribute_group(
adata: anndata.AnnData,
database_name: str = None,
group_name: str = None,
metabolite_name: str = None,
summary: str = 'sender',
cmap: str = "green",
group_cmap: dict = None,
figsize: tuple = (4,6),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None
):
"""
Plot a heatmap showing group-level contributions of metabolite–sensor pairs for a specific metabolite.
This function visualizes how each cell group contributes to the overall communication
strength of all metabolite–sensor pairs associated with a given metabolite.
It is particularly useful for identifying dominant sender or receiver groups for specific metabolites.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing MetaChat results.
Must include:
- ``adata.uns['df_metasen_filtered']`` : table linking metabolites (HMDB IDs) to sensors.
- ``adata.obsm['MetaChat-<database_name>-sum-<summary>']`` : communication summary matrix.
database_name : str
Name of the metabolite–sensor interaction database used for MCC computation.
group_name : str
Column name in ``adata.obs`` specifying the cell group or cluster identity.
metabolite_name : str
HMDB ID of the metabolite to visualize.
summary : {"sender", "receiver"}, default="sender"
Whether to visualize sender- or receiver-side contributions.
cmap : {"green", "red", "blue"}, default="green"
Colormap for the heatmap.
- "green": metabolic enrichment style
- "red": activity intensity style
- "blue": signal pathway style
group_cmap : dict, optional
Mapping from group names to colors. If ``None``, derived automatically from ``adata.uns[group_name + '_colors']``.
figsize : tuple of float, default=(4, 6)
Figure size (width, height).
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw the heatmap on. If ``None``, a new figure is created.
plot_savepath : str, optional
File path to save the figure (e.g., ``"results/MSpair_contribution_heatmap.pdf"``).
The format is inferred from the file extension. If ``None``, the plot is displayed interactively without saving.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the contribution heatmap.
Notes
-----
- The function requires prior execution of:
1. :func:`mc.tl.metabolic_communication`
2. :func:`mc.tl.summary_communication`
3. :func:`mc.tl.communication_group` or :func:`mc.tl.communication_group_spatial`
- Each column represents a metabolite–sensor pair, and each row represents a cell group.
- Values are log-transformed group-level summed communication scores.
"""
# ==== Check inputs ====
assert database_name is not None, "Please at least specify `database_name`."
assert group_name is not None, "Please at least specify `group_name`."
assert metabolite_name is not None, "Please at least specify `metabolite_name`."
# ==== Prepare data ====
df_metasen = adata.uns['df_metasen_filtered']
name_sensor = df_metasen.loc[df_metasen['HMDB.ID'] == metabolite_name, 'Sensor.Gene'].tolist()
if summary == 'sender':
arrv = 's'
elif summary == 'receiver':
arrv = 'r'
else:
raise ValueError("`summary` must be either 'sender' or 'receiver'.")
ms_pair = [f"{arrv}-{metabolite_name}-{sensor}" for sensor in name_sensor]
ms_pair.sort()
df_MCC = adata.obsm[f"MetaChat-{database_name}-sum-{summary}"].loc[:, ms_pair].copy()
df_MCC[group_name] = adata.obs[group_name].copy()
df_contribute = df_MCC.groupby(group_name).sum()
df_contribute = np.log(df_contribute + 1)
n_colors = 256
if cmap == 'green':
cmap = sns.blend_palette(["#F4FAFC", "#CAE7E0", "#80C0A5", "#48884B", "#1E4621"], n_colors=n_colors)
elif cmap == 'blue':
cmap = sns.blend_palette(["#FAFDFE", "#B7CDE9", "#749FD2", "#4967AC", "#3356A2"], n_colors=n_colors)
elif cmap == 'red':
cmap = sns.blend_palette(["#FFFEF7", "#FCF5B8", "#EBA55A", "#C23532"], n_colors=n_colors)
else:
raise ValueError("`cmap` must be one of {'green', 'red', 'blue'}.")
if group_cmap is None:
group_cmap_dict = dict(
zip(adata.obs[group_name].cat.categories.tolist(), adata.uns[group_name + '_colors'])
)
group_cmap = [group_cmap_dict[g] for g in df_contribute.index]
# ==== Draw plot ====
sns.clustermap(
df_contribute.T,
row_cluster = False,
col_cluster = False,
col_colors = group_cmap,
cmap = cmap,
figsize = figsize,
cbar_pos=(0.02, 0.3, 0.02, 0.4)
)
# ==== Save ====
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_summary_pathway(
ms_result: pd.DataFrame = None,
metapathway_rank: pd.DataFrame = None,
senspathway_rank: pd.DataFrame = None,
plot_metapathway_index: list = None,
plot_senspathway_index: list = None,
figsize: tuple = (10,10),
plot_savepath: str = None
):
"""
Plot a Sankey diagram summarizing metabolic cell communication between metabolite and sensor pathways.
This function visualizes how metabolite pathways contribute to sensor pathways based on
group-level MCC results. Each link represents the overall communication score between
a metabolite pathway (source) and a sensor pathway (target), scaled by interaction intensity.
Parameters
----------
ms_result : pandas.DataFrame
A matrix of pathway-level metabolic communication scores.
Rows correspond to metabolite pathways, columns to sensor pathways.
Typically generated from :func:`mc.tl.summary_pathway`.
metapathway_rank : pandas.DataFrame
DataFrame containing metabolite pathway rankings.
senspathway_rank : pandas.DataFrame
DataFrame containing sensor pathway rankings.
plot_metapathway_index : list of int
Indices (in ``metapathway_rank``) of metabolite pathways to include in the Sankey diagram.
plot_senspathway_index : list of int
Indices (in ``senspathway_rank``) of sensor pathways to include in the Sankey diagram.
figsize : tuple of float, default=(10, 10)
Figure size (width, height) in inches.
plot_savepath : str, optional
Path to save the plot as a static image (e.g., ``"results/pathway_summary.pdf"``).
The format is inferred from the file extension.
If ``None``, the figure will be displayed interactively.
Returns
-------
fig : plotly.graph_objects.Figure
The Sankey diagram figure object.
Notes
-----
- This function requires results computed by :func:`mc.tl.summary_pathway`.
- Each node represents a pathway; link width corresponds to relative communication strength.
- Scores are log-scaled and rescaled for better visual contrast.
"""
# ==== Prepare data ====
palette_1 = sns.color_palette("tab20",20)
hex_colors_1 = [mcolors.to_hex(color) for color in palette_1]
hex_colors_source = [color for index, color in enumerate(hex_colors_1) if index % 2 == 0]
hex_colors_line = [color for index, color in enumerate(hex_colors_1) if index % 2 == 1]
palette_2 = sns.color_palette("YlGnBu",len(plot_senspathway_index))
hex_colors_target = [mcolors.to_hex(color) for color in palette_2][::-1]
usename_metapathway = list(metapathway_rank.iloc[plot_metapathway_index]["Metabolite.Pathway"])
usename_senspathway = list(senspathway_rank.iloc[plot_senspathway_index]["Sensor.Pathway"])
ms_result_new = ms_result.loc[usename_metapathway,usename_senspathway].copy()
all_values = ms_result_new.values.flatten()
non_zero_values = all_values[all_values != 0]
min_non_zero_value = np.min(non_zero_values)
ms_result_new = np.log(ms_result_new/min_non_zero_value + 1)
ms_result_new = ms_result_new.reset_index().copy()
result_all_melted = ms_result_new.melt(
id_vars='Metabolite.Pathway', var_name='Sensor.Pathway', value_name='communication_score'
)
metapathway_color = pd.DataFrame({
'Metabolite.Pathway': np.array(ms_result_new.loc[:,"Metabolite.Pathway"]),
'color_source': np.array(hex_colors_source[:ms_result_new.shape[0]]),
'color_link': np.array(hex_colors_line[:ms_result_new.shape[0]])
})
result_all_melted = pd.merge(result_all_melted, metapathway_color, on='Metabolite.Pathway', how='outer')
NODES = dict(
label = usename_metapathway + usename_senspathway,
color = hex_colors_source[:len(usename_metapathway)] + hex_colors_target[:len(usename_senspathway)]
)
Node_index = pd.DataFrame({
'node': usename_metapathway + usename_senspathway,
'index': list(range(len(usename_metapathway) + len(usename_senspathway)))
})
result_all_melted = pd.merge(result_all_melted, Node_index, left_on='Metabolite.Pathway', right_on = 'node', how='inner')
result_all_melted = pd.merge(result_all_melted, Node_index, left_on='Sensor.Pathway', right_on = 'node', how='inner')
scores = np.array(result_all_melted["communication_score"])
scores_scaled = (scores - scores.min()) / (scores.max() - scores.min())
scaled_width = scores_scaled * 5
LINKS = dict(
source = np.array(result_all_melted["index_x"]).tolist(),
target = np.array(result_all_melted["index_y"]).tolist(),
value = scaled_width.tolist(),
color = np.array(result_all_melted["color_link"]).tolist()
)
data = go.Sankey(node = NODES, link = LINKS)
fig = go.Figure(data)
fig.show(config={"width": figsize[0], "height": figsize[1]})
if plot_savepath:
fig.write_image(plot_savepath, width=figsize[0] * 100, height=figsize[1] * 100)
[docs]
def plot_communication_responseGenes(
df_deg: pd.DataFrame,
df_yhat: pd.DataFrame,
show_gene_names: bool = True,
top_ngene_per_cluster: int = -1,
colormap: str = 'magma',
cluster_colormap: str = 'Plotly',
color_range: tuple = None,
font_scale: float = 1,
figsize: tuple = (10,10),
plot_savepath: str = None,
return_genes: bool = False
):
"""
Plot the smoothed gene expression profiles of metabolic cell communication response genes.
This function visualizes the response genes inferred from metabolic cell communication analysis.
Genes are grouped by response clusters and displayed as a heatmap, where columns
correspond to genes and rows represent the order of received signals.
The function allows highlighting top-ranked genes per cluster based on the Wald statistic
and sorting both genes and clusters according to their expression peak locations.
This function should be run after :func:`mc.tl.communication_responseGenes` and :func:`mc.tl.communication_responseGenes_cluster`.
Parameters
----------
df_deg : pandas.DataFrame
DataFrame results of response gene analysis from :func:`mc.tl.communication_responseGenes` and :func:`mc.tl.communication_responseGenes_cluster`.
df_yhat : pandas.DataFrame
DataFrame containing smoothed and normalized expression values of the same genes
show_gene_names : bool, default=True
Whether to show gene names as x-axis tick labels in the heatmap.
top_ngene_per_cluster : int, default=-1
Number of top genes to display per cluster.
If set to a non-negative value, only the top-ranked genes (by Wald statistic)
within each cluster are plotted.
If ``-1``, all genes are included.
colormap : str, default='magma'
Colormap used for expression values (passed to seaborn/Matplotlib).
cluster_colormap : str, default='Plotly'
Qualitative colormap used to color gene clusters.
Accepts predefined names such as ``'Plotly'``, ``'Light24'``, ``'Dark24'``, or ``'Alphabet'``.
color_range : tuple, optional
Tuple ``(vmin, vmax)`` specifying the lower and upper limits for expression color scaling.
Values outside this range are clipped. If ``None``, full dynamic range is used.
font_scale : float, default=1
Scaling factor for seaborn font sizes.
figsize : tuple of float, default=(10, 10)
Size of the output figure (width, height).
plot_savepath : str, optional
File path to save the figure.
If ``None``, the figure is not saved.
return_genes : bool, default=False
Whether to return the list of genes displayed in the plot.
Returns
-------
selected_genes : list of str, optional
If ``return_genes=True``, returns a list of gene names in the order they appear
in the heatmap. Otherwise, returns ``None``.
Notes
-----
- Clusters are first sorted by the mean position of the peak expression (``np.argmax`` of smoothed values).
- Within each cluster, genes are ranked by the Wald statistic, and optionally truncated
to the top ``N`` genes.
- The resulting heatmap displays genes grouped and color-coded by cluster assignment.
- This function relies on :func:`get_cmap_qualitative` to obtain qualitative color palettes.
"""
import seaborn as sns
import matplotlib.pyplot as plt
cmap = get_cmap_qualitative(cluster_colormap)
wald_stats = df_deg['waldStat'].values
labels = np.array(df_deg['cluster'].values, int)
nlabel = np.max(labels) + 1
yhat_mat = df_yhat.values
if color_range is not None:
yhat_mat[yhat_mat > color_range[1]] = color_range[1]
yhat_mat[yhat_mat < color_range[0]] = color_range[0]
# Sort clusters by mean peak location
peak_locs = []
for i in range(nlabel):
tmp_idx = np.where(labels == i)[0]
tmp_y = yhat_mat[tmp_idx, :]
peak_locs.append(np.mean(np.argmax(tmp_y, axis=1)))
cluster_order = np.argsort(peak_locs)
# Get peak per gene
gene_peak_positions = np.argmax(yhat_mat, axis=1)
idx = np.array([], dtype=int)
col_colors = []
for i in cluster_order:
tmp_idx = np.where(labels == i)[0]
# Step 1: sort by waldStat descending
wald_in_cluster = wald_stats[tmp_idx]
top_order = np.argsort(-wald_in_cluster)
if top_ngene_per_cluster >= 0:
top_ngene = min(len(tmp_idx), top_ngene_per_cluster)
else:
top_ngene = len(tmp_idx)
selected_idx = tmp_idx[top_order[:top_ngene]]
# Step 2: sort those by peak position
peak_pos_selected = gene_peak_positions[selected_idx]
final_order = np.argsort(peak_pos_selected)
sorted_idx = selected_idx[final_order]
idx = np.concatenate((idx, sorted_idx))
for _ in range(len(sorted_idx)):
col_colors.append(cmap[i % len(cmap)])
# Plot
sns.set(font_scale=font_scale)
g = sns.clustermap(df_yhat.iloc[idx].T,
row_cluster=False,
col_cluster=False,
col_colors=col_colors,
cmap=colormap,
xticklabels=show_gene_names,
yticklabels=False,
linewidths=0,
figsize=figsize)
g.ax_heatmap.invert_yaxis()
g.cax.set_position([.1, .2, .03, .45])
if plot_savepath is not None:
plt.savefig(plot_savepath, dpi=300)
if return_genes:
return list(df_deg.iloc[idx].index)
[docs]
def plot_communication_responseGenes_keggEnrich(
df_result: pd.DataFrame,
organism: str = "Human",
show_term_order: list = [0,1,2,3,4],
cmap: str = 'green',
maxshow_gene: int = 5,
figsize: tuple = (6,6),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None
):
"""
Plot a horizontal bar chart summarizing KEGG enrichment results of MCC response genes.
This function visualizes the top enriched KEGG pathways identified from
:func:`mc.tl.communication_responseGenes`, with bar length representing
-log10(p-value) and text annotations showing top associated genes.
Parameters
----------
df_result : pandas.DataFrame
The KEGG enrichment results table returned by :func:`mc.tl.communication_responseGenes`.
organism : {"Human", "Mouse"}, default="Human"
If set to ``"Mouse"``, gene names are capitalized (first letter uppercase, rest lowercase).
show_term_order : list of int, default=[0,1,2,3,4]
List of row indices (in ``df_result``) specifying which pathways to show and their order.
If ``None``, the top 5 terms will be displayed.
cmap : {"green", "blue", "red"}, default="green"
Color theme for the barplot.
maxshow_gene : int, default=10
Maximum number of genes to display under each pathway.
figsize : tuple of float, default=(6, 6)
Figure size (width, height).
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw the barplot on. If ``None``, a new figure and axis are created.
plot_savepath : str, optional
File path to save the figure (e.g., ``"results/kegg_enrich.pdf"``).
The format is inferred from the file extension. If ``None``, the plot is shown interactively.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the bar chart.
Notes
-----
- This visualization summarizes KEGG enrichment performed by
:func:`mc.tl.communication_responseGenes`.
- Each bar represents one KEGG pathway, with its significance shown as -log10(p-value).
- Up to ``maxshow_gene`` top genes per pathway are displayed below the bar label.
"""
# ==== Prepare data ====
df_show = df_result.iloc[show_term_order].copy()
x_names = df_show['Term'].tolist()
x_names.reverse()
x_to_num = {p[1]: p[0] for p in enumerate(x_names)}
path_to_genes = {}
path_to_value = {}
for _, row in df_show.iterrows():
gene_list = row['Genes'].split(';')
# Capitalize gene names if organism is Mouse
if organism.lower() == "mouse":
gene_list = [g.capitalize() for g in gene_list]
genename_show = gene_list[:maxshow_gene]
genename_show = ';'.join(genename_show)
path_to_genes[row['Term']] = genename_show
path_to_value[row['Term']] = -np.log10(row['P-value'])
bar_color = {'blue': '#C9E3F6',
'green': '#ACD3B7',
'red': '#F0C3AC'}
text_color = {'blue': '#2D3A8C',
'green': '#2E5731',
'red': '#AD392F'}
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
x = [x_to_num[v] for v in x_names]
y = [path_to_value[v] for v in x_names]
ax.barh(x, y, color=bar_color[cmap], height=0.5)
ax.set_facecolor('white')
for v in x_names:
ax.text(0 + 0.05, x_to_num[v], v, color='black', ha='left', va='center')
ax.text(0 + 0.05, x_to_num[v] - 0.4, path_to_genes[v], color=text_color[cmap], ha='left', va='center')
ax.set_yticklabels([])
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_ylim([-0.7, max(x) + 1])
ax.set_xlabel('-log10(p-value)')
ax.set_ylabel('KEGG pathway')
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_DEG_volcano(
deg_result: pd.DataFrame,
name_col: str = "names",
logfc_col: str = "logfoldchanges",
pval_col: str = "pvals_adj",
logfc_thresh: float = 1.0,
pval_thresh: float = 0.05,
neglog10_cap: float = 50.0,
pointsize_min: float = 5.0,
pointsize_max: float = 30.0,
label_thresh_neglog10: float = 10.0,
label_fontsize: int = 6,
figsize: tuple = (5, 5),
cbar_orientation: str = "horizontal",
cbar_pad: float = 0.2,
title: str = "Volcano Plot",
show_labels: bool = True,
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None
):
"""
Plot a volcano plot from differential MCC results.
This function visualizes log2 fold changes versus statistical significance
(-log10 adjusted p-values) to highlight up- and down-regulated genes.
Parameters
----------
deg_result : pandas.DataFrame
DataFrame containing differential MCC results.
name_col : str, default='names'
Column name for gene or feature names.
logfc_col : str, default='logfoldchanges'
Column name for log2 fold changes.
pval_col : str, default='pvals_adj'
Column name for adjusted p-values.
logfc_thresh : float, default=1.0
Threshold for absolute log2 fold change used to draw vertical reference lines.
pval_thresh : float, default=0.05
Threshold for adjusted p-value used to draw a horizontal reference line.
neglog10_cap : float, default=50.0
Cap for -log10(p-value) to avoid extreme values distorting visualization.
pointsize_min : float, default=5.0
Minimum scatter point size.
pointsize_max : float, default=30.0
Maximum scatter point size.
label_thresh_neglog10 : float, default=10.0
Threshold of -log10(p-value) above which gene labels are annotated.
label_fontsize : int, default=6
Font size for gene labels.
figsize : tuple of float, default=(5, 5)
Figure size (width, height) in inches.
cbar_orientation : {'horizontal', 'vertical'}, default='horizontal'
Orientation of the colorbar.
cbar_pad : float, default=0.2
Padding distance between colorbar and axis.
title : str, default='Volcano Plot'
Title of the figure.
show_labels : bool, default=True
Whether to show text labels for significant genes.
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw on.
If None, a new figure and axis are created.
plot_savepath : str, optional
File path to save the figure (e.g. ``'results/volcano_plot.pdf'``).
The format is inferred from the file extension.
If None, the figure is shown interactively.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis object containing the rendered volcano plot.
Notes
-----
- Points are colored by log2 fold change using a diverging colormap.
- Vertical and horizontal dashed lines represent significance thresholds.
- Extreme -log10(p-values) are capped at ``neglog10_cap`` to maintain scale.
"""
# ==== Copy and preprocess ====
deg_result = deg_result.copy()
deg_result['neg_log10_pValue'] = -np.log10(deg_result[pval_col])
deg_result['neg_log10_pValue'] = np.clip(deg_result['neg_log10_pValue'], None, neglog10_cap)
# ==== Point size and color normalization ====
deg_result['point_size'] = np.clip(deg_result['neg_log10_pValue'], pointsize_min, pointsize_max)
vmin = deg_result[logfc_col].min()
vmax = deg_result[logfc_col].max()
norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
# ==== Initialize figure ====
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
# ==== Draw scatter ====
scatter = ax.scatter(
x=deg_result[logfc_col],
y=deg_result['neg_log10_pValue'],
c=deg_result[logfc_col],
s=5 * deg_result['point_size'],
cmap='Spectral_r',
norm=norm
)
# ==== Add colorbar ====
cbar = plt.colorbar(scatter, ax=ax, orientation=cbar_orientation, pad=cbar_pad, shrink=0.6)
cbar.set_label('Log2 Fold Change')
# ==== Axis formatting ====
max_fc = np.max(np.abs(deg_result[logfc_col])) + 0.5
ax.set_xlim(-max_fc, max_fc)
ax.set_ylim(-3, neglog10_cap + 5)
ax.set_title(title)
ax.set_xlabel('Log2 Fold Change')
ax.set_ylabel('-Log10 p-value')
# Add threshold lines
ax.axhline(y=-np.log10(pval_thresh), color='black', linestyle='--')
ax.axvline(x=logfc_thresh, color='black', linestyle='--')
ax.axvline(x=-logfc_thresh, color='black', linestyle='--')
# Annotate significant points with adjustText
if show_labels:
sig_df = deg_result[
(np.abs(deg_result[logfc_col]) > logfc_thresh) &
(deg_result['neg_log10_pValue'] > label_thresh_neglog10)
]
texts = []
for _, row in sig_df.iterrows():
texts.append(
ax.text(
row[logfc_col],
row['neg_log10_pValue'],
str(row[name_col]),
fontsize=label_fontsize
)
)
adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', lw=0.5))
ax.set_box_aspect(1)
plt.tight_layout()
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_3d_feature(
adata: anndata.AnnData,
feature: str,
spatial_key: str = "spatial_3d",
cmap_continuous: str = "Viridis",
layer: str = None,
use_raw: bool = False,
point_size: int = 3,
opacity = 0.6,
aspectratio: tuple = (1.1, 1.0, 1.3),
show_axes: bool = True,
figsize: tuple = (9, 9),
plot_savepath: str = None
):
"""
Visualize a spatial feature (gene or obs annotation) in 3D using Plotly.
This function renders a 3D scatter plot for any feature stored in `.obs`,
`.var`, `.layers`, or `.raw`, using the spatial coordinates stored in
``adata.obsm[spatial_key]``.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with 3D spatial coordinates stored in ``adata.obsm[spatial_key]``.
feature : str
Name of the feature to visualize. It can be:
- A column name in ``adata.obs`` (categorical or continuous);
- A gene name in ``adata.var_names`` or ``adata.raw.var_names``.
spatial_key : str, default='spatial_3d'
Key in ``adata.obsm`` containing the 3D spatial coordinates (shape: ``(n_obs, 3)``).
cmap_continuous : str, default='Viridis'
Colormap for continuous features.
layer : str, optional
If provided, use expression values from ``adata.layers[layer]``.
use_raw : bool, default=False
Whether to use ``adata.raw`` instead of ``adata.X`` when available.
point_size : int, default=3
Size of scatter plot markers.
opacity : float or list, default=0.6
Global opacity (float) or per-category opacity (list, same length as categories).
aspectratio : tuple of float, default=(1.1, 1.0, 1.3)
Aspect ratio (x, y, z) for 3D axes scaling.
show_axes : bool, default=True
Whether to display 3D axis lines and labels.
figsize : tuple of float, default=(9, 9)
Figure size in inches (controls pixel width/height of the Plotly figure).
plot_savepath : str, optional
Path to save the 3D figure (e.g. ``"results/3D_feature_geneA.pdf"``).
If None, the figure is displayed interactively.
Returns
-------
fig : plotly.graph_objects.Figure
The generated Plotly 3D scatter figure.
Notes
-----
- If the feature is categorical, it is colored using discrete color sequences.
- If the feature is continuous (e.g., gene expression), a continuous colormap is applied.
- This function is ideal for visualizing 3D tissue reconstructions, such as those derived
from MetaChat’s 3D MCC analysis.
"""
if spatial_key not in adata.obsm:
raise KeyError(f"`obsm['{spatial_key}']` not found or not set to 3D coords.")
coords = np.asarray(adata.obsm[spatial_key])
if coords.ndim != 2 or coords.shape[1] != 3:
raise ValueError(f"`obsm['{spatial_key}']` must be of shape (n_cells, 3).")
n = coords.shape[0]
df = adata.obs.copy()
df["x_3d"], df["y_3d"], df["z_3d"] = coords[:, 0], coords[:, 1], coords[:, 2]
def _to_1d_array(x):
if sparse.issparse(x):
x = x.toarray()
x = np.asarray(x)
if x.ndim == 2:
if x.shape[0] == 1:
x = x[0]
elif x.shape[1] == 1:
x = x[:, 0]
return x.reshape(-1)
is_obs_feature = feature in adata.obs.columns
gene_vec = None
if not is_obs_feature:
if use_raw and getattr(adata, "raw", None) is not None and feature in adata.raw.var_names:
X = adata.raw[:, feature].X
gene_vec = _to_1d_array(X)
elif feature in adata.var_names:
if layer is not None:
if layer not in adata.layers:
raise KeyError(f"Layer '{layer}' not found in `adata.layers`.")
X = adata[:, feature].layers[layer]
else:
X = adata[:, feature].X
gene_vec = _to_1d_array(X)
if gene_vec is not None and gene_vec.shape[0] != n:
raise ValueError(f"Gene vector length ({gene_vec.shape[0]}) does not match number of cells ({n}).")
# --- main plotting ---
if is_obs_feature:
df["feature"] = adata.obs[feature]
is_categorical = pd.api.types.is_categorical_dtype(df["feature"]) or df["feature"].dtype == "category"
if is_categorical:
color_key = feature + "_colors"
colors = adata.uns.get(color_key, None)
category_order = list(df["feature"].cat.categories) if hasattr(df["feature"], "cat") else sorted(df["feature"].unique())
fig = px.scatter_3d(
df, x="x_3d", y="y_3d", z="z_3d",
color="feature",
color_discrete_sequence=colors if colors is not None else None,
category_orders={"feature": category_order},
title=f"3D Spatial - obs: {feature}",
)
# --- opacity by category ---
if isinstance(opacity, (list, tuple)) and len(opacity) >= len(category_order):
for i, trace in enumerate(fig.data):
trace.opacity = opacity[i]
else:
for trace in fig.data:
trace.opacity = opacity if isinstance(opacity, (int, float)) else 0.6
else:
try:
df["feature"] = pd.to_numeric(df["feature"])
except Exception:
pass
fig = px.scatter_3d(
df, x="x_3d", y="y_3d", z="z_3d",
color="feature",
color_continuous_scale=cmap_continuous,
title=f"3D Spatial - obs: {feature}",
)
for trace in fig.data:
trace.opacity = opacity if isinstance(opacity, (int, float)) else 0.6
else:
if gene_vec is None:
candidates = []
if getattr(adata, "raw", None) is not None:
candidates.append("adata.raw.var_names")
candidates.append("adata.var_names")
raise KeyError(f"Feature '{feature}' not found in `adata.obs` or as a gene in {', '.join(candidates)}.")
df["feature"] = gene_vec
fig = px.scatter_3d(
df, x="x_3d", y="y_3d", z="z_3d",
color="feature",
color_continuous_scale="Viridis",
title=f"3D Spatial - gene: {feature}",
)
for trace in fig.data:
trace.opacity = opacity if isinstance(opacity, (int, float)) else 0.6
fig.update_traces(marker_size=point_size)
fig.update_layout(
height=figsize[0]*100,
width=figsize[1]*100,
scene=dict(
aspectratio=dict(
x=float(aspectratio[0]),
y=float(aspectratio[1]),
z=float(aspectratio[2])
),
xaxis=dict(title="X" if show_axes else "", showbackground=False, showgrid=False,
zeroline=show_axes, showline=show_axes,
linecolor="black", linewidth=3 if show_axes else 0,
showticklabels=show_axes),
yaxis=dict(title="Y" if show_axes else "", showbackground=False, showgrid=False,
zeroline=show_axes, showline=show_axes,
linecolor="black", linewidth=3 if show_axes else 0,
showticklabels=show_axes),
zaxis=dict(title="Z" if show_axes else "", showbackground=False, showgrid=False,
zeroline=show_axes, showline=show_axes,
linecolor="black", linewidth=3 if show_axes else 0,
showticklabels=show_axes),
),
margin=dict(r=0, l=0, b=0, t=30),
)
fig.show()
if plot_savepath:
fig.write_image(plot_savepath, width=figsize[0]*100, height=figsize[1]*100)
[docs]
def plot_dis_thr(
adata: anndata.AnnData,
dis_thr: float,
spot_index: int,
use_existing_distance: bool = False,
figsize: tuple = (6, 5),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None,
):
"""
Visualize spatial neighborhood within a specified distance threshold around a selected spot.
This function highlights which spots in the spatial omics dataset fall within a given Euclidean distance (`dis_thr`) from a selected spot.
The center spot is shown in a distinct color, while neighboring and non-neighboring spots are visually separated.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing spatial coordinates in ``adata.obsm["spatial"]``.
dis_thr : float
Distance threshold (in pixel or micron units, depending on ``adata`` scaling) used to define spatial neighbors.
spot_index : int
Index (0-based) of the reference/center spot.
use_existing_distance : bool, default=False
Whether to reuse an existing precomputed distance matrix stored in
``adata.obsp["spatial_distance"]``.
If False, a new Euclidean distance matrix will be computed.
figsize : tuple of float, default=(6, 5)
Figure size (width, height) in inches.
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw on.
If None, a new figure and axis are created.
plot_savepath : str, optional
File path to save the plot (e.g., ``"results/dis_thr_spots.pdf"``).
The format is inferred from the file extension.
If None, the plot is shown interactively.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the rendered spatial scatter plot.
Notes
-----
- This function helps visualize local spot neighborhoods used in spatial communication
or proximity-based analyses (e.g., for barrier-aware OT or diffusion modeling).
- The color coding is:
* **center** – the selected reference spot
* **neighbor** – spots within `dis_thr` distance
* **outside** – all other spots
"""
adata_vis = adata.copy()
# compute or reuse distance matrix
if use_existing_distance and 'spatial_distance' in adata_vis.obsp:
dis_mat = adata_vis.obsp['spatial_distance']
else:
dis_mat = distance_matrix(adata_vis.obsm["spatial"], adata_vis.obsm["spatial"])
adata_vis.obsp['spatial_distance'] = dis_mat
# compute mask: 0 = outside, 1 = within dis_thr, 2 = center
mask = (dis_mat[spot_index, :] < dis_thr).astype(int)
spot_name = adata_vis.obs.index[spot_index]
mask_series = pd.Series(mask, index=adata_vis.obs.index)
mask_series.loc[spot_name] = 2
# map numeric to string labels
label_map = {0: 'outside', 1: 'neighbor', 2: 'center'}
label_series = mask_series.map(label_map).astype('category')
# assign to obs
adata_vis.obs['within_dis_thr'] = label_series
fig, ax = plt.subplots(figsize = figsize)
sq.pl.spatial_scatter(
adata_vis,
color='within_dis_thr',
title=f'Spots within dis_thr={dis_thr} of spot {spot_index}',
ax = ax
)
ax.set_box_aspect(1)
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_LRC_markers(
adata: anndata.AnnData,
LRC_name: str,
LRC_marker_genes: list,
avg: bool = False,
figsize = (10, 5),
plot_savepath: str = None
):
"""
Visualize expression of LRC (Long-Range Channel) marker genes in spatial omics data.
This function plots the spatial expression of selected LRC marker genes.
It can either visualize individual genes separately or compute and plot the
average expression of all selected markers.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with spatial coordinates stored in ``adata.obsm["spatial"]``.
LRC_name : str
Name tag for the LRC type (e.g., "CSF", "Blood").
Used in plot titles and output naming.
LRC_marker_genes : list
List of LRC marker gene names to visualize.
Genes not found in ``adata.var_names`` are automatically skipped.
avg : bool, default=False
If True, plots the average expression of all marker genes.
If False, plots each gene separately.
figsize : tuple of float, default=(10, 5)
Figure size (width, height) in inches.
plot_savepath : str, optional
Path to save the figure (e.g. ``"results/LRC_markers_Blood.pdf"``).
If None, the figure is displayed interactively.
Returns
-------
ax : matplotlib.axes.Axes or list of Axes
Matplotlib axis or list of axes containing the rendered plot(s).
Notes
-----
- Marker genes are filtered to retain only those present in ``adata.var_names``.
- Average expression is computed directly from ``adata.X`` (converted to dense if needed).
- When multiple genes are plotted, they are arranged in a single row of subplots.
"""
# Filter out genes not present
valid_genes = [g for g in LRC_marker_genes if g in adata.var_names]
if len(valid_genes) == 0:
raise ValueError("None of the input marker genes were found in adata.var_names.")
if avg:
expr = adata[:, valid_genes].X
if not isinstance(expr, np.ndarray):
expr = expr.toarray()
avg_expr = expr.mean(axis=1)
obs_key = f"LRC_{LRC_name}_avg_markers"
adata.obs[obs_key] = avg_expr
fig, ax = plt.subplots(figsize=figsize)
sq.pl.spatial_scatter(adata, color=obs_key, ax=ax)
ax.set_title(obs_key)
ax.set_box_aspect(1)
plt.tight_layout()
else:
n_genes = len(valid_genes)
fig, axes = plt.subplots(1, n_genes, figsize=(6 * n_genes, 5))
if n_genes == 1:
axes = [axes]
for gene, ax in zip(valid_genes, axes):
obs_key = f"LRC_{LRC_name}_marker_{gene}"
adata.obs[obs_key] = adata[:, gene].X.toarray().ravel()
sq.pl.spatial_scatter(adata, color=gene, ax=ax)
ax.set_title(f"{gene}")
ax.set_box_aspect(1)
plt.suptitle(f"LRC markers for {LRC_name}", fontsize=14)
plt.tight_layout()
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_spot_distance(
adata: anndata.AnnData,
dist_matrix_key: str,
spot_index: int,
figsize: tuple = (6, 5),
ax: Optional[mpl.axes.Axes] = None,
plot_savepath: str = None
):
"""
Visualize the spatial distance from a selected spot to all other spots.
This function displays the distance values (from a specified spot) stored
in a precomputed distance matrix under ``adata.obsp``.
It helps assess local or long-range connectivity patterns in spatial data.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix containing spatial coordinates in
``adata.obsm["spatial"]`` and a distance matrix stored in ``adata.obsp``.
dist_matrix_key : str
Key name in ``adata.obsp`` where the spatial distance matrix is stored,
e.g., ``"spatial_distance"`` or ``"spatial_distance_LRC_base"``.
spot_index : int
Index of the target spot (0-based) whose distances to all other spots
will be visualized.
figsize : tuple of float, default=(6, 5)
Figure size (width, height) in inches.
plot_savepath : str, optional
Path to save the plot (e.g. ``"results/distance_to_spot120.pdf"``).
The format is inferred from the file extension.
If ``None``, the figure is shown interactively.
ax : matplotlib.axes.Axes, optional
Existing Matplotlib axis to draw the plot on.
If ``None``, a new figure and axis will be created.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the rendered spatial distance visualization.
Notes
-----
- Distances are taken directly from the precomputed matrix in ``adata.obsp``.
- The distances are stored in ``adata.obs["distance_from_target"]`` for plotting.
- The selected target spot (by index) can be visually identified as having
the minimum distance (0).
"""
adata_vis = adata.copy()
# ==== Extract distance vector ====
if dist_matrix_key not in adata_vis.obsp:
raise KeyError(f"Distance matrix '{dist_matrix_key}' not found in adata.obsp.")
dist_matrix = adata.obsp[dist_matrix_key]
distances_from_target = dist_matrix[spot_index]
adata_vis.obs["distances_from_target"] = distances_from_target
# ==== Plot ====
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
sq.pl.spatial_scatter(
adata_vis,
color="distances_from_target",
ax=ax,
title=f"Distance to spot index {spot_index}"
)
ax.set_box_aspect(1)
ax.set_xlabel("X coordinate")
ax.set_ylabel("Y coordinate")
# ax.set_ylim(ax.get_ylim()[1], ax.get_ylim()[0]) # Flip y-axis
plt.tight_layout()
plt.show()
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
return ax
[docs]
def plot_graph_connectivity(
G: nx.Graph,
node_size: float = 10.0,
linewidths: float = 1.0,
width: float = 1.0,
show_weights: bool = False,
weight_fontsize: int = 8,
figsize: tuple = (8, 8),
plot_savepath: str = None
):
"""
Plot a 2D visualization of a graph showing connectivity between nodes and edges.
This function visualizes a graph structure (e.g., communication connectivity or
spatial relationships) using node coordinates as (x, y) positions.
Optionally, edge weights can be displayed as labels.
Parameters
----------
G : networkx.Graph
Input graph where each node represents a 2D point (tuple of x, y coordinates)
and edges may have an optional ``'weight'`` attribute.
node_size : float, default=10.0
Marker size of nodes.
linewidths : float, default=1.0
Line width of node borders.
width : float, default=1.0
Width of edges.
show_weights : bool, default=False
Whether to display numerical edge weights on the plot.
weight_fontsize : int, default=8
Font size of edge weight annotations (effective only if ``show_weights=True``).
figsize : tuple of float, default=(8, 8)
Figure size (width, height) in inches.
plot_savepath : str, optional
File path to save the figure (e.g., ``"results/graph_connectivity.png"``).
If ``None``, the figure is shown interactively.
Returns
-------
ax : matplotlib.axes.Axes
Matplotlib axis containing the rendered graph.
Notes
-----
- Node coordinates are assumed to be stored as the node keys, e.g. ``(x, y)`` tuples.
- Edge weights (if any) are visualized with labels when ``show_weights=True``.
- Commonly used for inspecting graph connectivity matrices derived from spatial MCC data.
"""
# ==== Prepare data ====
pos = {node: node for node in G.nodes()}
plt.figure(figsize=figsize)
nx.draw(
G,
pos,
node_size=node_size,
edge_color="gray",
alpha=0.5,
linewidths=linewidths,
width=width
)
# Display edge weights if enabled
if show_weights:
edge_labels = {
(u, v): f"{data['weight']:.2f}"
for u, v, data in G.edges(data=True)
if 'weight' in data
}
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, font_size=weight_fontsize
)
plt.title("Graph Connectivity with Edge Weights" if show_weights else "Graph Connectivity")
plt.xlabel("X")
plt.ylabel("Y")
plt.show()
if plot_savepath:
plt.savefig(plot_savepath, dpi=300, bbox_inches="tight")
[docs]
def plot_direction_similarity(
df_direction: pd.DataFrame,
cluster_labels: np.ndarray,
cmap="mako",
figsize=(9, 7),
title="Flow pattern similarity (Euclidean-based)",
savepath=None
):
"""
Plot a block-ordered similarity heatmap for direction-based flow clusters.
This function visualizes the pairwise similarity of MCC flow direction histograms
across all metabolite–sensor pairs, ordered by K-means cluster labels.
The resulting heatmap highlights within-cluster consistency and
between-cluster differences in flow directionality patterns.
The similarity between two pairs is defined as:
``S = 1 - D / D.max()``,
where ``D`` is the pairwise Euclidean distance between direction histograms.
Thus, S ranges from 0 (completely dissimilar) to 1 (identical).
Parameters
----------
df_direction : pandas.DataFrame
Matrix of direction histogram features (rows = metabolite–sensor pairs,
columns = direction bins, typically 18 bins).
cluster_labels : numpy.ndarray or array-like
Cluster assignments for each M–S pair, usually obtained from K-means or
hierarchical clustering.
cmap : str or matplotlib Colormap, default="mako"
Colormap for the similarity heatmap. Can be any seaborn or matplotlib colormap.
figsize : tuple of (float, float), default=(9, 7)
Figure size in inches.
title : str, default="Flow pattern similarity (Euclidean-based)"
Title displayed at the top of the plot.
savepath : str or None, optional
If provided, saves the figure to the specified file path.
Returns
-------
None
Displays a block-ordered heatmap showing inter-pair similarity
and cluster boundaries.
Notes
-----
- The similarity matrix is computed as ``1 - D / D.max()``,
where D is the Euclidean distance matrix among all histograms.
- The function automatically sorts rows and columns by cluster label
to create block-like patterns for visual cluster inspection.
- Colored sidebars indicate cluster membership along both axes.
- A legend summarizes each flow pattern cluster with its sample size.
"""
# ---------- Compute similarity ----------
X = df_direction.values.astype(float)
from sklearn.metrics import pairwise_distances
D = pairwise_distances(X, metric="euclidean")
S = 1 - D / D.max()
np.fill_diagonal(S, 1.0)
# ---------- Order by cluster ----------
order = np.argsort(cluster_labels)
S_ord = S[order][:, order]
labs_ord = cluster_labels[order]
uniq = np.unique(labs_ord)
# ---------- Cluster color mapping ----------
palette = list(mcolors.TABLEAU_COLORS.values())
cmap_by_cluster = {c: palette[i % len(palette)] for i, c in enumerate(uniq)}
# ---------- Cluster bounds ----------
bounds, sizes = [], {}
start = 0
for c in uniq:
k = int((labs_ord == c).sum())
sizes[c] = k
bounds.append((c, start, start + k))
start += k
# ---------- Plot ----------
fig, ax = plt.subplots(figsize=figsize)
sns.heatmap(
S_ord,
vmin=np.percentile(S, 2),
vmax=np.percentile(S, 98),
cmap = cmap,
square=True,
xticklabels=False,
yticklabels=False,
cbar_kws={"label": "Similarity (1 - normalized Euclidean distance)"}
)
# White grid lines
n = S_ord.shape[0]
ax.plot(np.arange(n) + 0.5, np.arange(n) + 0.5, color="white", lw=0.5, zorder=5)
for _, s, e in bounds:
ax.axhline(s, color="white", lw=0.6)
ax.axhline(e, color="white", lw=0.6)
ax.axvline(s, color="white", lw=0.6)
ax.axvline(e, color="white", lw=0.6)
# Sidebars
bar_width = 3.0
for c, s, e in bounds:
ax.add_patch(patches.Rectangle((-bar_width, s), bar_width, (e - s),
facecolor=cmap_by_cluster[c], edgecolor="none",
transform=ax.transData, clip_on=False, zorder=10))
ax.add_patch(patches.Rectangle((s, -bar_width), (e - s), bar_width,
facecolor=cmap_by_cluster[c], edgecolor="none",
transform=ax.transData, clip_on=False, zorder=10))
# Legend
handles = [
mlines.Line2D([], [], color=cmap_by_cluster[c], marker="s", linestyle="None",
markersize=8, label=f"Pattern {c} (n={sizes[c]})")
for c in uniq
]
ax.legend(
handles=handles,
title="Flow patterns",
frameon=False,
ncol=1,
loc="upper left",
bbox_to_anchor=(1.2, 1.0),
borderaxespad=0.0
)
ax.set_title(title)
plt.tight_layout()
if savepath:
fig.savefig(savepath, dpi=300, bbox_inches="tight")
plt.show()
[docs]
def plot_3d_LRC_with_two_slices(
adata,
mask_key: str,
spatial_key: str = "spatial_3d",
z_levels: list = None, # e.g., [z1, z2]; if None, auto quantiles (0.3, 0.7)
slab_ratio: float = 0.02, # thinner slab so it's cleaner
flatten_z: bool = True, # flatten points in slice to exactly z0
bg_color: str = "#b0b0b0",
ch_color: str = "#ff7f0e",
plane_opacity: float = 0.35,
):
"""
Visualize 3D long-range channels (LRCs) with two representative z-slice views.
This function creates a 3D interactive Plotly visualization showing the overall
distribution of long-range channels (LRCs) and two sectional slices at selected
z-levels. It helps interpret the spatial continuity of LRCs across depth and
verify their anatomical localization relative to tissue layers.
The figure layout consists of:
1. **Left panel** — full 3D overview of the tissue with highlighted LRC voxels.
2. **Middle and right panels** — top-down slice views at two z-planes,
showing LRCs within thin slabs (controlled by `slab_ratio`).
Parameters
----------
adata : anndata.AnnData
Annotated data object containing 3D spatial coordinates in ``.obsm[spatial_key]``
and a binary mask column in ``.obs[mask_key]`` indicating LRC membership.
mask_key : str
Key in ``adata.obs`` corresponding to a boolean array marking LRC points
(e.g., ``'LRC_CSF_manual_filtered'`` or ``'LRC_Blood_auto'``).
spatial_key : str, default="spatial_3d"
Key in ``adata.obsm`` specifying 3D spatial coordinates (shape: N × 3).
z_levels : list of float, optional
Two z-levels to use for slice visualization.
If ``None``, the function automatically selects the 0.3 and 0.7 quantiles
of the z-coordinate distribution.
slab_ratio : float, default=0.02
Fraction of the z-range defining the slab thickness for each slice.
Smaller values yield thinner and cleaner cross-sections.
flatten_z : bool, default=True
If True, all points within a slice are flattened to the central z-plane
(for a 2D-like appearance). If False, preserves their original depth variation.
bg_color : str, default="#b0b0b0"
Color for non-LRC (background) points.
ch_color : str, default="#ff7f0e"
Color for highlighted LRC points (channels).
plane_opacity : float, default=0.35
Opacity of the grey slicing planes shown in the main 3D view.
Returns
-------
fig : plotly.graph_objects.Figure
A Plotly figure object containing three synchronized 3D scenes:
the main tissue view and two z-slice panels.
Notes
-----
- The function uses Plotly's ``make_subplots`` with three 3D scenes arranged
horizontally for combined visualization.
- The selected z-levels are marked by semi-transparent planes in the main panel.
- Each slice subplot is rendered with an orthographic top-down projection
to emphasize spatial distribution patterns rather than depth.
- This function is useful for verifying whether manually or automatically
identified LRC regions form continuous 3D paths through the tissue volume.
"""
# --- prepare data ---
coords = np.asarray(adata.obsm[spatial_key])
obs = adata.obs
x, y, z = coords[:,0], coords[:,1], coords[:,2]
m = obs[mask_key].astype(bool).to_numpy()
zmin, zmax = float(z.min()), float(z.max())
if z_levels is None:
z_levels = list(np.quantile(z, [0.3, 0.7]))
thick = (zmax - zmin) * slab_ratio
# --- layout: 1 row (main + two slices) ---
fig = make_subplots(
rows=1, cols=3,
specs=[[{"type":"scene"},{"type":"scene"},{"type":"scene"}]],
column_widths=[0.5, 0.25, 0.25],
subplot_titles=("LRC overview",
f"Slice z≈{z_levels[0]:.2f}",
f"Slice z≈{z_levels[1]:.2f}")
)
# ===== Main plot =====
# background
fig.add_trace(go.Scatter3d(
x=x[~m], y=y[~m], z=z[~m],
mode="markers",
marker=dict(size=1, color=bg_color, opacity=0.1),
name="background"
), row=1, col=1)
# channel
fig.add_trace(go.Scatter3d(
x=x[m], y=y[m], z=z[m],
mode="markers",
marker=dict(size=3, color=ch_color, opacity=0.7),
name="channel"
), row=1, col=1)
# planes
xx = np.linspace(x.min(), x.max(), 2)
yy = np.linspace(y.min(), y.max(), 2)
X, Y = np.meshgrid(xx, yy)
for z0 in z_levels:
Z = np.full_like(X, z0)
fig.add_trace(go.Surface(
x=X, y=Y, z=Z,
showscale=False, opacity=plane_opacity,
surfacecolor=np.zeros_like(X),
colorscale=[[0,"grey"],[1,"grey"]],
hoverinfo="skip"
), row=1, col=1)
# ===== helper: add one slice subplot =====
def add_slice(col, z0):
sel = (z >= z0 - thick/2) & (z <= z0 + thick/2)
xb, yb, zb = x[sel & ~m], y[sel & ~m], z[sel & ~m]
xc, yc, zc = x[sel & m], y[sel & m], z[sel & m]
if flatten_z:
zb = np.full_like(zb, z0) # flatten to one layer
zc = np.full_like(zc, z0)
# background (faint)
fig.add_trace(go.Scatter3d(
x=xb, y=yb, z=zb,
mode="markers",
marker=dict(size=2.5, color=bg_color, opacity=0.1),
showlegend=False
), row=1, col=col)
# channel (highlight)
fig.add_trace(go.Scatter3d(
x=xc, y=yc, z=zc,
mode="markers",
marker=dict(size=3, color=ch_color, opacity=0.9),
showlegend=False
), row=1, col=col)
# top-down orthographic view to avoid multi-layer look
fig.update_layout(**{
f"scene{col}": dict(
camera=dict(projection=dict(type="orthographic"),
eye=dict(x=0., y=0., z=2.0), up=dict(x=0, y=1, z=0)),
xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)
)
})
add_slice(2, z_levels[0])
add_slice(3, z_levels[1])
# main scene style
fig.update_layout(
scene=dict(
camera=dict(projection=dict(type="orthographic")),
xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)
),
margin=dict(l=0,r=0,t=40,b=0),
paper_bgcolor="white"
)
return fig