Source code for mne_denoise.viz.spectra

"""Spectral and time-frequency visualization primitives.

This module contains reusable, method-agnostic plots focused on
frequency-domain and time-frequency diagnostics.

This module contains:
1. PSD comparisons for before/after denoising outputs.
2. Component-spectrum comparisons for extracted sources.
3. Spectrogram and time-frequency mask visualizations.
4. Narrowband scan summaries for spectral sweeps.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
         Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
"""

from __future__ import annotations

import mne
import numpy as np
from scipy import signal

from .theme import (
    COLORS,
    DIVERGING_CMAP,
    FONTS,
    SEQUENTIAL_CMAP,
    _finalize_fig,
    get_series_color,
    style_axes,
    themed_figure,
    themed_legend,
)


def _compute_array_psd(data, sfreq, fmin, fmax):
    """Compute PSDs for array-like inputs using Welch's method."""
    data = np.asarray(data, dtype=float)
    if data.ndim == 1:
        data = data[np.newaxis, :]
    elif data.ndim > 2:
        data = data.reshape(-1, data.shape[-1])

    nperseg = min(data.shape[-1], int(sfreq * 2))
    freqs, psd = signal.welch(data, fs=sfreq, nperseg=nperseg, axis=-1)
    keep = (freqs >= fmin) & (freqs <= fmax)
    return freqs[keep], psd[..., keep]


def _compute_psd_matrix(inst, sfreq, fmin, fmax):
    """Return PSD matrix with shape ``(n_series, n_freqs)``."""
    if isinstance(inst, (mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked)):
        spectrum = inst.compute_psd(fmin=fmin, fmax=fmax)
        freqs = np.asarray(spectrum.freqs, dtype=float)
        psd = np.asarray(spectrum.get_data(return_freqs=False), dtype=float)
    else:
        if sfreq is None:
            raise ValueError("sfreq must be provided when plotting PSDs from arrays.")
        freqs, psd = _compute_array_psd(inst, sfreq=sfreq, fmin=fmin, fmax=fmax)
        freqs = np.asarray(freqs, dtype=float)
        psd = np.asarray(psd, dtype=float)

    return freqs, psd.reshape(-1, psd.shape[-1])


def _as_component_data(components):
    """Normalize component inputs to canonical 2D shape ``(n_components, n_times)``."""
    if isinstance(components, (mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked)):
        data = np.asarray(components.get_data(), dtype=float)
    else:
        data = np.asarray(components, dtype=float)

    if data.ndim == 1:
        return data[np.newaxis, :]
    if data.ndim == 2:
        return data
    if data.ndim == 3:
        return data.mean(axis=0)
    raise ValueError(
        "components must be 1D, 2D, or 3D data, or an MNE Raw/Epochs/Evoked object."
    )


def _compute_array_spectrogram(data, picks, sfreq, fmin, fmax, n_freqs):
    """Compute a mean channel spectrogram for canonical array inputs."""
    data = np.asarray(data, dtype=float)
    if data.ndim == 2:
        selected = data[picks, :]
    elif data.ndim == 3:
        selected = data[:, picks, :].reshape(-1, data.shape[-1])
    else:
        raise ValueError(
            "Array spectrogram inputs must be 2D (n_channels, n_times) "
            "or 3D (n_epochs, n_channels, n_times)."
        )

    nperseg = min(selected.shape[-1], int(sfreq * 2))
    noverlap = max(0, nperseg // 2)
    freqs, _, spec = signal.spectrogram(
        selected, fs=sfreq, nperseg=nperseg, noverlap=noverlap, axis=-1
    )
    spec = np.asarray(spec, dtype=float)
    spec = spec.mean(axis=0)

    keep = (freqs >= fmin) & (freqs <= fmax)
    freqs = freqs[keep]
    spec = spec[keep]
    if freqs.size < 2:
        raise ValueError("Could not compute a valid frequency grid in [fmin, fmax].")

    target_freqs = np.linspace(fmin, fmax, n_freqs, dtype=float)
    interp_spec = np.empty((target_freqs.size, spec.shape[1]), dtype=float)
    for time_idx in range(spec.shape[1]):
        interp_spec[:, time_idx] = np.interp(
            target_freqs,
            freqs,
            spec[:, time_idx],
        )
    return target_freqs, interp_spec


def _add_colorbar(fig, ax, image, label):
    """Add a lightly styled colorbar to an axis."""
    colorbar = fig.colorbar(image, ax=ax, pad=0.02)
    colorbar.set_label(label, fontsize=FONTS["label"])
    colorbar.ax.tick_params(labelsize=FONTS["tick"])
    colorbar.outline.set_edgecolor(COLORS["edge"])
    colorbar.outline.set_linewidth(0.5)
    return colorbar


def _add_line_markers(ax, line_freq, fmax):
    """Add line-frequency markers and visible harmonics to an axis."""
    if line_freq is None:
        return

    ax.axvline(
        line_freq,
        color=COLORS["line_marker"],
        linestyle="--",
        alpha=0.7,
        label=f"{line_freq:g} Hz",
    )

    harmonic = 2
    while line_freq * harmonic <= fmax:
        ax.axvline(
            line_freq * harmonic,
            color=COLORS["line_marker"],
            linestyle="--",
            alpha=0.3,
        )
        harmonic += 1


[docs] def plot_narrowband_score_scan( frequencies, eigenvalues, peak_freq=None, true_freqs=None, ax=None, show=True, fname=None, ): """Plot score/eigenvalue profiles from a narrowband scan. Parameters ---------- frequencies : array-like of shape (n_freqs,) Frequency grid used in the scan. eigenvalues : array-like of shape (n_freqs,) | (n_freqs, n_components) Scan scores. For 2D inputs, the first column is treated as dominant. peak_freq : float | None Optional frequency to highlight with a marker and vertical line. true_freqs : sequence of float | None Optional reference frequencies to mark. ax : matplotlib.axes.Axes | None Target axes. If None, create a new figure and axes. show : bool If True, display the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``frequencies`` is not 1D, if ``eigenvalues`` is not 1D/2D, or if their first dimensions do not match. Notes ----- This function is plotting-only and does not run frequency estimation. ``peak_freq`` and ``true_freqs`` are optional annotations supplied directly by the caller. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_narrowband_score_scan >>> freqs = np.linspace(6, 40, 50) >>> scores = np.exp(-0.5 * ((freqs - 12.0) / 1.5) ** 2) >>> fig = plot_narrowband_score_scan( ... freqs, scores, peak_freq=12.0, true_freqs=[12.0, 24.0], show=False ... ) """ frequencies = np.asarray(frequencies, dtype=float) eigenvalues = np.asarray(eigenvalues, dtype=float) if frequencies.ndim != 1: raise ValueError("frequencies must be a 1D array.") if eigenvalues.ndim not in (1, 2): raise ValueError("eigenvalues must be a 1D or 2D array.") if eigenvalues.shape[0] != frequencies.shape[0]: raise ValueError( "frequencies and eigenvalues must have matching first dimensions." ) if ax is None: fig, ax = themed_figure(figsize=(10, 4)) else: fig = ax.figure dominant = eigenvalues[:, 0] if eigenvalues.ndim == 2 else eigenvalues if eigenvalues.ndim == 2 and eigenvalues.shape[1] > 1: ax.plot( frequencies, eigenvalues[:, 1:], color=COLORS["muted"], linestyle="-", alpha=0.6, linewidth=1.2, ) ax.plot( frequencies, dominant, color=COLORS["primary"], marker="o", linestyle="-", markersize=4, linewidth=1.8, label="Dominant component", ) if peak_freq is not None: peak_idx = np.argmin(np.abs(frequencies - peak_freq)) ax.plot( peak_freq, dominant[peak_idx], color=COLORS["accent"], marker="*", linestyle="none", markersize=14, label=f"Peak: {peak_freq:.1f} Hz", ) ax.axvline(peak_freq, color=COLORS["accent"], linestyle="--", alpha=0.5) if true_freqs is not None: palette = [ COLORS["accent"], COLORS["success"], COLORS["secondary"], COLORS["purple"], ] for idx, freq in enumerate(true_freqs): ax.axvline( freq, color=palette[idx % len(palette)], linestyle="--", alpha=0.5, label=f"True: {freq:g} Hz", ) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_ylabel("Score / Eigenvalue", fontsize=FONTS["label"]) ax.set_title("Narrowband Score Scan", fontsize=FONTS["title"]) style_axes(ax, grid=True) if peak_freq is not None or true_freqs is not None: themed_legend(ax, loc="best") return _finalize_fig(fig, show=show, fname=fname, tight=False)
[docs] def plot_psd_comparison( inst_before, inst_after, fmin=0, fmax=np.inf, sfreq=None, line_freq=None, show=True, average=True, ax=None, fname=None, ): """Plot PSD comparison for original and denoised data. Parameters ---------- inst_before, inst_after : MNE object | ndarray Inputs to compare. Supported MNE inputs are Raw, Epochs, and Evoked. Array inputs are interpreted with the last axis as time. When either input is an array, ``sfreq`` must be provided. fmin, fmax : float Frequency bounds to display. sfreq : float | None Sampling frequency for array inputs. line_freq : float | None Optional line frequency to mark, along with visible harmonics. show : bool Whether to display the figure. average : bool If True, average PSDs across non-frequency axes. ax : Axes | None Optional axis to draw into. fname : path-like | None Optional output path. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If array inputs are used without ``sfreq``. Notes ----- PSD backend is selected by input type: - MNE inputs use ``compute_psd``. - Array inputs use SciPy Welch PSD. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_psd_comparison >>> before = np.random.randn(8, 2000) >>> after = before * 0.8 >>> fig = plot_psd_comparison(before, after, sfreq=250.0, show=False) """ if ax is None: fig, ax = themed_figure(figsize=(8, 4)) else: fig = ax.figure for inst, label, color in [ (inst_before, "Before", COLORS["before"]), (inst_after, "After", COLORS["after"]), ]: freqs, psd = _compute_psd_matrix(inst, sfreq=sfreq, fmin=fmin, fmax=fmax) if average: axis = tuple(range(psd.ndim - 1)) psd_mean = np.mean(psd, axis=axis) ax.semilogy(freqs, psd_mean, label=label, color=color) else: psd = psd.reshape(-1, psd.shape[-1]) ax.semilogy(freqs, psd.T, color=color, alpha=0.2) ax.plot([], [], color=color, label=label) display_fmax = float(np.max(freqs)) if np.isinf(fmax) else fmax _add_line_markers(ax, line_freq=line_freq, fmax=display_fmax) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) ax.set_title("PSD Comparison", fontsize=FONTS["title"]) ax.set_xlim(left=max(0.0, fmin), right=display_fmax) style_axes(ax, grid=True) themed_legend(ax, loc="best") return _finalize_fig(fig, show=show, fname=fname, tight=False)
def plot_psd_zoom_comparison( freqs_before, psd_before, freqs_after, psd_after, series_name="", title="", zoom_freqs=None, zoom_annotations=None, fmax=125.0, zoom_half_width_hz=8.0, series_colors=None, series_labels=None, fname=None, show=True, ): """Plot a PSD comparison plus zoomed panels around selected frequencies. Parameters ---------- freqs_before, freqs_after : array-like of shape (n_freqs,) Frequency vectors for the before/after PSD curves. psd_before, psd_after : array-like of shape (n_freqs,) PSD vectors aligned with ``freqs_before`` and ``freqs_after``. series_name : str Series key used for optional color/label mapping. title : str Optional figure suptitle. zoom_freqs : array-like of shape (n_zoom,) Frequency centers for zoom panels. zoom_annotations : sequence[str] | None Optional annotation text per zoom panel. fmax : float Max frequency on the full-spectrum panel. zoom_half_width_hz : float Half-width (Hz) around each ``zoom_freq``. series_colors : mapping[str, str] | None Optional color overrides by series name. series_labels : mapping[str, str] | None Optional display label overrides by series name. fname : path-like | None Optional output path used to save the figure. show : bool If True, display the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``zoom_freqs`` is empty/non-1D or if ``zoom_half_width_hz <= 0``. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_psd_zoom_comparison >>> freqs = np.linspace(0, 120, 512) >>> before = np.exp(-freqs / 40) >>> after = before * 0.7 >>> fig = plot_psd_zoom_comparison( ... freqs, before, freqs, after, zoom_freqs=[50.0], show=False ... ) """ freqs_before = np.asarray(freqs_before, dtype=float) psd_before = np.asarray(psd_before, dtype=float) freqs_after = np.asarray(freqs_after, dtype=float) psd_after = np.asarray(psd_after, dtype=float) zoom_freqs = np.asarray(zoom_freqs, dtype=float) if zoom_freqs.ndim != 1 or zoom_freqs.size == 0: raise ValueError("zoom_freqs must be a non-empty 1D array-like.") zoom_half_width_hz = float(zoom_half_width_hz) if zoom_half_width_hz <= 0: raise ValueError("zoom_half_width_hz must be positive.") n_zoom = len(zoom_freqs) fig, axes = themed_figure(1, 1 + n_zoom, figsize=(4 * (1 + n_zoom), 4)) axes = np.atleast_1d(axes) if series_colors and series_name in series_colors: series_color = series_colors[series_name] else: series_color = COLORS["after"] if series_name: if series_labels and series_name in series_labels: series_label = series_labels[series_name] else: series_label = series_name else: series_label = "After" ax = axes[0] ax.semilogy( freqs_before, psd_before, color=COLORS["before"], alpha=0.5, lw=1, label="Before", ) ax.semilogy( freqs_after, psd_after, color=series_color, lw=1.5, label=series_label, ) for freq in zoom_freqs: ax.axvline(freq, color=COLORS["line_marker"], ls="--", alpha=0.2) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) ax.set_title("PSD Comparison", fontsize=FONTS["title"]) ax.set_xlim(0, fmax) themed_legend(ax) style_axes(ax, grid=True) for idx, freq in enumerate(zoom_freqs): ax = axes[1 + idx] zoom = zoom_half_width_hz before_mask = (freqs_before >= freq - zoom) & (freqs_before <= freq + zoom) after_mask = (freqs_after >= freq - zoom) & (freqs_after <= freq + zoom) ax.semilogy( freqs_before[before_mask], psd_before[before_mask], color=COLORS["before"], alpha=0.5, lw=1, ) ax.semilogy( freqs_after[after_mask], psd_after[after_mask], color=series_color, lw=1.5, ) ax.axvline(freq, color=COLORS["line_marker"], ls="--", alpha=0.4) panel_title = f"{freq:.0f} Hz" if zoom_annotations is not None and idx < len(zoom_annotations): panel_title += f"\n{zoom_annotations[idx]}" ax.set_title(panel_title, fontsize=FONTS["tick"]) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) if idx == 0: ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) style_axes(ax, grid=True) if title: fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold") return _finalize_fig(fig, show=show, fname=fname) def plot_psd_gallery( freqs_reference, psd_reference, series_psds, zoom_freqs, fmax=125.0, zoom_half_width_hz=8.0, title="", series_order=None, series_colors=None, series_labels=None, fname=None, show=True, ): """Plot full-spectrum and zoomed PSD panels across multiple series. Parameters ---------- freqs_reference : array-like of shape (n_freqs,) Frequency vector for the reference PSD. psd_reference : array-like of shape (n_freqs,) Reference PSD values. series_psds : mapping[str, tuple[array-like, array-like]] Mapping from series name to ``(freqs, psd)`` arrays. zoom_freqs : array-like of shape (n_zoom,) Frequency centers for zoom panels. fmax : float Max frequency on the full-spectrum panels. zoom_half_width_hz : float Half-width (Hz) around each zoom center. title : str Optional figure suptitle. series_order : sequence[str] | None Optional plotting order. Missing names are shown as empty placeholders. series_colors : mapping[str, str] | None Optional color overrides by series name. series_labels : mapping[str, str] | None Optional display label overrides by series name. fname : path-like | None Optional output path used to save the figure. show : bool If True, display the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``zoom_freqs`` is empty/non-1D or if ``zoom_half_width_hz <= 0``. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_psd_gallery >>> freqs = np.linspace(0, 120, 512) >>> before = np.exp(-freqs / 40) >>> series = {"A": (freqs, before * 0.8), "B": (freqs, before * 0.6)} >>> fig = plot_psd_gallery(freqs, before, series, zoom_freqs=[50.0], show=False) """ freqs_reference = np.asarray(freqs_reference, dtype=float) psd_reference = np.asarray(psd_reference, dtype=float) zoom_freqs = np.asarray(zoom_freqs, dtype=float) if zoom_freqs.ndim != 1 or zoom_freqs.size == 0: raise ValueError("zoom_freqs must be a non-empty 1D array-like.") zoom_half_width_hz = float(zoom_half_width_hz) if zoom_half_width_hz <= 0: raise ValueError("zoom_half_width_hz must be positive.") if series_order is None: series_order = list(series_psds.keys()) n_rows = len(series_order) n_cols = 1 + len(zoom_freqs) fig, axes = themed_figure(n_rows, n_cols, figsize=(4 * n_cols, 3.5 * n_rows)) axes = np.asarray(axes, dtype=object) if n_rows == 1: axes = axes[np.newaxis, :] for row_idx, series_name in enumerate(series_order): if series_name not in series_psds: for col_idx in range(n_cols): ax = axes[row_idx, col_idx] ax.text( 0.5, 0.5, f"{series_name}\nno data", transform=ax.transAxes, ha="center", va="center", fontsize=FONTS["label"], color=COLORS["placeholder"], ) ax.axis("off") continue freqs_series, psd_series = series_psds[series_name] freqs_series = np.asarray(freqs_series, dtype=float) psd_series = np.asarray(psd_series, dtype=float) if series_colors and series_name in series_colors: color = series_colors[series_name] else: color = get_series_color(row_idx) if series_labels and series_name in series_labels: label = series_labels[series_name] else: label = series_name ax = axes[row_idx, 0] ax.semilogy( freqs_reference, psd_reference, color=COLORS["before"], alpha=0.4, lw=0.8, label="Before", ) ax.semilogy(freqs_series, psd_series, color=color, lw=1.2, label=label) for freq in zoom_freqs: ax.axvline(freq, color=COLORS["line_marker"], ls="--", alpha=0.15) ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) if row_idx == 0: ax.set_title("Full PSD", fontsize=FONTS["title"]) ax.text( 0.01, 0.98, label, transform=ax.transAxes, ha="left", va="top", fontsize=FONTS["annotation"], color=color, ) ax.set_xlim(0, fmax) themed_legend(ax, fontsize=6) style_axes(ax, grid=True) for col_idx, freq in enumerate(zoom_freqs): ax = axes[row_idx, 1 + col_idx] zoom = zoom_half_width_hz before_mask = (freqs_reference >= freq - zoom) & ( freqs_reference <= freq + zoom ) series_mask = (freqs_series >= freq - zoom) & (freqs_series <= freq + zoom) ax.semilogy( freqs_reference[before_mask], psd_reference[before_mask], color=COLORS["before"], alpha=0.4, lw=0.8, ) ax.semilogy( freqs_series[series_mask], psd_series[series_mask], color=color, lw=1.2, ) ax.axvline(freq, color=COLORS["line_marker"], ls="--", alpha=0.3) if row_idx == 0: ax.set_title(f"{freq:.0f} Hz", fontsize=FONTS["title"]) if col_idx == 0: ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) style_axes(ax, grid=True) if title: fig.suptitle(title, fontsize=FONTS["suptitle"], fontweight="bold", y=1.01) return _finalize_fig(fig, show=show, fname=fname) def plot_psd_overlay( freqs_reference, psd_reference, series_psds, focus_freq, fmax=125.0, focus_half_width_hz=10.0, n_harmonics=3, title="", series_order=None, series_colors=None, series_labels=None, fname=None, show=True, ): """Plot full-spectrum and focused PSD overlays across multiple series. Parameters ---------- freqs_reference : array-like of shape (n_freqs,) Frequency vector for the reference PSD. psd_reference : array-like of shape (n_freqs,) Reference PSD values. series_psds : mapping[str, tuple[array-like, array-like]] Mapping from series name to ``(freqs, psd)`` arrays. focus_freq : float Center frequency for the zoomed overlay panel. fmax : float Max frequency shown in the full-spectrum panel. focus_half_width_hz : float Half-width (Hz) used for the zoomed panel around ``focus_freq``. n_harmonics : int Number of harmonics to mark on the full-spectrum panel. title : str Optional title for the full-spectrum panel. series_order : sequence[str] | None Optional plotting order. series_colors : mapping[str, str] | None Optional color overrides by series name. series_labels : mapping[str, str] | None Optional label overrides by series name. fname : path-like | None Optional output path used to save the figure. show : bool If True, display the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``focus_half_width_hz <= 0``. Notes ----- When ``series_order`` is not provided, overlay order follows the insertion order of ``series_psds``. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_psd_overlay >>> freqs = np.linspace(0, 120, 512) >>> before = np.exp(-freqs / 40) >>> series = {"A": (freqs, before * 0.8), "B": (freqs, before * 0.6)} >>> fig = plot_psd_overlay(freqs, before, series, focus_freq=50.0, show=False) """ freqs_reference = np.asarray(freqs_reference, dtype=float) psd_reference = np.asarray(psd_reference, dtype=float) focus_half_width_hz = float(focus_half_width_hz) if focus_half_width_hz <= 0: raise ValueError("focus_half_width_hz must be positive.") if series_order is None: series_order = list(series_psds.keys()) fig, axes = themed_figure(1, 2, figsize=(16, 5)) axes = np.atleast_1d(axes) ax = axes[0] ax.semilogy( freqs_reference, psd_reference, color=COLORS["before"], alpha=0.4, lw=1, label="Before", ) for idx, series_name in enumerate(series_order): if series_name not in series_psds: continue freqs_series, psd_series = series_psds[series_name] if series_colors and series_name in series_colors: color = series_colors[series_name] else: color = get_series_color(idx) if series_labels and series_name in series_labels: label = series_labels[series_name] else: label = series_name ax.semilogy( freqs_series, psd_series, color=color, lw=1.2, label=label, ) for harmonic_idx in range(1, n_harmonics + 2): harmonic = focus_freq * harmonic_idx if harmonic < fmax: ax.axvline(harmonic, color=COLORS["line_marker"], ls="--", alpha=0.15) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) ax.set_title(title or "Full Spectrum Comparison", fontsize=FONTS["title"]) ax.set_xlim(0, fmax) themed_legend(ax) style_axes(ax, grid=True) ax = axes[1] zoom = focus_half_width_hz before_mask = (freqs_reference >= focus_freq - zoom) & ( freqs_reference <= focus_freq + zoom ) ax.semilogy( freqs_reference[before_mask], psd_reference[before_mask], color=COLORS["before"], alpha=0.4, lw=1, label="Before", ) for idx, series_name in enumerate(series_order): if series_name not in series_psds: continue freqs_series, psd_series = series_psds[series_name] series_mask = (freqs_series >= focus_freq - zoom) & ( freqs_series <= focus_freq + zoom ) if series_colors and series_name in series_colors: color = series_colors[series_name] else: color = get_series_color(idx) if series_labels and series_name in series_labels: label = series_labels[series_name] else: label = series_name ax.semilogy( freqs_series[series_mask], psd_series[series_mask], color=color, lw=1.5, label=label, ) ax.axvline(focus_freq, color=COLORS["line_marker"], ls="--", alpha=0.4) ax.set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_title(f"Zoom at {focus_freq:.0f} Hz", fontsize=FONTS["title"]) themed_legend(ax) style_axes(ax, grid=True) return _finalize_fig(fig, show=show, fname=fname)
[docs] def plot_component_psd_comparison( inst_before, components, component_indices, sfreq=None, peak_freq=None, fmin=1, fmax=40, show=True, fname=None, ): """Plot input PSD next to PSDs of selected components. Parameters ---------- inst_before : MNE object | ndarray Baseline signal used for the reference PSD. components : MNE object | ndarray Component signals with canonical shape ``(n_components, n_times)``, or ``(n_epochs, n_components, n_times)``. component_indices : int | sequence of int Explicit component index/indices to include in the component PSD panel. sfreq : float | None Sampling frequency for array inputs. If ``components`` is an MNE object and ``sfreq`` is None, ``components.info['sfreq']`` is used. peak_freq : float | None Optional frequency marker shown on both panels. fmin, fmax : float Frequency bounds for PSD computation. show : bool If True, display the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``component_indices`` is empty/out of range or if array inputs are provided without ``sfreq``. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_component_psd_comparison >>> signal = np.random.randn(8, 2000) >>> sources = np.random.randn(4, 2000) >>> fig = plot_component_psd_comparison( ... signal, ... sources, ... component_indices=[0, 1], ... sfreq=250.0, ... show=False, ... ) """ fig, axes = themed_figure( 1, 2, figsize=(12, 4), sharey=True, constrained_layout=True ) axes = np.atleast_1d(axes) freqs_before, psd_before = _compute_psd_matrix( inst_before, sfreq=sfreq, fmin=fmin, fmax=fmax ) axes[0].semilogy( freqs_before, psd_before.mean(axis=0), color=COLORS["before"], label="Before", ) axes[0].set_title("Original Data PSD", fontsize=FONTS["title"]) axes[0].set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) axes[0].set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) style_axes(axes[0], grid=True) component_data = _as_component_data(components) if np.isscalar(component_indices): indices = [int(component_indices)] else: indices = [int(idx) for idx in component_indices] if len(indices) == 0: raise ValueError("component_indices cannot be empty.") invalid = [idx for idx in indices if idx < 0 or idx >= component_data.shape[0]] if invalid: raise ValueError(f"Component indices out of range: {invalid}") component_sfreq = sfreq if component_sfreq is None and isinstance( components, (mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked) ): component_sfreq = float(components.info["sfreq"]) if component_sfreq is None: raise ValueError("sfreq must be provided when components are arrays.") freqs_components, psd_components = _compute_psd_matrix( component_data[indices], sfreq=component_sfreq, fmin=fmin, fmax=fmax ) for idx, component_psd in enumerate(psd_components): comp_idx = indices[idx] axes[1].semilogy( freqs_components, component_psd, color=get_series_color(idx), label=f"Component {comp_idx}", ) axes[1].set_title("Component PSD", fontsize=FONTS["title"]) axes[1].set_xlabel("Frequency (Hz)", fontsize=FONTS["label"]) axes[1].set_ylabel("Power Spectral Density", fontsize=FONTS["label"]) style_axes(axes[1], grid=True) if peak_freq is not None: for axis in axes: axis.axvline( peak_freq, color=COLORS["line_marker"], linestyle="--", alpha=0.7, ) themed_legend(axes[0], loc="best") themed_legend(axes[1], loc="best") return _finalize_fig(fig, show=show, fname=fname, tight=False)
[docs] def plot_spectrogram_comparison( inst_before, inst_after, picks, times, sfreq=None, fmin=1, fmax=40, n_freqs=20, show=True, fname=None, ): """Compare before/after spectrograms averaged across selected channels. Parameters ---------- inst_before, inst_after : MNE object | ndarray Inputs to compare. Either both MNE objects or both arrays. Array inputs must be 2D ``(n_channels, n_times)`` or 3D ``(n_epochs, n_channels, n_times)``. picks : sequence of int Explicit channel picks used for averaging. times : array-like of shape (n_times,) Explicit time vector used on x-axis. sfreq : float | None Sampling frequency for array inputs. fmin, fmax : float Frequency bounds for the spectrogram. n_freqs : int Number of frequencies in the display grid. show : bool If True, display the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If ``picks``/``times`` are invalid, if input types are mixed, if array inputs are missing ``sfreq``, or if shape constraints fail. Notes ----- This function enforces an explicit ``times`` input for both MNE and NumPy inputs to avoid hidden axis inference. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_spectrogram_comparison >>> before = np.random.randn(8, 2000) >>> after = before * 0.8 >>> t = np.arange(before.shape[-1]) / 250.0 >>> fig = plot_spectrogram_comparison( ... before, after, picks=[0, 1], times=t, sfreq=250.0, show=False ... ) """ if n_freqs < 2: raise ValueError("n_freqs must be at least 2.") if fmax <= fmin: raise ValueError("fmax must be greater than fmin.") if picks is None: raise ValueError("picks must be provided explicitly.") picks = list(picks) if len(picks) == 0: raise ValueError("picks cannot be empty.") is_mne_before = isinstance( inst_before, (mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked) ) is_mne_after = isinstance(inst_after, (mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked)) if is_mne_before != is_mne_after: raise ValueError("inst_before and inst_after must be both MNE or both arrays.") times = np.asarray(times, dtype=float) if times.ndim != 1: raise ValueError("times must be a 1D array.") freqs = np.linspace(fmin, fmax, n_freqs, dtype=float) if is_mne_before: n_times_before = inst_before.get_data().shape[-1] n_times_after = inst_after.get_data().shape[-1] if n_times_before != n_times_after: raise ValueError("inst_before and inst_after must share the same n_times.") if times.size != n_times_before: raise ValueError("times must match the signal n_times.") tfr_before = inst_before.compute_tfr( method="multitaper", freqs=freqs, n_cycles=freqs / 2.0, picks=picks, ) tfr_after = inst_after.compute_tfr( method="multitaper", freqs=freqs, n_cycles=freqs / 2.0, picks=picks, ) data_before = np.asarray(tfr_before.data, dtype=float) data_after = np.asarray(tfr_after.data, dtype=float) mean_axes = tuple(range(data_before.ndim - 2)) if mean_axes: data_before = data_before.mean(axis=mean_axes) data_after = data_after.mean(axis=mean_axes) else: if sfreq is None: raise ValueError("sfreq must be provided when inputs are arrays.") data_before_arr = np.asarray(inst_before, dtype=float) data_after_arr = np.asarray(inst_after, dtype=float) if data_before_arr.shape[-1] != data_after_arr.shape[-1]: raise ValueError("inst_before and inst_after must share the same n_times.") if times.size != data_before_arr.shape[-1]: raise ValueError("times must match the signal n_times.") if data_before_arr.ndim not in (2, 3) or data_after_arr.ndim not in (2, 3): raise ValueError( "Array spectrogram inputs must be 2D or 3D with time as last axis." ) n_channels = data_before_arr.shape[-2] pick_indices = [int(pick) for pick in picks] invalid = [idx for idx in pick_indices if idx < 0 or idx >= n_channels] if invalid: raise ValueError(f"Channel picks out of range: {invalid}") freqs, data_before = _compute_array_spectrogram( data_before_arr, picks=pick_indices, sfreq=sfreq, fmin=fmin, fmax=fmax, n_freqs=n_freqs, ) _, data_after = _compute_array_spectrogram( data_after_arr, picks=pick_indices, sfreq=sfreq, fmin=fmin, fmax=fmax, n_freqs=n_freqs, ) diff = data_before - data_after fig, axes = themed_figure( 1, 3, figsize=(15, 4), sharey=True, constrained_layout=True ) axes = np.atleast_1d(axes) vmax = float(max(np.max(data_before), np.max(data_after))) diff_limit = float(np.max(np.abs(diff))) def _plot_im(ax, data, title, cmap, vlims, colorbar_label): im = ax.imshow( data, origin="lower", aspect="auto", cmap=cmap, extent=[times[0], times[-1], freqs[0], freqs[-1]], vmin=vlims[0], vmax=vlims[1], ) ax.set_title(title, fontsize=FONTS["title"]) ax.set_xlabel("Time (s)", fontsize=FONTS["label"]) ax.set_ylabel("Frequency (Hz)", fontsize=FONTS["label"]) style_axes(ax, grid=False) _add_colorbar(fig, ax, im, colorbar_label) _plot_im( axes[0], data_before, "Before", cmap=SEQUENTIAL_CMAP, vlims=(0.0, vmax), colorbar_label="Power", ) _plot_im( axes[1], data_after, "After", cmap=SEQUENTIAL_CMAP, vlims=(0.0, vmax), colorbar_label="Power", ) _plot_im( axes[2], diff, "Before - After", cmap=DIVERGING_CMAP, vlims=(-diff_limit, diff_limit), colorbar_label="Power Difference", ) return _finalize_fig(fig, show=show, fname=fname, tight=False)
[docs] def plot_time_frequency_mask( mask, times, freqs, title="Time-Frequency Mask", ax=None, show=True, fname=None, ): """Visualize a time-frequency mask matrix. Parameters ---------- mask : array-like of shape (n_freqs, n_times) Time-frequency weights. times : array-like of shape (n_times,) Time axis coordinates. freqs : array-like of shape (n_freqs,) Frequency axis coordinates. title : str Panel title. ax : matplotlib.axes.Axes | None Target axes. If None, create a new figure and axes. show : bool If True, display the figure. fname : path-like | None Optional output path used to save the figure. Returns ------- fig : matplotlib.figure.Figure Figure handle. Raises ------ ValueError If dimensions of ``mask``, ``times``, and ``freqs`` are inconsistent. Examples -------- >>> import numpy as np >>> from mne_denoise.viz import plot_time_frequency_mask >>> mask = np.random.rand(20, 100) >>> times = np.linspace(0, 2.0, 100) >>> freqs = np.linspace(1.0, 40.0, 20) >>> fig = plot_time_frequency_mask(mask, times, freqs, show=False) """ mask = np.asarray(mask) times = np.asarray(times) freqs = np.asarray(freqs) if mask.ndim != 2: raise ValueError("mask must be a 2D array.") if times.ndim != 1 or freqs.ndim != 1: raise ValueError("times and freqs must be 1D arrays.") if mask.shape != (len(freqs), len(times)): raise ValueError("mask shape must match (len(freqs), len(times)).") if ax is None: fig, ax = themed_figure(figsize=(10, 5)) else: fig = ax.figure im = ax.pcolormesh( times, freqs, mask, shading="auto", cmap=SEQUENTIAL_CMAP, vmin=0, vmax=1, ) ax.set_ylabel("Frequency (Hz)", fontsize=FONTS["label"]) ax.set_xlabel("Time (s)", fontsize=FONTS["label"]) ax.set_title(title, fontsize=FONTS["title"]) style_axes(ax, grid=False) _add_colorbar(fig, ax, im, "Mask Weight") return _finalize_fig(fig, show=show, fname=fname)