Source code for mne_denoise.dss.denoisers.masking

"""Adaptive masking denoisers for DSS.

This module implements denoisers that estimate a time-varying mask $m(t)$
based on the local variance of the source signal.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
         Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)

References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. J. Mach. Learn. Res., 6, 233-272.
"""

from __future__ import annotations

import numpy as np
from scipy import ndimage

from .base import NonlinearDenoiser


[docs] class WienerMaskDenoiser(NonlinearDenoiser): """Adaptive Wiener mask denoiser. The core nonlinear DSS denoiser from Särelä & Valpola (2005). Estimates time-varying signal variance and applies soft Wiener-style masking: m(t) = σ²_signal(t) / [σ²_signal(t) + σ²_noise] s⁺(t) = s(t) · m(t) This is adaptive/nonlinear because the mask is estimated from the data. Ideal for bursty, non-stationary signals (spindles, beta bursts, intermittent artifacts). Parameters ---------- window_samples : int Window size for local variance estimation. Default 50. noise_percentile : float Percentile of local variance used to estimate noise floor. Lower values = more aggressive denoising. Default 25. min_gain : float Minimum mask value (prevents complete zeroing). Default 0.01. noise_variance : float, optional If provided, use this fixed noise variance instead of estimating. Examples -------- >>> from mne_denoise.dss.denoisers import WienerMaskDenoiser >>> denoiser = WienerMaskDenoiser(window_samples=50) >>> denoised = denoiser.denoise(source) References ---------- Särelä & Valpola (2005). Section 4.4 "Spectral Shift and Approximation of the Objective Function with Mask-Based Denoisings" """
[docs] def __init__( self, window_samples: int = 50, noise_percentile: float = 25.0, *, min_gain: float = 0.01, noise_variance: float | None = None, ) -> None: self.window_samples = max(3, window_samples) self.noise_percentile = noise_percentile self.min_gain = min_gain self.noise_variance = noise_variance
def denoise(self, source: np.ndarray) -> np.ndarray: """Apply Wiener mask denoising. Parameters ---------- source : ndarray, shape (n_times,) or (n_times, n_epochs) Source time series. Returns ------- denoised : ndarray, same shape as input Wiener-masked source. """ if source.ndim == 1: return self._denoise_1d(source) elif source.ndim == 2: _, n_epochs = source.shape denoised = np.zeros_like(source) for ep in range(n_epochs): denoised[:, ep] = self._denoise_1d(source[:, ep]) return denoised else: raise ValueError(f"Source must be 1D or 2D, got {source.ndim}D") def _denoise_1d(self, source: np.ndarray) -> np.ndarray: """Apply Wiener mask to 1D source.""" n_samples = len(source) window = min(self.window_samples, n_samples // 2) # Estimate local signal variance: σ²(t) = E[s²] - E[s]² source_sq = source**2 local_mean_sq = ndimage.uniform_filter1d(source_sq, size=window, mode="reflect") local_mean = ndimage.uniform_filter1d(source, size=window, mode="reflect") local_var = np.maximum(local_mean_sq - local_mean**2, 0) # Estimate noise variance (from quiet periods) if self.noise_variance is not None: noise_var = self.noise_variance else: # Use percentile of local variance as noise floor estimate noise_var = np.percentile(local_var, self.noise_percentile) noise_var = max(noise_var, 1e-15) # Prevent division by zero # Wiener mask: m(t) = σ²_signal / (σ²_signal + σ²_noise) # where σ²_signal = max(0, local_var - noise_var) signal_var = np.maximum(local_var - noise_var, 0) mask = signal_var / (signal_var + noise_var + 1e-15) # Apply minimum gain mask = np.maximum(mask, self.min_gain) return source * mask
class VarianceMaskDenoiser(NonlinearDenoiser): """Nonlinear denoiser using local variance masking. Identifies high-variance regions in the source time series and weights them higher, effectively emphasizing transient activity. Useful for extracting non-stationary sources. Parameters ---------- window_samples : int Window size for local variance computation. Default 100. percentile : float Percentile threshold for high-variance mask. Default 75. soft : bool If True, use soft weighting based on variance magnitude. If False, use binary mask. Default True. Examples -------- >>> from mne_denoise.dss.denoisers import VarianceMaskDenoiser >>> denoiser = VarianceMaskDenoiser(window_samples=50, percentile=80) >>> denoised_source = denoiser.denoise(source) References ---------- Särelä & Valpola (2005). Section 4.4 "Spectral Shift and Approximation of the Objective Function with Mask-Based Denoisings" """ def __init__( self, window_samples: int = 100, percentile: float = 75.0, *, soft: bool = True, ) -> None: self.window_samples = max(3, window_samples) self.percentile = percentile self.soft = soft def denoise(self, source: np.ndarray) -> np.ndarray: """Apply variance-based masking to source time series.""" if source.ndim == 1: return self._denoise_1d(source) elif source.ndim == 2: _, n_epochs = source.shape denoised = np.zeros_like(source) for ep in range(n_epochs): denoised[:, ep] = self._denoise_1d(source[:, ep]) return denoised else: raise ValueError(f"Source must be 1D or 2D, got {source.ndim}D") def _denoise_1d(self, source: np.ndarray) -> np.ndarray: """Process single 1D source.""" n_samples = len(source) source_sq = source**2 window = min(self.window_samples, n_samples) local_mean_sq = ndimage.uniform_filter1d(source_sq, size=window, mode="reflect") local_mean = ndimage.uniform_filter1d(source, size=window, mode="reflect") local_var = np.maximum(local_mean_sq - local_mean**2, 0) if self.soft: threshold = np.percentile(local_var, self.percentile) if threshold < 1e-15: threshold = np.max(local_var) * 0.5 if threshold < 1e-15: return source weights = 1 / (1 + np.exp(-(local_var - threshold) / (threshold * 0.5))) denoised = source * weights else: threshold = np.percentile(local_var, self.percentile) mask = local_var >= threshold denoised = source * mask.astype(float) return denoised