"""Spectrogram-based bias functions for DSS.
Implements denoising based on time-frequency representation masking.
Includes both Linear (SpectrogramBias) and Nonlinear (SpectrogramDenoiser)
implementations.
Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. J. Mach. Learn. Res., 6, 233-272.
"""
from __future__ import annotations
import numpy as np
from scipy import signal
from scipy.ndimage import zoom
from .base import LinearDenoiser, NonlinearDenoiser
def _apply_tf_mask(
data_1d: np.ndarray, mask: np.ndarray, nperseg: int, noverlap: int
) -> np.ndarray:
"""Apply TF mask to 1D signal."""
f, t, Zxx = signal.stft(data_1d, nperseg=nperseg, noverlap=noverlap)
# Resize mask if needed
if mask.shape != Zxx.shape:
zoom_factors = (Zxx.shape[0] / mask.shape[0], Zxx.shape[1] / mask.shape[1])
mask_2d = zoom(mask, zoom_factors, order=0) # Nearest/Linear
else:
mask_2d = mask
Zxx_masked = Zxx * mask_2d
_, reconstructed = signal.istft(Zxx_masked, nperseg=nperseg, noverlap=noverlap)
# Match length
if len(reconstructed) > len(data_1d):
reconstructed = reconstructed[: len(data_1d)]
elif len(reconstructed) < len(data_1d):
# Pad with zeros
padded = np.zeros(len(data_1d))
padded[: len(reconstructed)] = reconstructed
reconstructed = padded
return reconstructed
[docs]
class SpectrogramBias(LinearDenoiser):
"""Linear spectrogram bias (Section 4.1.3).
Applies a FIXED time-frequency mask to the data. This is a linear operation
used to define the signal subspace in the initialization or linear DSS step.
Parameters
----------
mask : ndarray, shape (n_freqs, n_times)
The fixed 2D mask to apply. Must be provided for linear biasing.
nperseg : int
Segment length for STFT. Default 256.
noverlap : int, optional
Overlap between segments. Default nperseg // 2.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.dss.denoisers import SpectrogramBias
>>> mask = np.ones((128, 1000))
>>> bias = SpectrogramBias(mask)
>>> data = np.random.randn(128, 1000)
>>> biased = bias.apply(data)
See Also
--------
SpectrogramDenoiser : Adaptive nonlinear version.
References
----------
Särelä & Valpola (2005). Section 4.1.3 "SPECTROGRAM DENOISING"
"""
[docs]
def __init__(
self,
mask: np.ndarray,
nperseg: int = 256,
noverlap: int | None = None,
) -> None:
self.mask = mask
self.nperseg = nperseg
self.noverlap = noverlap if noverlap is not None else nperseg // 2
def apply(self, data: np.ndarray) -> np.ndarray:
"""Apply fixed spectrogram mask to all channels."""
# Linear denoisers operate on sensor data (n_ch, n_times)
if data.ndim == 2:
return self._apply_2d(data)
elif data.ndim == 3:
# (n_ch, n_times, n_epochs)
n_ch, n_times, n_epochs = data.shape
biased = np.zeros_like(data)
for ep in range(n_epochs):
biased[:, :, ep] = self._apply_2d(data[:, :, ep])
return biased
else:
raise ValueError(f"Data must be 2D or 3D, got {data.ndim}D")
def _apply_2d(self, data: np.ndarray) -> np.ndarray:
# Apply strict mask to each channel
n_ch, n_times = data.shape
biased = np.zeros_like(data)
for ch in range(n_ch):
biased[ch] = _apply_tf_mask(
data[ch], self.mask, self.nperseg, self.noverlap
)
return biased
[docs]
class SpectrogramDenoiser(NonlinearDenoiser):
"""Adaptive/Nonlinear spectrogram denoiser (Section 4.1.3).
Applies masking in the time-frequency domain. This version is ADAPTIVE,
calculating the mask from the source estimate itself at each iteration.
This makes it distinct from the Linear SpectrogramBias.
Parameters
----------
threshold_percentile : float
For adaptive masking, threshold below this percentile. Default 90.
Higher percentile = sparser signal (more aggressive denoising).
nperseg : int
Segment length for STFT. Default 256.
noverlap : int, optional
Overlap between segments. Default nperseg // 2.
mask : ndarray, shape (n_freqs, n_times), optional
Optional FIXED mask to use instead of adaptive (hybrid mode).
Examples
--------
>>> from mne_denoise.dss.denoisers import SpectrogramDenoiser
>>> # Retain only the strongest 10% of TF-bins (aggressive denoising)
>>> denoiser = SpectrogramDenoiser(threshold_percentile=90)
>>> denoised = denoiser.denoise(source)
See Also
--------
SpectrogramBias : Fixed linear version.
References
----------
Särelä & Valpola (2005). Section 4.1.3 "SPECTROGRAM DENOISING"
"""
[docs]
def __init__(
self,
threshold_percentile: float = 90.0,
nperseg: int = 256,
noverlap: int | None = None,
mask: np.ndarray | None = None,
) -> None:
self.threshold_percentile = threshold_percentile
self.nperseg = nperseg
self.noverlap = noverlap if noverlap is not None else nperseg // 2
self.mask = mask
def denoise(self, source: np.ndarray) -> np.ndarray:
"""Apply adaptive 2D spectrogram masking."""
if source.ndim == 2:
_, n_epochs = source.shape
denoised = np.zeros_like(source)
for ep in range(n_epochs):
denoised[:, ep] = self._denoise_1d(source[:, ep])
return denoised
elif source.ndim == 1:
return self._denoise_1d(source)
else:
raise ValueError(f"Source must be 1D or 2D, got {source.ndim}D")
def _denoise_1d(self, source: np.ndarray) -> np.ndarray:
"""Process 1D source."""
# STFT just to calculate mask if adaptive
if self.mask is None:
f, t, Zxx = signal.stft(
source, nperseg=self.nperseg, noverlap=self.noverlap
)
# Adaptive magnitude-based mask
magnitude = np.abs(Zxx)
threshold = np.percentile(magnitude, self.threshold_percentile)
computed_mask = (magnitude > threshold).astype(float)
else:
computed_mask = self.mask
# Apply mask using shared logic
return _apply_tf_mask(source, computed_mask, self.nperseg, self.noverlap)