Source code for mne_denoise.viz.summary

"""Generic multi-panel summary composers.

This module provides method-agnostic summary figures that compose
precomputed inputs only. Estimator fitting and benchmark-specific defaults
are intentionally out of scope for this plotting layer.

This module contains:
1. Dashboard-style denoising summaries.
2. Metric trade-off summary composers.
3. Generic component/signal/interaction/endpoint summary composers.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
         Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

from collections.abc import Mapping

import mne
import numpy as np

from ._summary_panels import (
    _draw_summary_table,
    _new_summary_figure,
    _new_summary_grid,
    _plot_before_after_psd_panel,
    _plot_patterns_panel,
    _plot_removed_power_panel,
    _plot_segment_boundaries_panel,
    _plot_segment_counts_panel,
    _plot_segment_metric_panel,
    _plot_selection_scores_panel,
    _plot_source_trace_panel,
)
from ._utils import _compute_gfp
from .signals import plot_power_ratio_map
from .spectra import plot_psd_comparison
from .stats import plot_metric_comparison, plot_tradeoff_scatter
from .theme import (
    COLORS,
    FONTS,
    _finalize_fig,
    get_series_color,
    style_axes,
    themed_figure,
    themed_legend,
)


[docs] def plot_denoising_summary( inst_before, inst_after, *, info, times, title="Denoising Summary", show=True, fname=None, ): """Plot a generic denoising diagnostics dashboard. Parameters ---------- inst_before, inst_after : MNE object | ndarray Signals before and after denoising. Supported array shapes are ``(n_channels, n_times)`` and ``(n_epochs, n_channels, n_times)``. MNE Raw/Epochs/Evoked inputs are also accepted. info : mne.Info Channel info used for the power-ratio map panel. times : array-like of shape (n_times,) Explicit time axis for GFP traces. title : str Figure title. show : bool Whether to display the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``times`` is not 1D or does not match the GFP length. Notes ----- The figure is composed of three panels: 1. Per-channel power ratio map (before/after). 2. PSD comparison (before/after). 3. GFP overlay with difference shading. Examples -------- >>> from mne_denoise.viz import plot_denoising_summary >>> fig = plot_denoising_summary( ... before, ... after, ... info=info, ... times=times, ... show=False, ... ) """ from matplotlib.gridspec import GridSpec fig = _new_summary_figure(figsize=(12, 10), constrained_layout=True) gs = GridSpec(2, 2, figure=fig, height_ratios=[1, 1]) ax_map = fig.add_subplot(gs[0, 0]) plot_power_ratio_map(inst_before, inst_after, info=info, ax=ax_map, show=False) ax_psd = fig.add_subplot(gs[0, 1]) plot_psd_comparison(inst_before, inst_after, ax=ax_psd, show=False) ax_gfp = fig.add_subplot(gs[1, :]) times = np.asarray(times, dtype=float) gfp_before = _compute_gfp(inst_before) gfp_after = _compute_gfp(inst_after) if times.ndim != 1 or times.size != gfp_before.size: raise ValueError("times must be 1D and match the GFP vector length.") ax_gfp.plot(times, gfp_before, label="Before", color=COLORS["before"], alpha=0.7) ax_gfp.plot(times, gfp_after, label="After", color=COLORS["after"], alpha=0.7) ax_gfp.fill_between( times, gfp_before, gfp_after, color=COLORS["muted"], alpha=0.2, label="Difference", ) themed_legend(ax_gfp, loc="best") ax_gfp.set_xlabel("Time", fontsize=FONTS["label"]) ax_gfp.set_ylabel("Global Field Power", fontsize=FONTS["label"]) ax_gfp.set_title("Temporal Signal Comparison (GFP)", fontsize=FONTS["title"]) style_axes(ax_gfp, grid=True) fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_metric_tradeoff_summary( data, *, group_col, subject_col, x_col, y_col, metric_col, x_label=None, y_label=None, metric_label=None, group_order=None, group_colors=None, group_labels=None, tradeoff_title="Metric Trade-off", metric_title="Metric Comparison", reference_x=None, reference_y=None, reference_value=None, reference_label="Reference", suptitle="Trade-off Analysis", show=True, fname=None, ): """Plot a two-panel metric trade-off summary. Parameters ---------- data : mapping of str to array-like Columnar data mapping with aligned 1D columns. group_col : str Group column name. subject_col : str Subject column name used for paired metric comparison. x_col, y_col : str Metric columns used for trade-off scatter axes. metric_col : str Metric column used for the right-panel comparison. x_label, y_label, metric_label : str | None Optional axis labels. group_order : sequence of str | None Optional group display order. group_colors, group_labels : mapping | None Optional color/label overrides keyed by group. tradeoff_title, metric_title, suptitle : str Panel and figure titles. reference_x, reference_y, reference_value : float | None Optional reference lines. reference_label : str Label for the metric reference line. show : bool Whether to display the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Notes ----- This is a composition wrapper around: - :func:`mne_denoise.viz.plot_tradeoff_scatter` - :func:`mne_denoise.viz.plot_metric_comparison` """ columns = {name: np.asarray(values) for name, values in data.items()} fig, axes = themed_figure(1, 2, figsize=(16, 6)) plot_tradeoff_scatter( columns, group_col=group_col, group_order=group_order, group_colors=group_colors, group_labels=group_labels, x_col=x_col, y_col=y_col, x_label=x_label, y_label=y_label, title=tradeoff_title, reference_x=reference_x, reference_y=reference_y, ax=axes[0], show=False, ) plot_metric_comparison( columns, metric_col=metric_col, metric_label=metric_label, group_col=group_col, subject_col=subject_col, group_order=group_order, group_colors=group_colors, group_labels=group_labels, title=metric_title, reference_value=reference_value, reference_label=reference_label, ax=axes[1], show=False, ) fig.suptitle(suptitle, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_cleaning_summary( scores=None, selected_count=0, patterns=None, info=None, channel_names=None, removed=None, sources=None, sfreq=None, freqs=None, psd_before=None, psd_after=None, line_freq=None, fmax=100.0, segment_info=None, summary_rows=None, title="Component Cleaning Summary", figsize=None, dpi=200, show=True, fname=None, ): """Plot a generic component-cleaning dashboard. Parameters ---------- scores : array-like | None Component score vector. selected_count : int Number of selected/removed components. patterns : array-like | None Component patterns with shape ``(n_channels, n_components)``. info : mne.Info | None Optional MNE info for topomap rendering. channel_names : sequence of str | None Channel labels used by non-topomap fallback panels. removed : array-like | None Removed signal with shape ``(n_channels, n_times)`` or ``(n_epochs, n_channels, n_times)``. sources : array-like | None Component source traces with shape ``(n_components, n_times)`` or ``(n_epochs, n_components, n_times)``. sfreq : float | None Sampling frequency for source/segment time axes. freqs : array-like | None Frequency axis for PSD panel. psd_before, psd_after : array-like | None PSD values for before/after comparison. line_freq : float | None Optional line frequency marker in PSD panel. fmax : float Upper frequency bound for PSD panel. segment_info : list[dict] | None Optional segmented metadata. If provided, segmented panels replace score/removed/source panels. summary_rows : list[tuple[str, object]] | None Optional explicit table rows. title : str Figure title. figsize : tuple[float, float] | None Figure size. dpi : int Figure resolution. show : bool Whether to show the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If panel inputs are incompatible with expected dimensions or metadata requirements (for example invalid source/pattern/segment shapes). Notes ----- This function is a thin composer around internal panel painters in :mod:`mne_denoise.viz._summary_panels`. It does not run fitting or denoising; all inputs are expected to be precomputed by the caller. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_component_cleaning_summary >>> rng = np.random.default_rng(0) >>> freqs = np.linspace(0, 80, 161) >>> fig = plot_component_cleaning_summary( ... scores=np.array([2.0, 1.2, 0.7]), ... selected_count=1, ... patterns=rng.standard_normal((5, 3)), ... removed=rng.standard_normal((5, 200)), ... sources=rng.standard_normal((3, 200)), ... sfreq=200.0, ... freqs=freqs, ... psd_before=rng.random((5, freqs.size)), ... psd_after=rng.random((5, freqs.size)), ... show=False, ... ) """ fig, gs = _new_summary_grid(figsize=figsize, dpi=dpi, hspace=0.40) if segment_info: ax_a = fig.add_subplot(gs[0, 0]) _plot_segment_counts_panel( ax_a, segment_info, title="(a) Segment selection counts" ) ax_b = fig.add_subplot(gs[0, 1]) _plot_segment_metric_panel(ax_b, segment_info, title="(b) Segment metric") ax_c = fig.add_subplot(gs[1, 0]) _plot_segment_boundaries_panel( ax_c, segment_info, sfreq=sfreq, title="(c) Segment boundaries", ) _plot_patterns_panel( fig, gs[1, 1], patterns, info=info, channel_names=channel_names, title="(d) Component patterns", ) else: ax_a = fig.add_subplot(gs[0, 0]) _plot_selection_scores_panel( ax_a, scores, int(selected_count), selected_label=f"Selected ({int(selected_count)})", title="(a) Component scores", ) _plot_patterns_panel( fig, gs[0, 1], patterns, info=info, channel_names=channel_names, title="(b) Component patterns", ) ax_c = fig.add_subplot(gs[1, 0]) _plot_removed_power_panel( fig, ax_c, removed, info=info, channel_names=channel_names, title="(c) Removed signal power", ) ax_d = fig.add_subplot(gs[1, 1]) _plot_source_trace_panel( ax_d, sources, sfreq=sfreq, title="(d) Component traces", ) ax_e = fig.add_subplot(gs[2, 0]) _plot_before_after_psd_panel( ax_e, freqs=freqs, psd_before=psd_before, psd_after=psd_after, line_freq=line_freq, fmax=fmax, title="(e) Power spectral density", ) ax_f = fig.add_subplot(gs[2, 1]) rows = list(summary_rows) if summary_rows is not None else [] if not rows: if scores is not None: rows.append(("n_components", np.asarray(scores).size)) if selected_count: rows.append(("selected_count", int(selected_count))) if segment_info: rows.append(("n_segments", len(segment_info))) if line_freq is not None: rows.append(("line_freq_hz", f"{float(line_freq):.3g}")) if sfreq is not None: rows.append(("sfreq_hz", f"{float(sfreq):.6g}")) if not rows: rows = [("status", "No metadata provided")] _draw_summary_table(ax_f, rows) fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_signal_diagnostics_summary( signals, *, channel, times, group_order, reference_group, group_colors, group_labels, channel_names=None, channel_label=None, windows=None, title="Signal Diagnostics Summary", show=True, fname=None, ): """Plot grouped time-domain signal diagnostics. Parameters ---------- signals : mapping[str, array-like | MNE object] Mapping from group name to signal data. Each value must be 2D ``(n_channels, n_times)`` or 3D ``(n_epochs, n_channels, n_times)``. channel : int | str Channel selector. String selectors require explicit ``channel_names``. times : array-like of shape (n_times,) Explicit time axis. channel_names : sequence of str | None Explicit channel names used for string ``channel`` selectors and channel labeling. The function does not infer names from MNE inputs. channel_label : str | None Explicit channel label when ``channel_names`` is not provided. group_order : sequence of str Explicit group plotting order. reference_group : str Explicit reference group used for the difference panel. group_colors, group_labels : mapping Explicit color and display-label mappings keyed by group. windows : sequence of tuple[float, float, str] | None Optional highlighted windows. Each entry must be ``(start, stop, label)``. title : str Figure title. show : bool Whether to show the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``signals`` is empty, ``times`` is shape-incompatible, channel metadata is insufficient for the requested selector, or group ordering is invalid. Notes ----- This is a strict explicit API: ``group_order``, ``reference_group``, ``group_colors``, and ``group_labels`` are caller-owned inputs. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_signal_diagnostics_summary >>> rng = np.random.default_rng(0) >>> n_times = 200 >>> times = np.arange(n_times) / 200.0 >>> signals = { ... "before": rng.standard_normal((3, n_times)), ... "after": rng.standard_normal((3, n_times)), ... } >>> fig = plot_signal_diagnostics_summary( ... signals, ... channel=1, ... channel_label="C2", ... times=times, ... group_order=["before", "after"], ... reference_group="before", ... group_colors={"before": "#4C72B0", "after": "#55A868"}, ... group_labels={"before": "Before", "after": "After"}, ... windows=[(0.08, 0.14, "early")], ... show=False, ... ) """ if not isinstance(signals, Mapping) or len(signals) == 0: raise ValueError("signals must be a non-empty mapping.") group_order = list(group_order) if len(group_order) == 0: raise ValueError("group_order must contain at least one group name.") if reference_group not in group_order: raise ValueError("reference_group must be present in group_order.") prepared = {} n_channels = None n_times = None for name in group_order: data_in = signals[name] if isinstance(data_in, (mne.io.BaseRaw, mne.Evoked, mne.BaseEpochs)): arr = np.asarray(data_in.get_data(), dtype=float) else: arr = np.asarray(data_in, dtype=float) if arr.ndim == 3: arr = arr.mean(axis=0) if arr.ndim != 2: raise ValueError( "Each signal must be 2D (n_channels, n_times) or 3D " "(n_epochs, n_channels, n_times)." ) if n_channels is None: n_channels = arr.shape[0] n_times = arr.shape[1] elif arr.shape != (n_channels, n_times): raise ValueError("All signals must have matching channel/time dimensions.") prepared[name] = arr names = None if channel_names is not None: names = list(channel_names) if len(names) != n_channels: raise ValueError( "channel_names length must match the number of signal channels." ) if isinstance(channel, str): if names is None: raise ValueError("String channel selectors require channel names.") channel_idx = names.index(channel) else: channel_idx = int(channel) if channel_label is not None: channel_label = str(channel_label) elif names is not None: channel_label = str(names[channel_idx]) else: raise ValueError( "channel_label must be provided when channel_names is not provided." ) time_axis = np.asarray(times, dtype=float) if time_axis.ndim != 1 or time_axis.size != n_times: raise ValueError("times must be 1D and match n_times.") window_specs = [] if windows is None else list(windows) fig, axes = themed_figure(2, 2, figsize=(14, 8)) if isinstance(axes, np.ndarray): ax_trace, ax_gfp, ax_diff, ax_table = axes.ravel() else: ax_trace = ax_gfp = ax_diff = ax_table = axes ref_name = reference_group ref_data = prepared[ref_name] for group in group_order: color = group_colors[group] label = group_labels[group] data = prepared[group] ax_trace.plot( time_axis, data[channel_idx], color=color, lw=1.3, alpha=0.9, label=label, ) ax_gfp.plot( time_axis, _compute_gfp(data), color=color, lw=1.3, alpha=0.9, label=label, ) if group != ref_name: ax_diff.plot( time_axis, data[channel_idx] - ref_data[channel_idx], color=color, lw=1.1, alpha=0.85, label=f"{label} - {ref_name}", ) for win_idx, (start, stop, _) in enumerate(window_specs): shade = get_series_color(win_idx) for panel in (ax_trace, ax_gfp, ax_diff): panel.axvspan(start, stop, color=shade, alpha=0.10, lw=0) ax_trace.set_xlabel("Time", fontsize=FONTS["label"]) ax_trace.set_ylabel("Amplitude", fontsize=FONTS["label"]) ax_trace.set_title( f"(a) Channel overlay ({channel_label})", fontsize=FONTS["title"] ) style_axes(ax_trace, grid=True) themed_legend(ax_trace) ax_gfp.set_xlabel("Time", fontsize=FONTS["label"]) ax_gfp.set_ylabel("Global Field Power", fontsize=FONTS["label"]) ax_gfp.set_title("(b) GFP comparison", fontsize=FONTS["title"]) style_axes(ax_gfp, grid=True) themed_legend(ax_gfp) ax_diff.axhline(0.0, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.5) ax_diff.set_xlabel("Time", fontsize=FONTS["label"]) ax_diff.set_ylabel("Difference", fontsize=FONTS["label"]) ax_diff.set_title(f"(c) Difference vs {ref_name}", fontsize=FONTS["title"]) style_axes(ax_diff, grid=True) themed_legend(ax_diff) rows = [("reference_group", ref_name), ("channel", channel_label)] for group in group_order: rms = np.sqrt(np.mean(prepared[group][channel_idx] ** 2)) rows.append((f"{group} RMS", f"{rms:.4g}")) for start, stop, label in window_specs: mask = (time_axis >= start) & (time_axis <= stop) if not mask.any(): continue ref_mean = np.mean(ref_data[channel_idx, mask]) rows.append((f"{label} mean ({ref_name})", f"{ref_mean:.4g}")) _draw_summary_table(ax_table, rows, title="(d) Summary") fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_condition_interaction_summary( traces, times, errors=None, condition_order=None, group_order=None, group_colors=None, group_labels=None, windows=None, title="Condition Interaction Summary", show=True, fname=None, ): """Plot condition-by-group interaction traces. Parameters ---------- traces : mapping[str, mapping[str, array-like]] Nested mapping ``condition -> group -> trace``. times : array-like of shape (n_times,) Explicit time axis. errors : mapping | None Optional nested mapping ``condition -> group -> standard error``. condition_order : sequence of str | None Optional condition order. Defaults to first-seen order. group_order : sequence of str | None Optional group order. Defaults to first-seen order across conditions. group_colors, group_labels : mapping | None Optional color/label overrides keyed by group. windows : sequence of tuple[float, float, str] | None Optional highlighted windows as explicit ``(start, stop, label)`` tuples. title : str Figure title. show : bool Whether to show the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``traces`` is empty, ``times`` is not 1D, or trace/error arrays do not match the expected time dimension. Notes ----- Conditions are rendered as separate panels and groups as overlaid lines within each panel. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_condition_interaction_summary >>> times = np.linspace(-0.2, 0.5, 300) >>> traces = { ... "A": {"before": np.sin(times), "after": 0.7 * np.sin(times)}, ... "B": {"before": np.cos(times), "after": 0.6 * np.cos(times)}, ... } >>> fig = plot_condition_interaction_summary( ... traces, ... times=times, ... windows=[(0.08, 0.14, "w1")], ... show=False, ... ) """ if not isinstance(traces, Mapping) or len(traces) == 0: raise ValueError("traces must be a non-empty mapping of conditions.") if condition_order is None: condition_order = list(traces.keys()) else: condition_order = list(condition_order) if group_order is None: seen = [] for condition in condition_order: for group in traces[condition]: if group not in seen: seen.append(group) group_order = seen else: group_order = list(group_order) time_axis = np.asarray(times, dtype=float) if time_axis.ndim != 1: raise ValueError("times must be 1D.") fig, axes = themed_figure( 1, len(condition_order), figsize=(6 * len(condition_order), 5), squeeze=False ) window_specs = [] if windows is None else list(windows) for cond_idx, condition in enumerate(condition_order): ax = axes[0, cond_idx] for win_idx, (start, stop, _) in enumerate(window_specs): ax.axvspan(start, stop, color=get_series_color(win_idx), alpha=0.10, lw=0) for group_idx, group in enumerate(group_order): if group not in traces[condition]: continue trace = np.asarray(traces[condition][group], dtype=float) if trace.ndim != 1 or trace.size != time_axis.size: raise ValueError("All traces must be 1D and match times length.") color = ( group_colors[group] if group_colors is not None and group in group_colors else get_series_color(group_idx) ) label = ( group_labels[group] if group_labels is not None and group in group_labels else group ) ax.plot(time_axis, trace, color=color, lw=1.4, alpha=0.9, label=label) if ( errors is not None and condition in errors and group in errors[condition] ): se = np.asarray(errors[condition][group], dtype=float) if se.shape != trace.shape: raise ValueError( "errors traces must match corresponding trace shape." ) ax.fill_between( time_axis, trace - se, trace + se, color=color, alpha=0.15, lw=0 ) ax.axhline(0.0, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.5) ax.set_xlabel("Time", fontsize=FONTS["label"]) if cond_idx == 0: ax.set_ylabel("Amplitude", fontsize=FONTS["label"]) ax.set_title(str(condition), fontsize=FONTS["title"]) style_axes(ax, grid=True) themed_legend(ax) fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_group_condition_interaction_summary( traces, times, errors=None, group_order=None, condition_order=None, condition_labels=None, windows=None, title="Group Condition Interaction Summary", show=True, fname=None, ): """Plot group-wise condition interaction traces. Parameters ---------- traces : mapping[str, mapping[str, array-like]] Nested mapping ``group -> condition -> trace``. times : array-like of shape (n_times,) Explicit time axis. errors : mapping | None Optional nested mapping ``group -> condition -> standard error``. group_order : sequence of str | None Optional group order. Defaults to first-seen order. condition_order : sequence of str | None Optional condition order. Defaults to first-seen across groups. condition_labels : mapping | None Optional label overrides keyed by condition. windows : sequence of tuple[float, float, str] | None Optional highlighted windows as explicit ``(start, stop, label)`` tuples. title : str Figure title. show : bool Whether to show the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``traces`` is empty, ``times`` is not 1D, or trace/error arrays do not match the expected time dimension. Notes ----- Groups are rendered as separate panels and conditions as overlaid lines within each panel. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_group_condition_interaction_summary >>> times = np.linspace(-0.2, 0.5, 300) >>> traces = { ... "before": {"A": np.sin(times), "B": np.cos(times)}, ... "after": {"A": 0.7 * np.sin(times), "B": 0.6 * np.cos(times)}, ... } >>> fig = plot_group_condition_interaction_summary( ... traces, ... times=times, ... windows=[(0.08, 0.14, "w1")], ... show=False, ... ) """ if not isinstance(traces, Mapping) or len(traces) == 0: raise ValueError("traces must be a non-empty mapping of groups.") if group_order is None: group_order = list(traces.keys()) else: group_order = list(group_order) if condition_order is None: seen = [] for group in group_order: for condition in traces[group]: if condition not in seen: seen.append(condition) condition_order = seen else: condition_order = list(condition_order) time_axis = np.asarray(times, dtype=float) if time_axis.ndim != 1: raise ValueError("times must be 1D.") window_specs = [] if windows is None else list(windows) fig, axes = themed_figure( 1, len(group_order), figsize=(6 * len(group_order), 5), squeeze=False ) for group_idx, group in enumerate(group_order): ax = axes[0, group_idx] for win_idx, (start, stop, _) in enumerate(window_specs): ax.axvspan(start, stop, color=get_series_color(win_idx), alpha=0.10, lw=0) for cond_idx, condition in enumerate(condition_order): if condition not in traces[group]: continue trace = np.asarray(traces[group][condition], dtype=float) if trace.ndim != 1 or trace.size != time_axis.size: raise ValueError("All traces must be 1D and match times length.") label = ( condition_labels[condition] if condition_labels is not None and condition in condition_labels else condition ) color = get_series_color(cond_idx) ax.plot(time_axis, trace, color=color, lw=1.4, alpha=0.9, label=label) if errors is not None and group in errors and condition in errors[group]: se = np.asarray(errors[group][condition], dtype=float) if se.shape != trace.shape: raise ValueError( "errors traces must match corresponding trace shape." ) ax.fill_between( time_axis, trace - se, trace + se, color=color, alpha=0.15, lw=0 ) ax.axhline(0.0, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.5) ax.set_xlabel("Time", fontsize=FONTS["label"]) if group_idx == 0: ax.set_ylabel("Amplitude", fontsize=FONTS["label"]) ax.set_title(str(group), fontsize=FONTS["title"]) style_axes(ax, grid=True) themed_legend(ax) fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_endpoint_metrics_summary( data, metric_col, group_col="group", subject_col="subject", group_order=None, group_colors=None, group_labels=None, reference_value=None, reference_label="Reference", null_distribution=None, observed_value=None, title="Endpoint Metrics Summary", show=True, fname=None, ): """Plot a generic endpoint-metric storyboard. Parameters ---------- data : mapping[str, array-like] Columnar mapping with at least ``metric_col``, ``group_col``, and ``subject_col``. metric_col : str Metric column to summarize. group_col : str Group identifier column. subject_col : str Subject identifier column. group_order : sequence of str | None Optional group order. Defaults to first-seen order. group_colors, group_labels : mapping | None Optional color/label overrides keyed by group. reference_value : float | None Optional horizontal reference line for metric panels. reference_label : str Label for ``reference_value``. null_distribution : array-like | None Optional null distribution for the null panel. observed_value : float | None Observed statistic for the null panel. title : str Figure title. show : bool Whether to show the figure. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ KeyError If required columns are missing from ``data``. ValueError If numeric conversion or plotting operations receive incompatible shapes. Notes ----- The output is a 2x2 storyboard combining grouped means, paired subject trajectories, per-group distributions, and optional null-distribution diagnostics. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_endpoint_metrics_summary >>> data = { ... "subject": np.array(["s1", "s1", "s2", "s2"]), ... "group": np.array(["A", "B", "A", "B"]), ... "score": np.array([1.2, 0.9, 1.1, 0.8]), ... } >>> fig = plot_endpoint_metrics_summary( ... data, ... metric_col="score", ... show=False, ... ) """ columns = {name: np.asarray(values) for name, values in data.items()} metric = np.asarray(columns[metric_col], dtype=float) groups = np.asarray(columns[group_col], dtype=object) subjects = np.asarray(columns[subject_col], dtype=object) if group_order is None: group_order = list(dict.fromkeys(groups.tolist())) else: group_order = list(group_order) fig, axes = themed_figure(2, 2, figsize=(14, 9)) if isinstance(axes, np.ndarray): ax_a, ax_b, ax_c, ax_d = axes.ravel() else: ax_a = ax_b = ax_c = ax_d = axes means = [] sems = [] for group in group_order: vals = metric[groups == group] vals = vals[np.isfinite(vals)] means.append(vals.mean() if vals.size else np.nan) if vals.size > 1: sems.append(vals.std(ddof=1) / np.sqrt(vals.size)) else: sems.append(0.0) x = np.arange(len(group_order)) bar_colors = [ group_colors[group] if group_colors is not None and group in group_colors else get_series_color(i) for i, group in enumerate(group_order) ] ax_a.bar( x, means, yerr=sems, color=bar_colors, edgecolor=COLORS["edge"], linewidth=0.6, alpha=0.85, ) if reference_value is not None: ax_a.axhline( reference_value, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.6, label=reference_label, ) themed_legend(ax_a) ax_a.set_xticks(x) ax_a.set_xticklabels( [ group_labels[group] if group_labels is not None and group in group_labels else group for group in group_order ], fontsize=FONTS["tick"], ) ax_a.set_ylabel(metric_col, fontsize=FONTS["label"]) ax_a.set_title("(a) Group mean ± SEM", fontsize=FONTS["title"]) style_axes(ax_a, grid=True) subj_order = list(dict.fromkeys(subjects.tolist())) traces = np.full((len(subj_order), len(group_order)), np.nan, dtype=float) for subj_idx, subject in enumerate(subj_order): mask_subject = subjects == subject for grp_idx, group in enumerate(group_order): vals = metric[mask_subject & (groups == group)] vals = vals[np.isfinite(vals)] if vals.size: traces[subj_idx, grp_idx] = vals[0] ax_b.plot( range(len(group_order)), traces[subj_idx], "o-", color=COLORS["stat_subject"], alpha=0.35, markersize=4, lw=0.8, ) mean_trace = np.nanmean(traces, axis=0) ax_b.plot( range(len(group_order)), mean_trace, "s-", color=COLORS["stat_mean"], markersize=6, lw=1.6, label="Mean", ) if reference_value is not None: ax_b.axhline( reference_value, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.6, label=reference_label, ) ax_b.set_xticks(range(len(group_order))) ax_b.set_xticklabels( [ group_labels[group] if group_labels is not None and group in group_labels else group for group in group_order ], fontsize=FONTS["tick"], ) ax_b.set_ylabel(metric_col, fontsize=FONTS["label"]) ax_b.set_title("(b) Paired subject trajectories", fontsize=FONTS["title"]) style_axes(ax_b, grid=True) themed_legend(ax_b) bp = ax_c.boxplot( [metric[(groups == group) & np.isfinite(metric)] for group in group_order], labels=[ group_labels[group] if group_labels is not None and group in group_labels else group for group in group_order ], patch_artist=True, ) for patch, color in zip(bp["boxes"], bar_colors): patch.set_facecolor(color) patch.set_alpha(0.65) patch.set_edgecolor(COLORS["edge"]) if reference_value is not None: ax_c.axhline( reference_value, color=COLORS["stat_reference"], ls="--", lw=0.8, alpha=0.6 ) ax_c.set_ylabel(metric_col, fontsize=FONTS["label"]) ax_c.set_title("(c) Distribution by group", fontsize=FONTS["title"]) style_axes(ax_c, grid=True) if null_distribution is not None: null_distribution = np.asarray(null_distribution, dtype=float) null_distribution = null_distribution[np.isfinite(null_distribution)] ax_d.hist( null_distribution, bins=30, color=COLORS["muted"], alpha=0.65, edgecolor=COLORS["edge"], linewidth=0.4, label="Null", ) if observed_value is not None: ax_d.axvline( observed_value, color=COLORS["accent"], lw=1.2, ls="--", label="Observed", ) if null_distribution.size: p_val = np.mean(np.abs(null_distribution) >= abs(observed_value)) ax_d.text( 0.98, 0.95, f"p={p_val:.4g}", transform=ax_d.transAxes, ha="right", va="top", fontsize=FONTS["annotation"], ) ax_d.set_xlabel(metric_col, fontsize=FONTS["label"]) ax_d.set_title("(d) Null distribution", fontsize=FONTS["title"]) themed_legend(ax_d) style_axes(ax_d, grid=True) else: rows = [ ("metric", metric_col), ("n_subjects", len(subj_order)), ("n_groups", len(group_order)), ("reference", reference_value if reference_value is not None else "None"), ] _draw_summary_table(ax_d, rows, title="(d) Summary") fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname)