Source code for mne_denoise.viz.components

"""Component-level visualization primitives.

This module contains:
1. Component score/eigenvalue curves.
2. Spatial component pattern plots with topomap or line fallback.
3. Source-space summaries in time, epoch-image, and spectrogram form.

These functions are method-agnostic and can be used with any fitted
estimator exposing component attributes such as ``patterns_``, ``scores_``,
``eigenvalues_``, or component sources via ``transform``.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
         Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

import mne
import numpy as np
from matplotlib.gridspec import GridSpec
from mne.time_frequency import tfr_array_multitaper

from ._utils import _get_components, _get_info, _get_patterns, _get_scores
from .theme import (
    COLORS,
    DIVERGING_CMAP,
    FONTS,
    SEQUENTIAL_CMAP,
    _finalize_fig,
    get_series_color,
    style_axes,
    themed_figure,
    themed_legend,
)


def _resolve_component_indices(
    n_components,
    n_available,
    default_max,
):
    """Normalize component selection to an explicit list of indices."""
    if n_components is None:
        return list(range(min(default_max, n_available)))

    if isinstance(n_components, int):
        return list(range(min(n_components, n_available)))

    indices = [int(idx) for idx in n_components]
    invalid = [idx for idx in indices if idx < 0 or idx >= n_available]
    if invalid:
        raise ValueError(f"Component indices out of range: {invalid}")
    return indices


[docs] def plot_component_score_curve( estimator, mode="raw", ax=None, show=True, fname=None, ): """Plot a 1D component score curve for a fitted estimator. Parameters ---------- estimator : object Fitted estimator exposing ``eigenvalues_`` or ``scores_``. mode : {'raw', 'cumulative', 'ratio'} Score display mode: - ``'raw'``: raw score/eigenvalue per component. - ``'cumulative'``: normalized cumulative sum. - ``'ratio'``: same values as ``'raw'`` but labeled as a ratio view. ax : matplotlib.axes.Axes | None Target axes. If None, a new themed figure is created. show : bool, default=True If True, show the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``mode`` is invalid, or if scores are missing/invalid. Notes ----- When available, the function overlays a dashed vertical cutoff using ``n_selected_`` or ``n_removed_`` from the estimator. Examples -------- >>> from mne_denoise.viz import plot_component_score_curve >>> fig = plot_component_score_curve(estimator, mode="raw", show=False) """ valid_modes = {"raw", "cumulative", "ratio"} if mode not in valid_modes: raise ValueError(f"mode must be one of {sorted(valid_modes)}") scores = _get_scores(estimator) if scores is None: raise ValueError("Estimator does not expose component scores.") scores = np.asarray(scores, dtype=float) if scores.ndim != 1 or scores.size == 0: raise ValueError("Component scores must be a non-empty 1D array.") if ax is None: fig, ax = themed_figure(figsize=(7, 4)) else: fig = ax.figure x = np.arange(1, scores.size + 1) if mode == "cumulative": y = np.cumsum(scores) y = y / y[-1] ylabel = "Cumulative Score (Normalized)" elif mode == "ratio": y = scores ylabel = "Power Ratio" else: y = scores ylabel = "Score / Eigenvalue" ax.plot( x, y, ".-", color=COLORS["primary"], linewidth=1.6, markersize=5, label="Scores", ) if mode != "cumulative": mean_score = np.mean(scores) ax.axhline( mean_score, color=COLORS["muted"], linestyle="--", linewidth=0.9, label=f"Mean ({mean_score:.3g})", ) n_selected = getattr(estimator, "n_selected_", None) if n_selected is None: n_selected = getattr(estimator, "n_removed_", None) if n_selected is not None and 0 < n_selected < scores.size: ax.axvline( n_selected + 0.5, color=COLORS["accent"], linestyle="--", linewidth=1.0, label=f"Cutoff ({n_selected})", ) themed_legend(ax, loc="best") ax.set_xlabel("Component") ax.set_ylabel(ylabel) ax.set_title("Component Scores") style_axes(ax, grid=True) return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_window_score_traces( scores, threshold=None, ax=None, show=True, fname=None, ): """Plot per-window score traces from a 2D score matrix. Parameters ---------- scores : array-like of shape (n_windows, n_scores) Score matrix to display. threshold : float | None Optional horizontal threshold line. ax : matplotlib.axes.Axes | None Target axes. If None, a new themed figure is created. show : bool, default=True If True, display the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. """ scores = np.asarray(scores, dtype=float) if scores.ndim != 2 or scores.shape[0] == 0: raise ValueError("scores must be a non-empty 2D array.") if ax is None: fig, ax = themed_figure(figsize=(9, 4)) else: fig = ax.figure n_windows, n_scores = scores.shape for idx in range(n_scores): vals = scores[:, idx] valid = np.isfinite(vals) if not np.any(valid): continue ax.plot( np.where(valid)[0], vals[valid], color=get_series_color(idx), linewidth=1.2, alpha=0.85, label=f"Score {idx + 1}", ) if threshold is not None: ax.axhline( float(threshold), color=COLORS["accent"], linestyle="--", linewidth=1.0, label=f"Threshold ({float(threshold):.3g})", ) ax.set_xlabel("Window") ax.set_ylabel("Score") ax.set_title("Window Score Traces") style_axes(ax, grid=True) if n_scores <= 10 or threshold is not None: themed_legend(ax, loc="best") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_patterns( estimator, info=None, picks=None, n_components=None, ax=None, show=True, fname=None, ): """Plot spatial component patterns. When compatible MNE channel information is available, the patterns are rendered as topomaps. Otherwise, the function falls back to plotting the selected component weights across channels on a standard axes. Parameters ---------- estimator : object Fitted estimator exposing ``patterns_``. info : mne.Info | None Measurement info used for topomap rendering. picks : array-like of int | None Channel indices used for topomap rendering. If None, no topomap is attempted and the function falls back to channel-weight line plots. n_components : int | sequence of int | None Components to plot. If an int, plot the first ``n_components``. ax : matplotlib.axes.Axes | None Optional target axes. Supported only for the line-plot fallback or when rendering a single topomap. show : bool, default=True If True, show the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If patterns are not 2D, if no components are selected, or when ``ax`` is passed while requesting multiple topomaps. Also raised when ``picks`` is provided without valid ``info``. Notes ----- Topomap rendering is explicit: pass both ``info`` and ``picks``. This function does not infer channel picks automatically. Examples -------- >>> from mne_denoise.viz import plot_component_patterns >>> fig = plot_component_patterns( ... estimator, ... info=info, ... picks=[0, 1, 2, 3], ... n_components=4, ... show=False, ... ) """ patterns = np.asarray(_get_patterns(estimator)) if patterns.ndim != 2: raise ValueError( "patterns_ must be a 2D array of shape (n_channels, n_components)." ) indices = _resolve_component_indices( n_components, patterns.shape[1], default_max=6, ) if not indices: raise ValueError("No components selected for plotting.") if picks is not None and info is None: raise ValueError("info is required when picks is provided.") if picks is not None: topo_info = mne.pick_info(info, picks) if ax is not None: if len(indices) != 1: raise ValueError("ax can only be used when plotting a single topomap.") fig = ax.figure mne.viz.plot_topomap( patterns[picks, indices[0]], topo_info, axes=ax, show=False, contours=4, ) ax.set_title(f"Comp {indices[0]}") return _finalize_fig(fig, show=show, fname=fname) n_show = len(indices) n_cols = min(4, n_show) n_rows = int(np.ceil(n_show / n_cols)) fig, axes = themed_figure( n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), squeeze=False, ) flat_axes = axes.ravel() for i, (plot_ax, comp_idx) in enumerate(zip(flat_axes, indices)): mne.viz.plot_topomap( patterns[picks, comp_idx], topo_info, axes=plot_ax, show=False, contours=4, ) plot_ax.set_title( f"Comp {comp_idx}", fontsize=FONTS["tick"], color=get_series_color(i), ) for plot_ax in flat_axes[len(indices) :]: plot_ax.axis("off") fig.suptitle( "Component Patterns", fontsize=FONTS["title"], fontweight="semibold" ) return _finalize_fig(fig, show=show, fname=fname) if ax is None: fig, ax = themed_figure(figsize=(8, 4.5)) else: fig = ax.figure for i, comp_idx in enumerate(indices): ax.plot( patterns[:, comp_idx], marker="o", markersize=4, linewidth=1.3, alpha=0.85, color=get_series_color(i), label=f"Comp {comp_idx}", ) ax.axhline(0, color=COLORS["muted"], linestyle="-", alpha=0.35) ax.set_xlabel("Channel") ax.set_ylabel("Pattern Weight") ax.set_title("Component Patterns") style_axes(ax, grid=True) themed_legend(ax, loc="best") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_summary( estimator, data=None, info=None, picks=None, times=None, sfreq=None, n_components=None, psd_fmax=None, show=True, plot_ci=True, fname=None, ): """Plot a compact per-component summary dashboard. Each selected component is displayed in one row with: 1) spatial pattern, 2) time course (mean ± CI for epoched sources), 3) power spectral density. Parameters ---------- estimator : object Fitted estimator exposing component patterns and a transform/source API. data : mne.io.BaseRaw | mne.BaseEpochs | ndarray | None Input data used to compute component sources when they are not cached. info : mne.Info | None Sensor metadata for topomap rendering. picks : array-like of int | None Channel indices used for topomap rendering. If None, the pattern panel uses a text placeholder instead of topomaps. times : array-like of shape (n_times,) | None Explicit time coordinates for source time-course panels. If None, sample indices are used. sfreq : float | None Sampling frequency used for PSD computation when ``info`` is not available. Required if ``info`` cannot be resolved. n_components : int | sequence of int | None Components to plot. If None, plot up to five components. psd_fmax : float | None Maximum frequency (Hz) shown in the PSD column. If None, defaults to ``min(100, sfreq / 2)`` to preserve previous behavior. show : bool, default=True If True, show the figure. plot_ci : bool, default=True If True and sources are epoched, overlay a 95% CI band based on SEM. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If no components are selected, if ``psd_fmax`` is not positive, if ``times`` length mismatches source length, or if ``picks`` is provided without valid ``info``. Also raised when neither ``info`` nor ``sfreq`` is provided. Notes ----- Topomap rendering and time coordinates are explicit in this function. It does not infer channel picks or time axes. Examples -------- >>> from mne_denoise.viz import plot_component_summary >>> fig = plot_component_summary( ... estimator, ... data=epochs, ... sfreq=epochs.info["sfreq"], ... info=info, ... picks=[0, 1, 2, 3], ... times=epochs.times, ... n_components=3, ... psd_fmax=80, ... show=False, ... ) """ if picks is not None and info is None: raise ValueError("info is required when picks is provided.") info = _get_info(estimator, info) patterns = np.asarray(_get_patterns(estimator)) sources = _get_components(estimator, data) indices = _resolve_component_indices( n_components, patterns.shape[1], default_max=5, ) if not indices: raise ValueError("No components selected for plotting.") fig, root_ax = themed_figure(figsize=(12, 3 * len(indices))) root_ax.remove() gs = GridSpec(len(indices), 3, figure=fig, width_ratios=[1, 2, 1]) if info is not None: sfreq_eff = float(info["sfreq"]) elif sfreq is not None: sfreq_eff = float(sfreq) else: raise ValueError("sfreq is required when info is not available.") if sfreq_eff <= 0: raise ValueError("sfreq must be strictly positive.") if times is None: times_template = np.arange(sources.shape[1]) time_label = "Time (samples)" else: times_template = np.asarray(times) if times_template.shape[0] != sources.shape[1]: raise ValueError("times must have length equal to source n_times.") time_label = "Time" if psd_fmax is None: psd_fmax = min(100.0, sfreq_eff / 2.0) psd_fmax = float(psd_fmax) if psd_fmax <= 0: raise ValueError("psd_fmax must be strictly positive.") psd_fmax = min(psd_fmax, sfreq_eff / 2.0) for row_idx, comp_idx in enumerate(indices): ax_topo = fig.add_subplot(gs[row_idx, 0]) if picks is not None: topo_info = mne.pick_info(info, picks) topo_data = patterns[picks, comp_idx] mne.viz.plot_topomap(topo_data, topo_info, axes=ax_topo, show=False) ax_topo.set_title(f"Comp {comp_idx} Pattern") else: ax_topo.text(0.5, 0.5, "No topomap info", ha="center", va="center") ax_topo.set_axis_off() ax_time = fig.add_subplot(gs[row_idx, 1]) if sources.ndim == 3: comp_data = sources[comp_idx] mean_tc = comp_data.mean(axis=1) ax_time.plot(times_template, mean_tc, label="Mean", color=COLORS["before"]) if plot_ci: std_tc = comp_data.std(axis=1) / np.sqrt(comp_data.shape[1]) ax_time.fill_between( times_template, mean_tc - 2 * std_tc, mean_tc + 2 * std_tc, color=COLORS["muted"], alpha=0.3, label="95% CI (SEM)", ) themed_legend(ax_time, loc="best") else: comp_data = sources[comp_idx] ax_time.plot(times_template, comp_data, color=COLORS["before"]) ax_time.set_title(f"Comp {comp_idx} Time Course") ax_time.set_xlabel(time_label) style_axes(ax_time, grid=True) ax_psd = fig.add_subplot(gs[row_idx, 2]) if sources.ndim == 3: d_flat = sources[comp_idx].T else: d_flat = sources[comp_idx][np.newaxis, :] psd_spec, freqs = mne.time_frequency.psd_array_welch( d_flat, sfreq=sfreq_eff, fmin=0, fmax=psd_fmax, n_fft=min(2048, d_flat.shape[-1]), verbose=False, ) mean_psd = np.mean(psd_spec, axis=0) ax_psd.semilogy(freqs, mean_psd, color=COLORS["primary"]) ax_psd.set_title("PSD") ax_psd.set_xlabel("Frequency (Hz)") ax_psd.set_xlim(0, psd_fmax) style_axes(ax_psd, grid=True) return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_epochs_image( estimator, data=None, n_components=None, show=True, fname=None, ): """Plot component activity as an epoch-by-time image. Parameters ---------- estimator : object Fitted estimator exposing component sources via cache or transform. data : mne.io.BaseRaw | mne.BaseEpochs | ndarray | None Input data used to compute sources when they are not cached. n_components : int | sequence of int | None Components to plot. If None, plot up to five components. show : bool, default=True If True, show the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If sources are not 2D/3D, or if no components are selected. Notes ----- Input source shapes are interpreted as: - ``(n_components, n_times)`` for a single average/time series. - ``(n_components, n_times, n_epochs)`` for epoched sources. Examples -------- >>> from mne_denoise.viz import plot_component_epochs_image >>> fig = plot_component_epochs_image( ... estimator, data=epochs, n_components=[0, 1], show=False ... ) """ sources = _get_components(estimator, data) if sources.ndim == 2: sources = sources[:, :, np.newaxis] if sources.ndim != 3: raise ValueError("Component sources must be 2D or 3D.") indices = _resolve_component_indices( n_components, sources.shape[0], default_max=5, ) if not indices: raise ValueError("No components selected for plotting.") fig, axes = themed_figure( len(indices), 1, figsize=(8, 2 * len(indices)), sharex=True, squeeze=False, ) axes = axes.ravel() for ax, comp_idx in zip(axes, indices): img = sources[comp_idx].T ax.imshow(img, aspect="auto", origin="lower", cmap=DIVERGING_CMAP) ax.set_title(f"Comp {comp_idx}") ax.set_ylabel("Epochs") axes[-1].set_xlabel("Time (samples)") return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_time_series( estimator, data=None, n_components=None, times=None, show=True, ax=None, fname=None, ): """Plot stacked component time series with fixed vertical offsets. Parameters ---------- estimator : object Fitted estimator exposing component sources via cache or transform. data : mne.io.BaseRaw | mne.BaseEpochs | ndarray | None Input data used to compute sources when they are not cached. n_components : int | sequence of int | None Components to plot. If None, plot up to twenty components. times : array-like of shape (n_times,) | None Explicit time coordinates. If None, sample indices are used. show : bool, default=True If True, show the figure. ax : matplotlib.axes.Axes | None Optional target axes. If None, a new themed figure is created. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If no components are selected or if ``times`` length mismatches source length. Notes ----- Each component is z-scored independently before plotting so that traces are comparable in amplitude and can be stacked with a fixed offset. Examples -------- >>> from mne_denoise.viz import plot_component_time_series >>> fig = plot_component_time_series( ... estimator, ... data=raw, ... times=raw.times, ... show=False, ... ) """ sources = _get_components(estimator, data) scores = _get_scores(estimator) if sources.ndim == 3: sources = sources.mean(axis=2) indices = _resolve_component_indices( n_components, sources.shape[0], default_max=20, ) if not indices: raise ValueError("No components selected for plotting.") if ax is None: fig, ax = themed_figure(figsize=(10, max(4.0, len(indices) * 0.5))) else: fig = ax.figure if times is None: time_axis = np.arange(sources.shape[1]) time_label = "Time (samples)" else: time_axis = np.asarray(times) if time_axis.shape[0] != sources.shape[1]: raise ValueError("times must have length equal to source n_times.") time_label = "Time" x_min = float(time_axis[0]) x_max = float(time_axis[-1]) x_pad = 0.03 * (x_max - x_min if x_max != x_min else 1.0) label_x = x_max + x_pad * 0.25 offset_step = 3.0 for row_idx, comp_idx in enumerate(indices): comp = sources[comp_idx] std = np.std(comp) if std < 1e-15: std = 1.0 comp_norm = comp / std offset = -row_idx * offset_step color = get_series_color(row_idx) ax.plot(time_axis, comp_norm + offset, color=color, linewidth=1.5) label = f"Comp {comp_idx}" if scores is not None and comp_idx < len(scores): label += f" (λ={scores[comp_idx]:.2f})" ax.text( label_x, offset, label, va="center", fontsize=FONTS["tick"], color=color ) ax.set_xlim(x_min, x_max + x_pad) ax.set_yticks([]) ax.set_xlabel(time_label) ax.set_title("Component Time Series") ax.spines["left"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_spectrogram( component_data, sfreq, freqs=None, fmax=50.0, n_cycles=None, title="Component Spectrogram", ax=None, show=True, fname=None, ): """Plot a time-frequency power view for one component. Parameters ---------- component_data : ndarray, shape (n_times,) or (n_epochs, n_times) Single-component time series or repeated epochs of one component. sfreq : float Sampling frequency. freqs : ndarray | None Frequencies to compute. If None, frequencies are generated from 1 Hz to ``fmax`` (capped at Nyquist). fmax : float | None Upper frequency bound used when ``freqs`` is None. Defaults to 50 Hz to preserve prior behavior. n_cycles : float | ndarray | None Number of cycles for multitaper estimation. title : str Axes title. ax : matplotlib.axes.Axes | None Optional target axes. If None, a new themed figure is created. show : bool, default=True If True, show the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``component_data`` is not 1D/2D, or if ``fmax`` is not positive when ``freqs`` is None. Notes ----- A 1D input is treated as one pseudo-epoch. A 2D input is interpreted as ``(n_epochs, n_times)`` and averaged across epochs in power space. Examples -------- >>> from mne_denoise.viz import plot_component_spectrogram >>> fig = plot_component_spectrogram( ... component_data, sfreq=250.0, fmax=80, show=False ... ) """ component_data = np.asarray(component_data) if component_data.ndim == 1: data = component_data[np.newaxis, np.newaxis, :] elif component_data.ndim == 2: data = component_data[:, np.newaxis, :] else: raise ValueError("component_data must be 1D or 2D.") if freqs is None: if fmax is None: upper = sfreq / 2.0 else: upper = min(float(fmax), sfreq / 2.0) if upper <= 0: raise ValueError("fmax must be strictly positive when freqs is None.") upper = max(2.0, upper) freqs = np.arange(1.0, np.floor(upper) + 1.0, 1.0) else: freqs = np.asarray(freqs, dtype=float) if n_cycles is None: n_cycles = freqs / 4.0 tfr = tfr_array_multitaper( data, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output="power", verbose=False, ) power = tfr[:, 0].mean(axis=0) times = np.arange(power.shape[1]) / sfreq if ax is None: fig, ax = themed_figure(figsize=(10, 5)) else: fig = ax.figure im = ax.pcolormesh(times, freqs, power, shading="gouraud", cmap=SEQUENTIAL_CMAP) ax.set_ylabel("Frequency (Hz)") ax.set_xlabel("Time (s)") ax.set_title(title) fig.colorbar(im, ax=ax, label="Power") style_axes(ax, grid=False) return _finalize_fig(fig, show=show, fname=fname)