Source code for mne_denoise.dss.denoisers.periodic

"""Periodic filter biases for DSS.

Implements peak and comb filter biases for extracting periodic/oscillatory
signals, particularly useful for SSVEP analysis.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
         Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)

References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. J. Mach. Learn. Res., 6, 233-272.
"""

from __future__ import annotations

import numpy as np
from scipy import ndimage, signal

from .base import LinearDenoiser, NonlinearDenoiser


[docs] class PeakFilterBias(LinearDenoiser): """Peak filter bias for single-frequency extraction. Applies a narrow bandpass (peak) filter to emphasize activity at a specific frequency. More selective than BandpassBias, using a resonant filter design. Parameters ---------- freq : float Target frequency in Hz. sfreq : float Sampling frequency in Hz. q_factor : float Quality factor controlling bandwidth. Higher Q = narrower band. Default 30 (roughly 1 Hz bandwidth at 30 Hz). order : int Filter order. Default 2. Examples -------- >>> # Extract 10 Hz alpha with tight filter >>> from mne_denoise.dss.denoisers import PeakFilterBias >>> bias = PeakFilterBias(freq=10, sfreq=250, q_factor=30) >>> biased_data = bias.apply(data) See Also -------- mne_denoise.dss.denoisers.spectral.BandpassBias : For broader band rhythms. Notes ----- Uses a second-order IIR peak filter design. Q factor determines the bandwidth: bandwidth ≈ freq / Q. References ---------- Särelä & Valpola (2005). Section 4.1.2 "DENOISING BASED ON FREQUENCY CONTENT" """
[docs] def __init__( self, freq: float, sfreq: float, *, q_factor: float = 30.0, order: int = 2, ) -> None: self.freq = freq self.sfreq = sfreq self.q_factor = q_factor self.order = order # Design peak filter self._sos = self._design_peak_filter()
def _design_peak_filter(self) -> np.ndarray: """Design IIR peak filter using second-order sections.""" nyq = self.sfreq / 2 if self.freq >= nyq: raise ValueError( f"Target frequency ({self.freq} Hz) must be < Nyquist ({nyq} Hz)" ) # Normalized frequency w0 = self.freq / nyq # Bandwidth from Q factor w0 / self.q_factor # Design peak filter using iirpeak b, a = signal.iirpeak(w0, self.q_factor) sos = signal.tf2sos(b, a) return sos def apply(self, data: np.ndarray) -> np.ndarray: """Apply peak filter bias. Parameters ---------- data : ndarray, shape (n_channels, n_times) or (n_channels, n_times, n_epochs) Input data. Returns ------- biased : ndarray, same shape as input Peak-filtered data. """ if data.ndim == 3: n_channels, n_times, n_epochs = data.shape biased = np.zeros_like(data) for ep in range(n_epochs): biased[:, :, ep] = signal.sosfiltfilt(self._sos, data[:, :, ep], axis=1) elif data.ndim == 2: biased = signal.sosfiltfilt(self._sos, data, axis=1) else: raise ValueError(f"Data must be 2D or 3D, got {data.ndim}D") return biased
[docs] class CombFilterBias(LinearDenoiser): """Comb filter bias for harmonic frequency extraction. Applies a comb filter that passes the fundamental frequency and its harmonics. Ideal for SSVEP analysis where stimulus frequency creates responses at multiple harmonic frequencies. Parameters ---------- fundamental_freq : float Fundamental frequency in Hz (e.g., 15 Hz for SSVEP). sfreq : float Sampling frequency in Hz. n_harmonics : int Number of harmonics to include (1 = fundamental only). Default 3. q_factor : float Quality factor for each peak. Default 30. weights : array-like, optional Weights for each harmonic. If None, uses 1/harmonic_number weighting (decreasing importance of higher harmonics). Examples -------- >>> # SSVEP at 15 Hz with 3 harmonics (15, 30, 45 Hz) >>> from mne_denoise.dss.denoisers import CombFilterBias >>> bias = CombFilterBias(fundamental_freq=50, sfreq=1000, n_harmonics=5) >>> biased_data = bias.apply(data) >>> # Custom weighting (equal weight for all harmonics) >>> bias = CombFilterBias( ... fundamental_freq=12, sfreq=500, n_harmonics=4, weights=[1.0, 1.0, 1.0, 1.0] ... ) See Also -------- PeakFilterBias : For single frequency targets. Notes ----- Implements a sum of peak filters at each harmonic. Harmonics above Nyquist frequency are automatically excluded. References ---------- Särelä & Valpola (2005). Section 4.1.2 "DENOISING BASED ON FREQUENCY CONTENT" """
[docs] def __init__( self, fundamental_freq: float, sfreq: float, *, n_harmonics: int = 3, q_factor: float = 30.0, weights: np.ndarray | None = None, ) -> None: self.fundamental_freq = fundamental_freq self.sfreq = sfreq self.n_harmonics = n_harmonics self.q_factor = q_factor # Set up weights if weights is None: self.weights = np.array([1.0 / h for h in range(1, n_harmonics + 1)]) else: self.weights = np.asarray(weights) if len(self.weights) != n_harmonics: raise ValueError( f"weights length ({len(self.weights)}) must match " f"n_harmonics ({n_harmonics})" ) # Create peak filters for each valid harmonic self._peak_filters: list[tuple[np.ndarray, float]] = [] self._create_harmonic_filters()
def _create_harmonic_filters(self) -> None: """Create peak filter for each harmonic within Nyquist.""" nyq = self.sfreq / 2 for h in range(1, self.n_harmonics + 1): freq = self.fundamental_freq * h if freq >= nyq * 0.95: continue # Skip harmonics too close to Nyquist w0 = freq / nyq weight = self.weights[h - 1] b, a = signal.iirpeak(w0, self.q_factor) sos = signal.tf2sos(b, a) self._peak_filters.append((sos, weight)) def apply(self, data: np.ndarray) -> np.ndarray: """Apply comb filter bias. Parameters ---------- data : ndarray, shape (n_channels, n_times) or (n_channels, n_times, n_epochs) Input data. Returns ------- biased : ndarray, same shape as input Comb-filtered data (sum of harmonics). """ if len(self._peak_filters) == 0: raise ValueError("No valid harmonics within Nyquist frequency") biased = np.zeros_like(data) for sos, weight in self._peak_filters: if data.ndim == 3: n_channels, n_times, n_epochs = data.shape for ep in range(n_epochs): biased[:, :, ep] += weight * signal.sosfiltfilt( sos, data[:, :, ep], axis=1 ) elif data.ndim == 2: biased += weight * signal.sosfiltfilt(sos, data, axis=1) else: raise ValueError(f"Data must be 2D or 3D, got {data.ndim}D") return biased @property def harmonic_frequencies(self) -> list[float]: """Return list of harmonic frequencies being filtered.""" nyq = self.sfreq / 2 return [ self.fundamental_freq * h for h in range(1, self.n_harmonics + 1) if self.fundamental_freq * h < nyq * 0.95 ]
[docs] class QuasiPeriodicDenoiser(NonlinearDenoiser): """Quasi-periodic denoiser via cycle averaging. For signals with repeating structure (ECG, respiration, periodic artifacts): 1. Detect peaks/cycles in the source 2. Segment into individual cycles 3. Time-warp cycles to common length 4. Average to create template 5. Replace each cycle with time-warped template Parameters ---------- peak_distance : int Minimum distance between peaks in samples. Default 100. peak_height_percentile : float Percentile of signal for peak detection threshold. Default 75. warp_length : int, optional Length to warp each cycle to. If None, use median cycle length. smooth_template : bool If True, smooth the template. Default True. Examples -------- >>> # For ECG-like signal at 250 Hz (peaks ~1 sec apart) >>> from mne_denoise.dss.denoisers import QuasiPeriodicDenoiser >>> denoiser = QuasiPeriodicDenoiser(peak_distance=200) >>> denoised = denoiser.denoise(ecg_source) References ---------- Särelä & Valpola (2005). Section 4.1.4 "DENOISING OF QUASIPERIODIC SIGNALS" """
[docs] def __init__( self, peak_distance: int = 100, peak_height_percentile: float = 75.0, *, warp_length: int | None = None, smooth_template: bool = True, ) -> None: self.peak_distance = max(10, peak_distance) self.peak_height_percentile = peak_height_percentile self.warp_length = warp_length self.smooth_template = smooth_template
def denoise(self, source: np.ndarray) -> np.ndarray: """Apply quasi-periodic denoising. Parameters ---------- source : ndarray, shape (n_times,) or (n_times, n_epochs) Source time series with quasi-periodic structure. Returns ------- denoised : ndarray, same shape as input Denoised source with cycles replaced by template. """ if source.ndim == 1: return self._denoise_1d(source) elif source.ndim == 2: n_times, n_epochs = source.shape denoised = np.zeros_like(source) for ep in range(n_epochs): denoised[:, ep] = self._denoise_1d(source[:, ep]) return denoised else: raise ValueError(f"Source must be 1D or 2D, got {source.ndim}D") def _denoise_1d(self, source: np.ndarray) -> np.ndarray: """Apply quasi-periodic denoising to 1D source.""" n_samples = len(source) # Step 1: Detect peaks height_threshold = np.percentile(np.abs(source), self.peak_height_percentile) peaks, _ = signal.find_peaks( np.abs(source), height=height_threshold, distance=self.peak_distance, ) if len(peaks) < 3: # Not enough cycles, return original return source # Step 2: Determine cycle boundaries (midpoints between peaks) boundaries = np.zeros(len(peaks) + 1, dtype=int) boundaries[0] = 0 boundaries[-1] = n_samples for i in range(1, len(peaks)): boundaries[i] = (peaks[i - 1] + peaks[i]) // 2 # Step 3: Extract cycles and determine warp length cycles = [] cycle_lengths = [] for i in range(len(peaks)): start = boundaries[i] end = boundaries[i + 1] if end > start: cycles.append(source[start:end]) cycle_lengths.append(end - start) if len(cycles) < 2: return source # Warp length: use provided or median if self.warp_length is not None: warp_len = self.warp_length else: warp_len = int(np.median(cycle_lengths)) warp_len = max(10, warp_len) # Step 4: Time-warp all cycles to common length and average warped_cycles = [] for cycle in cycles: if len(cycle) >= 3: # Resample to warp_len warped = np.interp( np.linspace(0, 1, warp_len), np.linspace(0, 1, len(cycle)), cycle ) warped_cycles.append(warped) if len(warped_cycles) < 2: return source # Average to create template template = np.mean(warped_cycles, axis=0) # Optional smoothing if self.smooth_template: smooth_window = max(3, warp_len // 20) template = ndimage.uniform_filter1d( template, size=smooth_window, mode="reflect" ) # Step 5: Replace each cycle with time-warped template denoised = np.zeros_like(source) for i, cycle in enumerate(cycles): start = boundaries[i] end = boundaries[i + 1] cycle_len = end - start if cycle_len >= 3: # Warp template back to original cycle length warped_template = np.interp( np.linspace(0, 1, cycle_len), np.linspace(0, 1, warp_len), template ) # Match amplitude to original cycle scale = np.std(cycle) / (np.std(warped_template) + 1e-15) offset = np.mean(cycle) - np.mean(warped_template) * scale denoised[start:end] = warped_template * scale + offset return denoised