"""Visualization helpers for grouped metrics and summary statistics.
This module provides reusable, study-agnostic metric plots for grouped
comparisons, paired subject trajectories, and distribution summaries.
Input model
-----------
Grouped-stat functions in this module assume column-oriented input:
1. Mapping-like object with ``.items()`` (for example: ``dict``).
2. Columns should be 1D and aligned by row.
3. Metric columns should be numeric when used in computations.
Public plots
------------
1. :func:`plot_metric_bars`
2. :func:`plot_tradeoff_scatter`
3. :func:`plot_metric_comparison`
4. :func:`plot_metric_slopes`
5. :func:`plot_metric_violins`
6. :func:`plot_null_distribution`
7. :func:`plot_forest`
8. :func:`plot_harmonic_attenuation` (line-noise-specific helper)
Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""
from __future__ import annotations
from types import MappingProxyType
import numpy as np
from ..qa import peak_attenuation_db
from ._seaborn import _suppress_seaborn_plot_warnings, _try_import_seaborn
from .theme import (
COLORS,
FONTS,
_finalize_fig,
get_series_color,
style_axes,
themed_figure,
themed_legend,
use_theme,
)
_STATS_STYLE = MappingProxyType(
{
"bar_alpha": 0.85,
"bar_linewidth": 0.5,
"bar_capsize": 3,
"scatter_size": 80,
"scatter_alpha": 0.8,
"scatter_edge_linewidth": 0.5,
"mean_scatter_size": 200,
"mean_marker_size": 8,
"mean_linewidth": 2.0,
"subject_trace_alpha": 0.3,
"subject_trace_marker_size": 4,
"paired_line_alpha": 0.1,
"paired_linewidth": 0.4,
"reference_linewidth": 0.8,
"reference_alpha": 0.5,
"annotation_star_size": 14,
"strip_size": 3,
"strip_alpha": 0.7,
"strip_jitter": 0.12,
"forest_marker_size": 5,
"forest_baseline_mean_marker_size": 9,
"forest_pooled_marker_size": 10,
"hist_alpha": 0.5,
"hist_linewidth": 0.5,
"legend_fontsize_small": 7,
}
)
def _plot_subject_trajectories(
ax,
subject_values,
groups,
metric_values,
group_order,
*,
style="o-",
alpha=None,
linewidth=None,
markersize=None,
zorder=None,
):
"""Plot paired subject trajectories and return finite group means."""
subjects = list(dict.fromkeys(np.asarray(subject_values, dtype=object).tolist()))
traces = np.full((len(subjects), len(group_order)), np.nan, dtype=float)
for subject_idx, subject in enumerate(subjects):
subject_mask = subject_values == subject
for group_idx, group in enumerate(group_order):
vals = metric_values[subject_mask & (groups == group)]
vals = vals[np.isfinite(vals)]
if vals.size:
traces[subject_idx, group_idx] = vals[0]
kws = {
"color": COLORS["stat_subject"],
"alpha": _STATS_STYLE["subject_trace_alpha"] if alpha is None else alpha,
}
if markersize is not None:
kws["markersize"] = markersize
if linewidth is not None:
kws["lw"] = linewidth
if zorder is not None:
kws["zorder"] = zorder
ax.plot(range(len(group_order)), traces[subject_idx], style, **kws)
return np.nanmean(traces, axis=0)
[docs]
def plot_window_count_series(
counts,
ax=None,
show=True,
fname=None,
):
"""Plot a per-window count or metric series."""
counts = np.asarray(counts, dtype=float)
if counts.ndim != 1 or counts.size == 0:
raise ValueError("counts must be a non-empty 1D array.")
if ax is None:
fig, ax = themed_figure(figsize=(9, 3.5))
else:
fig = ax.figure
x = np.arange(counts.size)
ax.bar(
x,
counts,
color=COLORS["primary"],
alpha=_STATS_STYLE["bar_alpha"],
linewidth=_STATS_STYLE["bar_linewidth"],
)
ax.axhline(
float(np.mean(counts)),
color=COLORS["accent"],
linestyle="--",
linewidth=1.0,
label=f"Mean ({np.mean(counts):.3g})",
)
ax.set_xlabel("Window")
ax.set_ylabel("Count")
ax.set_title("Window Count Series")
style_axes(ax, grid=True)
themed_legend(ax, loc="best")
return _finalize_fig(fig, show=show, fname=fname)
[docs]
def plot_metric_bars(
data,
metric_cols,
metric_labels=None,
lower_better=None,
group_col="group",
group_order=None,
group_colors=None,
group_labels=None,
title="Metric Comparison (group mean ± SEM)",
fname=None,
show=True,
):
"""Plot grouped bar charts for one or more scalar metrics.
Parameters
----------
data : mapping of str to array-like
Columnar data mapping. All columns must be 1D and have equal length.
metric_cols : list of str
Metric columns to visualize.
metric_labels : list of str | None
Axis labels corresponding to ``metric_cols``. If None, labels are
derived from metric names.
lower_better : list of bool | None
Whether smaller values indicate better performance for each metric.
If None, no best-marker star is added.
group_col : str
Column name identifying comparison groups.
group_order : list of str | None
Explicit order for group bars. If None, first-seen order is used.
group_colors, group_labels : dict | None
Optional color/label overrides keyed by group name.
title : str
Figure-level title.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_metric_bars
>>> data = {
... "group": np.array(["A", "A", "B", "B"]),
... "score": np.array([0.9, 1.0, 0.7, 0.8]),
... }
>>> fig = plot_metric_bars(
... data, metric_cols=["score"], group_col="group", show=False
... )
"""
columns = {name: np.asarray(values) for name, values in data.items()}
groups = np.asarray(columns[group_col], dtype=object)
metric_cols = list(metric_cols)
if metric_labels is None:
metric_labels = [str(col).replace("_", " ").strip() for col in metric_cols]
else:
metric_labels = list(metric_labels)
if lower_better is None:
lower_better = [None] * len(metric_cols)
else:
lower_better = list(lower_better)
if group_order is None:
group_order = list(dict.fromkeys(groups.tolist()))
else:
group_order = list(group_order)
with use_theme():
n_metrics = len(metric_cols)
fig, axes = themed_figure(1, n_metrics, figsize=(4 * n_metrics, 5))
if n_metrics == 1:
axes = np.array([axes])
for axis_index, (col, label, is_lower_better) in enumerate(
zip(metric_cols, metric_labels, lower_better)
):
ax = axes[axis_index]
metric_values = np.asarray(columns[col], dtype=float)
means, sems = [], []
for group in group_order:
vals = metric_values[groups == group]
vals = vals[np.isfinite(vals)]
means.append(vals.mean() if vals.size else np.nan)
if vals.size > 1:
sems.append(float(vals.std(ddof=1) / np.sqrt(vals.size)))
else:
sems.append(0.0)
x = np.arange(len(group_order))
colors = [
group_colors[group]
if group_colors and group in group_colors
else get_series_color(idx)
for idx, group in enumerate(group_order)
]
bars = ax.bar(
x,
means,
yerr=sems,
color=colors,
edgecolor=COLORS["edge"],
linewidth=_STATS_STYLE["bar_linewidth"],
capsize=_STATS_STYLE["bar_capsize"],
alpha=_STATS_STYLE["bar_alpha"],
)
ax.set_xticks(x)
ax.set_xticklabels(
[
group_labels[group]
if group_labels and group in group_labels
else group
for group in group_order
],
fontsize=FONTS["tick"],
)
ax.set_ylabel(label, fontsize=FONTS["label"])
style_axes(ax, grid=True)
for bar, mean_value in zip(bars, means):
if np.isnan(mean_value):
continue
ax.text(
bar.get_x() + bar.get_width() / 2,
mean_value,
f"{mean_value:.2f}",
ha="center",
va="bottom",
fontsize=FONTS["annotation"],
)
finite_means = np.asarray(means, dtype=float)
if is_lower_better is not None and np.isfinite(finite_means).any():
best_index = (
int(np.nanargmin(finite_means))
if is_lower_better
else int(np.nanargmax(finite_means))
)
ax.annotate(
"★",
xy=(best_index, finite_means[best_index]),
fontsize=_STATS_STYLE["annotation_star_size"],
ha="center",
va="bottom",
color=COLORS["stat_highlight"],
)
fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold")
return _finalize_fig(fig, show=show, fname=fname)
[docs]
def plot_tradeoff_scatter(
data,
x_col,
y_col,
group_col="group",
group_order=None,
group_colors=None,
group_labels=None,
x_label=None,
y_label=None,
title="Metric Trade-off",
reference_x=None,
reference_y=None,
ax=None,
fname=None,
show=True,
):
"""Plot a grouped x/y trade-off scatter with optional group means.
Parameters
----------
data : mapping of str to array-like
Columnar mapping with group and metric columns.
x_col, y_col : str
Metric columns for x and y axes.
group_col : str
Grouping column name.
group_order : list of str | None
Optional group order. If None, first-seen order is used.
group_colors, group_labels : dict | None
Optional style overrides keyed by group name.
x_label, y_label : str | None
Axis labels. If None, derived from metric names.
title : str
Axes title.
reference_x, reference_y : float | None
Optional vertical/horizontal reference lines.
ax : matplotlib.axes.Axes | None
Existing axes. If None, create a new figure.
fname : path-like | None
Optional output path when creating a new figure.
show : bool
Whether to display the figure when creating a new figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_tradeoff_scatter
>>> data = {
... "group": np.array(["A", "A", "B", "B"]),
... "distortion": np.array([0.1, 0.2, 0.4, 0.3]),
... "attenuation": np.array([8.0, 9.0, 5.0, 6.0]),
... }
>>> fig = plot_tradeoff_scatter(
... data, group_col="group", x_col="distortion", y_col="attenuation", show=False
... )
"""
columns = {name: np.asarray(values) for name, values in data.items()}
groups = np.asarray(columns[group_col], dtype=object)
x = np.asarray(columns[x_col], dtype=float)
y = np.asarray(columns[y_col], dtype=float)
if x_label is None:
x_label = str(x_col).replace("_", " ").strip()
if y_label is None:
y_label = str(y_col).replace("_", " ").strip()
if group_order is None:
group_order = list(dict.fromkeys(groups.tolist()))
else:
group_order = list(group_order)
with use_theme():
finalize = ax is None
if ax is None:
fig, ax = themed_figure(1, 1, figsize=(8, 6))
if isinstance(ax, np.ndarray):
ax = ax.flat[0]
else:
fig = ax.figure
for idx, group in enumerate(group_order):
mask = groups == group
x_group = x[mask]
y_group = y[mask]
finite_mask = np.isfinite(x_group) & np.isfinite(y_group)
x_group = x_group[finite_mask]
y_group = y_group[finite_mask]
color = (
group_colors[group]
if group_colors and group in group_colors
else get_series_color(idx)
)
label = (
group_labels[group] if group_labels and group in group_labels else group
)
ax.scatter(
x_group,
y_group,
color=color,
s=_STATS_STYLE["scatter_size"],
alpha=_STATS_STYLE["scatter_alpha"],
edgecolors=COLORS["edge"],
linewidth=_STATS_STYLE["scatter_edge_linewidth"],
label=label,
zorder=3,
)
if x_group.size > 1:
ax.scatter(
x_group.mean(),
y_group.mean(),
color=color,
s=_STATS_STYLE["mean_scatter_size"],
marker="*",
edgecolors=COLORS["edge"],
linewidth=_STATS_STYLE["mean_linewidth"] / 2,
zorder=4,
)
ax.set_xlabel(x_label, fontsize=FONTS["label"])
ax.set_ylabel(y_label, fontsize=FONTS["label"])
ax.set_title(title, fontsize=FONTS["title"], fontweight="bold")
if reference_y is not None:
ax.axhline(
reference_y,
color=COLORS["stat_reference"],
ls=":",
lw=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["reference_alpha"],
)
if reference_x is not None:
ax.axvline(
reference_x,
color=COLORS["stat_reference"],
ls=":",
lw=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["reference_alpha"],
)
themed_legend(ax)
style_axes(ax, grid=True)
if finalize:
return _finalize_fig(fig, show=show, fname=fname)
return fig
[docs]
def plot_metric_comparison(
data,
metric_col,
metric_label=None,
group_col="group",
subject_col="subject",
group_order=None,
group_colors=None,
group_labels=None,
title="Metric Comparison",
reference_value=None,
reference_label="Reference",
ax=None,
fname=None,
show=True,
):
"""Plot one metric as grouped bars or paired subject trajectories.
Parameters
----------
data : mapping of str to array-like
Columnar mapping with at least ``group_col`` and one numeric metric.
metric_col : str
Metric column to visualize.
metric_label : str | None
Y-axis label. If None, derived from ``metric_col``.
group_col : str
Grouping column name.
subject_col : str
Subject identifier column for paired overlays.
group_order : list of str | None
Optional explicit group order. If None, first-seen order is used.
group_colors, group_labels : dict | None
Optional style overrides keyed by group.
title : str
Axes title.
reference_value : float | None
Optional horizontal reference line.
reference_label : str
Legend label for ``reference_value``.
ax : matplotlib.axes.Axes | None
Existing axes. If None, create a new figure.
fname : path-like | None
Optional output path when creating a new figure.
show : bool
Whether to display the figure when creating a new figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plot.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_metric_comparison
>>> data = {
... "subject": np.array(["s1", "s1", "s2", "s2"]),
... "group": np.array(["A", "B", "A", "B"]),
... "score": np.array([1.1, 0.8, 1.0, 0.7]),
... }
>>> fig = plot_metric_comparison(
... data,
... group_col="group",
... subject_col="subject",
... metric_col="score",
... show=False,
... )
"""
columns = {name: np.asarray(values) for name, values in data.items()}
groups = np.asarray(columns[group_col], dtype=object)
subjects = np.asarray(columns[subject_col], dtype=object)
metric_values = np.asarray(columns[metric_col], dtype=float)
if metric_label is None:
metric_label = str(metric_col).replace("_", " ").strip()
if group_order is None:
group_order = list(dict.fromkeys(groups.tolist()))
else:
group_order = list(group_order)
with use_theme():
finalize = ax is None
if ax is None:
fig, ax = themed_figure(1, 1, figsize=(8, 6))
if isinstance(ax, np.ndarray):
ax = ax.flat[0]
else:
fig = ax.figure
multi_subject = len(list(dict.fromkeys(subjects.tolist()))) > 1
if multi_subject:
means = _plot_subject_trajectories(
ax,
subjects,
groups,
metric_values,
group_order,
style="o-",
markersize=_STATS_STYLE["subject_trace_marker_size"],
)
ax.plot(
range(len(group_order)),
means,
"s-",
color=COLORS["stat_mean"],
markersize=_STATS_STYLE["mean_marker_size"],
lw=_STATS_STYLE["mean_linewidth"],
label="Group mean",
zorder=5,
)
else:
metric_vals = []
for group in group_order:
values = metric_values[groups == group]
values = values[np.isfinite(values)]
metric_vals.append(values[0] if values.size else np.nan)
x = np.arange(len(group_order))
colors = [
group_colors[group]
if group_colors and group in group_colors
else get_series_color(idx)
for idx, group in enumerate(group_order)
]
ax.bar(
x,
metric_vals,
color=colors,
edgecolor=COLORS["edge"],
linewidth=_STATS_STYLE["bar_linewidth"],
alpha=_STATS_STYLE["bar_alpha"],
)
for xi, metric_value in zip(x, metric_vals):
if np.isnan(metric_value):
continue
ax.text(
xi,
metric_value,
f"{metric_value:.2f}",
ha="center",
va="bottom",
fontsize=FONTS["annotation"],
)
if reference_value is not None:
ax.axhline(
reference_value,
color=COLORS["stat_reference"],
ls="--",
lw=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["reference_alpha"],
label=reference_label,
)
ax.set_xticks(range(len(group_order)))
ax.set_xticklabels(
[
group_labels[group] if group_labels and group in group_labels else group
for group in group_order
],
fontsize=FONTS["tick"],
)
ax.set_ylabel(metric_label, fontsize=FONTS["label"])
ax.set_title(title, fontsize=FONTS["title"], fontweight="bold")
themed_legend(ax)
style_axes(ax, grid=True)
if finalize:
return _finalize_fig(fig, show=show, fname=fname)
return fig
[docs]
def plot_harmonic_attenuation(
freqs_before,
gm_before,
cleaned_psds,
harmonics_hz,
subject="",
series_order=None,
series_colors=None,
series_labels=None,
title=None,
fname=None,
show=True,
):
"""Plot grouped per-harmonic attenuation bars for line-noise studies.
Parameters
----------
freqs_before : array-like
Frequency axis of the reference PSD.
gm_before : array-like
Reference geometric-mean PSD.
cleaned_psds : dict[str, tuple[array-like, array-like]]
Mapping from series name to ``(freqs, psd)`` after denoising.
harmonics_hz : array-like of float
Harmonic frequencies to evaluate.
subject : str
Optional subject label included in default title.
series_order : list[str] | None
Plotting order for series. If None, keys from ``cleaned_psds`` are used.
series_colors, series_labels : dict | None
Optional color/label overrides keyed by series name.
title : str | None
Custom axes title.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Notes
-----
This helper is intentionally domain-specific (line-frequency harmonics)
and complements the otherwise study-agnostic grouped-stat plots.
"""
if series_order is None:
series_order = list(cleaned_psds.keys())
with use_theme():
fig, ax = themed_figure(1, 1, figsize=(10, 5))
if isinstance(ax, np.ndarray):
ax = ax.flat[0]
bar_width = 0.8 / max(len(series_order), 1)
x = np.arange(len(harmonics_hz))
for idx, series_name in enumerate(series_order):
if series_name not in cleaned_psds:
continue
_, gm_clean = cleaned_psds[series_name]
attenuation = [
peak_attenuation_db(freqs_before, gm_before, gm_clean, harmonic)
for harmonic in harmonics_hz
]
ax.bar(
x + idx * bar_width,
attenuation,
bar_width,
color=series_colors[series_name]
if series_colors and series_name in series_colors
else get_series_color(idx),
edgecolor=COLORS["edge"],
linewidth=_STATS_STYLE["bar_linewidth"],
label=series_labels[series_name]
if series_labels and series_name in series_labels
else series_name,
alpha=_STATS_STYLE["bar_alpha"],
)
ax.set_xticks(x + bar_width * (len(series_order) - 1) / 2)
ax.set_xticklabels(
[f"{harmonic:.0f} Hz" for harmonic in harmonics_hz],
fontsize=FONTS["tick"],
)
ax.set_ylabel("Peak Attenuation (dB)", fontsize=FONTS["label"])
if title is None:
title = (
f"Per-Harmonic Attenuation — {subject}"
if subject
else "Per-Harmonic Attenuation"
)
ax.set_title(title, fontsize=FONTS["title"], fontweight="bold")
themed_legend(ax)
style_axes(ax, grid=True)
return _finalize_fig(fig, show=show, fname=fname)
[docs]
def plot_metric_slopes(
data,
metric_cols=None,
metric_labels=None,
metric_specs=None,
group_col="group",
subject_col="subject",
group_order=None,
group_colors=None,
group_labels=None,
reference_lines=None,
suptitle=None,
title="Paired Subject-Level Comparison",
fname=None,
show=True,
):
"""Plot subject-level paired trajectories for one or more metrics.
Parameters
----------
data : mapping of str to array-like
Columnar mapping with subject/group identifiers and metric columns.
metric_cols : list of str | None
Metric columns to plot. Used only when ``metric_specs`` is None.
metric_labels : list of str | None
Display labels aligned with ``metric_cols``.
metric_specs : list[tuple[str, str]] | None
Explicit list of ``(metric_col, metric_label)`` pairs.
group_col : str
Grouping column name.
subject_col : str
Subject identifier column name.
group_order : list of str | None
Optional group order. If None, first-seen order is used.
group_colors, group_labels : dict | None
Optional style overrides keyed by group.
reference_lines : dict | None
Optional horizontal reference lines per metric:
``{metric_col: [(y_value, style_dict), ...]}``.
suptitle, title : str | None
Figure title. ``suptitle`` overrides ``title`` when provided.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_metric_slopes
>>> data = {
... "subject": np.array(["s1", "s1", "s2", "s2"]),
... "group": np.array(["A", "B", "A", "B"]),
... "metric": np.array([1.0, 0.8, 1.1, 0.7]),
... }
>>> fig = plot_metric_slopes(
... data, metric_cols=["metric"], group_col="group", show=False
... )
"""
columns = {name: np.asarray(values) for name, values in data.items()}
subject_values = np.asarray(columns[subject_col], dtype=object)
groups = np.asarray(columns[group_col], dtype=object)
if metric_specs is None:
metric_cols = list(metric_cols)
if metric_labels is None:
metric_labels = [
str(metric_col).replace("_", " ").strip() for metric_col in metric_cols
]
else:
metric_labels = list(metric_labels)
metric_specs = list(zip(metric_cols, metric_labels))
else:
metric_specs = list(metric_specs)
if group_order is None:
group_order = list(dict.fromkeys(groups.tolist()))
else:
group_order = list(group_order)
with use_theme():
fig, axes = themed_figure(
1, len(metric_specs), figsize=(6 * len(metric_specs), 5)
)
if not isinstance(axes, np.ndarray):
axes = np.array([axes])
tick_labels = [
group_labels[group] if group_labels and group in group_labels else group
for group in group_order
]
for ax, (metric_col, metric_label) in zip(axes.flat, metric_specs):
metric = np.asarray(columns[metric_col], dtype=float)
means = _plot_subject_trajectories(
ax,
subject_values,
groups,
metric,
group_order,
style="o-",
markersize=_STATS_STYLE["subject_trace_marker_size"],
)
ax.plot(
range(len(group_order)),
means,
"s-",
color=COLORS["stat_mean"],
markersize=_STATS_STYLE["mean_marker_size"],
lw=_STATS_STYLE["mean_linewidth"],
label="Group mean",
zorder=5,
)
ax.set_xticks(range(len(group_order)))
ax.set_xticklabels(tick_labels, fontsize=FONTS["tick"])
ax.set_ylabel(metric_label, fontsize=FONTS["label"])
if reference_lines and metric_col in reference_lines:
for y_val, style in reference_lines[metric_col]:
ax.axhline(y_val, **style)
themed_legend(ax)
style_axes(ax, grid=True)
fig.suptitle(
suptitle or title,
fontsize=FONTS["suptitle"],
fontweight="bold",
)
return _finalize_fig(fig, show=show, fname=fname)
[docs]
def plot_metric_violins(
data,
metric_cols,
metric_labels=None,
group_col="group",
subject_col="subject",
group_order=None,
group_colors=None,
group_labels=None,
baseline_group=None,
reference_lines=None,
show_paired=True,
suptitle=None,
figsize=None,
fname=None,
show=True,
):
"""Plot violin + strip distributions with optional paired subject lines.
Parameters
----------
data : mapping of str to array-like
Columnar mapping with group/subject columns and one or more metrics.
metric_cols : list of str
Metric columns to render.
metric_labels : list of str | None
Labels corresponding to ``metric_cols``.
group_col : str
Grouping column name.
subject_col : str
Subject identifier column name.
group_order : list of str | None
Optional group order. If None, first-seen order is used.
group_colors, group_labels : dict | None
Optional style overrides keyed by group.
baseline_group : str | None
Optional group used to draw a baseline mean line.
reference_lines : dict | None
Optional horizontal reference lines per metric.
show_paired : bool
Whether to draw subject-level paired lines.
suptitle : str | None
Figure-level title.
figsize : tuple | None
Figure size in inches. Defaults to ``(4 * n_metrics, 5.5)``.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Raises
------
ImportError
If seaborn is unavailable.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_metric_violins
>>> data = {
... "subject": np.array(["s1", "s1", "s2", "s2"]),
... "group": np.array(["A", "B", "A", "B"]),
... "metric": np.array([0.2, 0.6, 0.1, 0.5]),
... }
>>> fig = plot_metric_violins(
... data, ["metric"], group_col="group", subject_col="subject", show=False
... )
"""
columns = {name: np.asarray(values) for name, values in data.items()}
sns = _try_import_seaborn()
metric_cols = list(metric_cols)
if metric_labels is None:
metric_labels = [
str(metric_col).replace("_", " ").strip() for metric_col in metric_cols
]
else:
metric_labels = list(metric_labels)
groups = np.asarray(columns[group_col], dtype=object)
subject_values = np.asarray(columns[subject_col], dtype=object)
if group_order is None:
group_order = list(dict.fromkeys(groups.tolist()))
else:
group_order = list(group_order)
n_metrics = len(metric_cols)
n_subjects = len(list(dict.fromkeys(subject_values.tolist())))
if figsize is None:
figsize = (4 * n_metrics, 5.5)
with use_theme():
fig, axes = themed_figure(1, n_metrics, figsize=figsize)
if n_metrics == 1:
axes = np.array([axes])
pretty_order = [
group_labels[group] if group_labels and group in group_labels else group
for group in group_order
]
palette = {
(
group_labels[group] if group_labels and group in group_labels else group
): (
group_colors[group]
if group_colors and group in group_colors
else get_series_color(idx)
)
for idx, group in enumerate(group_order)
}
for ax, metric_col, metric_label in zip(axes.flat, metric_cols, metric_labels):
metric = np.asarray(columns[metric_col], dtype=float)
x_vals = []
y_vals = []
for group in group_order:
group_vals = metric[groups == group]
group_vals = group_vals[np.isfinite(group_vals)]
if group_vals.size == 0:
continue
group_name = (
group_labels[group]
if group_labels and group in group_labels
else group
)
x_vals.extend([group_name] * group_vals.size)
y_vals.extend(group_vals.tolist())
if not y_vals:
ax.text(
0.5,
0.5,
"No data",
transform=ax.transAxes,
ha="center",
fontsize=FONTS["label"],
)
style_axes(ax)
continue
x_arr = np.asarray(x_vals, dtype=object)
y_arr = np.asarray(y_vals, dtype=float)
with _suppress_seaborn_plot_warnings():
sns.violinplot(
x=x_arr,
y=y_arr,
hue=x_arr,
order=pretty_order,
hue_order=pretty_order,
palette=palette,
inner=None,
linewidth=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["subject_trace_alpha"],
ax=ax,
cut=0,
density_norm="width",
legend=False,
)
sns.stripplot(
x=x_arr,
y=y_arr,
hue=x_arr,
order=pretty_order,
hue_order=pretty_order,
palette=palette,
size=_STATS_STYLE["strip_size"],
alpha=_STATS_STYLE["strip_alpha"],
jitter=_STATS_STYLE["strip_jitter"],
ax=ax,
zorder=5,
legend=False,
)
if show_paired and n_subjects > 1:
_plot_subject_trajectories(
ax,
subject_values,
groups,
metric,
group_order,
style="-",
alpha=_STATS_STYLE["paired_line_alpha"],
linewidth=_STATS_STYLE["paired_linewidth"],
zorder=1,
)
if baseline_group is not None:
base_values = metric[groups == baseline_group]
base_values = base_values[np.isfinite(base_values)]
else:
base_values = np.array([])
if base_values.size:
ax.axhline(
base_values.mean(),
color=COLORS["stat_reference"],
ls="--",
lw=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["reference_alpha"],
)
if reference_lines and metric_col in reference_lines:
for y_val, style in reference_lines[metric_col]:
ax.axhline(y_val, **style)
ax.set_xlabel("")
ax.set_ylabel(metric_label, fontsize=FONTS["label"])
ax.tick_params(axis="x", labelsize=FONTS["tick"], rotation=30)
style_axes(ax, grid=True)
ax.xaxis.grid(False)
fig.suptitle(
suptitle or f"Metric Distributions (N = {n_subjects})",
fontsize=FONTS["suptitle"],
fontweight="bold",
)
return _finalize_fig(fig, show=show, fname=fname)
[docs]
def plot_null_distribution(
null_values,
observed,
metric_label="Statistic",
ci=95,
n_bins=60,
suptitle=None,
series_color=None,
figsize=None,
fname=None,
show=True,
):
"""Plot a null-distribution histogram with observed statistic and CI.
Parameters
----------
null_values : array-like
Samples from the null distribution.
observed : float
Observed statistic to compare against the null.
metric_label : str
Label for the x-axis.
ci : float
Central interval width in percent.
n_bins : int
Number of histogram bins.
suptitle : str | None
Figure title override.
series_color : str | None
Color for the observed-statistic marker/annotation.
figsize : tuple | None
Figure size in inches.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
p_value : float
Two-sided empirical p-value under ``null_values``.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_null_distribution
>>> rng = np.random.default_rng(0)
>>> null = rng.normal(0.0, 0.1, 1000)
>>> fig, p = plot_null_distribution(null, observed=0.25, show=False)
"""
null_values = np.asarray(null_values)
if figsize is None:
figsize = (8, 5)
if series_color is None:
series_color = COLORS["stat_mean"]
with use_theme():
fig, ax = themed_figure(1, 1, figsize=figsize)
ax.hist(
null_values,
bins=n_bins,
color=COLORS["stat_subject"],
alpha=_STATS_STYLE["hist_alpha"],
edgecolor=COLORS["separator"],
linewidth=_STATS_STYLE["hist_linewidth"],
density=True,
zorder=2,
label=f"Null (N = {len(null_values):,})",
)
alpha_tail = (100 - ci) / 2
lo, hi = np.percentile(null_values, [alpha_tail, 100 - alpha_tail])
ax.axvspan(
lo,
hi,
color=COLORS["stat_ci"],
alpha=0.12,
zorder=1,
label=f"{ci}% CI [{lo:+.3f}, {hi:+.3f}]",
)
ax.axvline(
lo,
color=COLORS["stat_reference"],
ls=":",
lw=_STATS_STYLE["reference_linewidth"],
alpha=0.6,
)
ax.axvline(
hi,
color=COLORS["stat_reference"],
ls=":",
lw=_STATS_STYLE["reference_linewidth"],
alpha=0.6,
)
ax.axvline(
observed,
color=series_color,
lw=2.5,
ls="--",
zorder=5,
label=f"Observed = {observed:.3f}",
)
p_value = float(np.mean(np.abs(null_values) >= np.abs(observed)))
ax.annotate(
f"p = {p_value:.4f}",
xy=(observed, ax.get_ylim()[1] * 0.92),
fontsize=FONTS["annotation"],
fontweight="bold",
ha="left" if observed > np.median(null_values) else "right",
va="top",
color=series_color,
xytext=(8, 0),
textcoords="offset points",
)
ax.set_xlabel(metric_label, fontsize=FONTS["label"])
ax.set_ylabel("Density", fontsize=FONTS["label"])
themed_legend(ax, fontsize=_STATS_STYLE["legend_fontsize_small"])
style_axes(ax)
fig.suptitle(
suptitle or f"Null Distribution - {metric_label}",
fontsize=FONTS["suptitle"],
fontweight="bold",
)
return _finalize_fig(fig, show=show, fname=fname), p_value
[docs]
def plot_forest(
data,
metric_col,
ci_col=None,
se_col=None,
group_col="group",
subject_col="subject",
target_group=None,
baseline_group=None,
group_colors=None,
group_labels=None,
metric_label=None,
reference_line=0.0,
suptitle=None,
figsize=None,
fname=None,
show=True,
):
"""Plot per-subject point estimates with confidence intervals.
Parameters
----------
data : mapping of str to array-like
Columnar mapping with group, subject, and metric columns.
metric_col : str
Metric column to display on the x-axis.
ci_col : str | None
Optional half-width CI column for each subject estimate.
se_col : str | None
Optional SE column. If provided and ``ci_col`` is absent, CI is
approximated as ``1.96 * SE``.
group_col : str
Grouping column name.
subject_col : str
Subject identifier column name.
target_group : str | None
Group to plot as primary forest series. Defaults to the last
first-seen group.
baseline_group : str | None
Optional baseline group to overlay with faint points and mean marker.
group_colors, group_labels : dict | None
Optional style overrides keyed by group.
metric_label : str | None
X-axis label. If None, derived from ``metric_col``.
reference_line : float | None
Optional vertical reference line value.
suptitle : str | None
Figure title override.
figsize : tuple | None
Figure size in inches.
fname : path-like | None
Optional output path.
show : bool
Whether to display the figure.
Returns
-------
fig : matplotlib.figure.Figure
Figure handle.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.viz import plot_forest
>>> data = {
... "subject": np.array(["s1", "s2", "s1", "s2"]),
... "group": np.array(["A", "A", "B", "B"]),
... "effect": np.array([0.2, 0.4, 0.8, 0.9]),
... }
>>> fig = plot_forest(data, metric_col="effect", group_col="group", show=False)
"""
columns = {name: np.asarray(values) for name, values in data.items()}
groups_col = np.asarray(columns[group_col], dtype=object)
subject_col_values = np.asarray(columns[subject_col], dtype=object)
metric = np.asarray(columns[metric_col], dtype=float)
groups = list(dict.fromkeys(groups_col.tolist()))
if target_group is None:
target_group = groups[-1]
if metric_label is None:
metric_label = str(metric_col).replace("_", " ").strip().title()
target_mask = groups_col == target_group
subjects = np.asarray(subject_col_values[target_mask], dtype=object)
values = metric[target_mask]
order_idx = np.argsort(np.where(np.isfinite(values), values, np.inf))
subjects = subjects[order_idx]
values = values[order_idx]
n_subjects = len(subjects)
if ci_col is not None:
ci_values = np.asarray(columns[ci_col], dtype=float)[target_mask]
half_width = ci_values[order_idx]
elif se_col is not None:
se_values = np.asarray(columns[se_col], dtype=float)[target_mask]
half_width = (1.96 * se_values)[order_idx]
else:
sd = values[np.isfinite(values)].std(ddof=1) if n_subjects > 1 else 1.0
half_width = np.full(n_subjects, 1.96 * sd / np.sqrt(max(n_subjects, 1)))
if figsize is None:
figsize = (8, max(4, n_subjects * 0.35 + 2))
with use_theme():
fig, ax = themed_figure(1, 1, figsize=figsize)
y_pos = np.arange(n_subjects)
target_color = (
group_colors[target_group]
if group_colors and target_group in group_colors
else get_series_color(groups.index(target_group))
)
target_label = (
group_labels[target_group]
if group_labels and target_group in group_labels
else target_group
)
if baseline_group is not None:
base_color = (
group_colors[baseline_group]
if group_colors and baseline_group in group_colors
else get_series_color(groups.index(baseline_group))
)
base_label = (
group_labels[baseline_group]
if group_labels and baseline_group in group_labels
else baseline_group
)
base_metric = metric[groups_col == baseline_group]
base_subjects = np.asarray(
subject_col_values[groups_col == baseline_group], dtype=object
)
for i, subject in enumerate(subjects):
mask = base_subjects == subject
base_values = base_metric[mask]
base_values = base_values[np.isfinite(base_values)]
if base_values.size:
ax.plot(
base_values[0],
y_pos[i],
"o",
color=base_color,
markersize=_STATS_STYLE["forest_marker_size"],
alpha=0.35,
zorder=2,
)
finite_base = base_metric[np.isfinite(base_metric)]
if finite_base.size:
base_mean = finite_base.mean()
ax.plot(
base_mean,
-1.2,
"D",
color=base_color,
markersize=_STATS_STYLE["forest_baseline_mean_marker_size"],
zorder=6,
alpha=0.5,
label=f"{base_label} mean = {base_mean:.3f}",
)
ax.errorbar(
values,
y_pos,
xerr=half_width,
fmt="o",
color=target_color,
ecolor=target_color,
elinewidth=1.2,
capsize=_STATS_STYLE["bar_capsize"],
markersize=_STATS_STYLE["forest_marker_size"],
alpha=_STATS_STYLE["bar_alpha"],
zorder=4,
label=target_label,
)
finite_values = values[np.isfinite(values)]
target_mean = float(finite_values.mean()) if finite_values.size else np.nan
target_se = (
float(finite_values.std(ddof=1) / np.sqrt(finite_values.size))
if finite_values.size > 1
else 0.0
)
ax.errorbar(
target_mean,
-1.2,
xerr=1.96 * target_se,
fmt="D",
color=target_color,
ecolor=target_color,
elinewidth=_STATS_STYLE["mean_linewidth"],
capsize=_STATS_STYLE["bar_capsize"] + 1,
markersize=_STATS_STYLE["forest_pooled_marker_size"],
zorder=6,
label=f"Pooled mean = {target_mean:.3f}",
)
if reference_line is not None:
ax.axvline(
reference_line,
color=COLORS["stat_reference"],
ls="--",
lw=_STATS_STYLE["reference_linewidth"],
alpha=_STATS_STYLE["reference_alpha"],
)
ax.set_yticks(list(y_pos) + [-1.2])
ax.set_yticklabels(list(subjects) + ["Pooled"], fontsize=FONTS["tick"])
ax.set_xlabel(metric_label, fontsize=FONTS["label"])
ax.set_ylabel("")
ax.invert_yaxis()
style_axes(ax, grid=True)
ax.yaxis.grid(False)
themed_legend(
ax, fontsize=_STATS_STYLE["legend_fontsize_small"], loc="lower right"
)
fig.suptitle(
suptitle
or f"Forest Plot - {group_labels[target_group] if group_labels and target_group in group_labels else target_group}",
fontsize=FONTS["suptitle"],
fontweight="bold",
)
return _finalize_fig(fig, show=show, fname=fname)