"""Temporal bias functions for DSS.
Implements time-shift and smoothing biases for extracting temporally
extended structure (slow waves, autocorrelated signals).
Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
References
----------
.. [1] de Cheveigné, A. (2010). Time-shift denoising source separation.
Journal of Neuroscience Methods, 189(1), 113-120.
.. [2] de Cheveigné, A. & Simon, J.Z. (2008). Denoising based on spatial filtering.
Journal of Neuroscience Methods, 171(2), 331-339.
"""
from __future__ import annotations
import numpy as np
from .base import LinearDenoiser, NonlinearDenoiser
[docs]
class TimeShiftBias(LinearDenoiser):
"""Time-shift bias for extracting autocorrelated signals.
Creates a bias by averaging time-shifted versions of the data,
emphasizing signals that are predictable across time lags.
Parameters
----------
shifts : int or array-like
If int, use lags from 1 to shifts.
If array, use specified lag values in samples.
Default 10.
method : str
Method for constructing bias:
- 'autocorrelation': Average of shifted versions (default)
- 'prediction': Weighted average (closer lags weighted more)
Examples
--------
>>> bias = TimeShiftBias(shifts=[1, 2, 5, 10], method="prediction")
>>> biased_data = bias.apply(data)
See Also
--------
SmoothingBias : Bias for low-frequency signals.
"""
[docs]
def __init__(
self,
shifts: int | np.ndarray = 10,
method: str = "autocorrelation",
) -> None:
self.shifts = shifts
self.method = method
# Resolve shifts to array
if isinstance(shifts, int):
self._shift_array = np.arange(1, shifts + 1)
else:
self._shift_array = np.asarray(shifts)
def apply(self, data: np.ndarray) -> np.ndarray:
"""Apply time-shift bias.
Parameters
----------
data : ndarray, shape (n_channels, n_times) or (n_channels, n_times, n_epochs)
Input data.
Returns
-------
biased : ndarray, same shape as input
Time-shifted averaged data.
"""
# Handle 3D data
orig_shape = data.shape
if data.ndim == 3:
n_ch, n_times, n_epochs = data.shape
data_2d = data.reshape(n_ch, -1)
else:
data_2d = data
if self.method == "autocorrelation":
biased_2d = self._autocorrelation_bias(data_2d)
elif self.method == "prediction":
biased_2d = self._prediction_bias(data_2d)
else:
raise ValueError(f"Unknown method: {self.method}")
# Restore shape
if data.ndim == 3:
return biased_2d.reshape(orig_shape)
return biased_2d
def _autocorrelation_bias(self, data: np.ndarray) -> np.ndarray:
"""Average of time-shifted versions."""
n_channels, n_samples = data.shape
shifts = self._shift_array
max_shift = np.max(np.abs(shifts))
if max_shift >= n_samples // 2:
raise ValueError(
f"Max shift ({max_shift}) too large for data length ({n_samples})"
)
valid_start = max_shift
valid_end = n_samples - max_shift
valid_length = valid_end - valid_start
accumulated = np.zeros((n_channels, valid_length))
for shift in shifts:
shifted = data[:, valid_start + shift : valid_end + shift]
accumulated += shifted
biased = accumulated / len(shifts)
# Pad to original length
biased_full = np.zeros_like(data)
biased_full[:, valid_start:valid_end] = biased
return biased_full
def _prediction_bias(self, data: np.ndarray) -> np.ndarray:
"""Weighted average (closer lags weighted more)."""
n_channels, n_samples = data.shape
shifts = self._shift_array
max_shift = np.max(np.abs(shifts))
valid_start = max_shift
valid_end = n_samples - max_shift
valid_length = valid_end - valid_start
accumulated = np.zeros((n_channels, valid_length))
total_weight = 0
for shift in shifts:
weight = 1.0 / max(abs(shift), 1)
shifted = data[:, valid_start + shift : valid_end + shift]
accumulated += weight * shifted
total_weight += weight
biased = accumulated / total_weight
biased_full = np.zeros_like(data)
biased_full[:, valid_start:valid_end] = biased
return biased_full
[docs]
class SmoothingBias(LinearDenoiser):
"""Unified temporal smoothing bias (Moving Average).
Uses a boxcar moving average filter to smooth the data."
Parameters
----------
window : int
Smoothing window size in samples.
Note: If you want to cancel a specific frequency (e.g. 50Hz line noise),
set window = int(sfreq / 50).
iterations : int
Number of smoothing passes. Repeated smoothing approximates a Gaussian filter
and provides sharper frequency cutoff. Default 1.
Examples
--------
>>> bias = SmoothingBias(window=20) # Simple smoothing
>>> biased = bias.apply(data)
>>> # To remove 50Hz line noise (Period smoothing)
>>> bias = SmoothingBias(window=int(1000 / 50), iterations=1)
"""
[docs]
def __init__(self, window: int = 10, iterations: int = 1) -> None:
self.window = window
self.iterations = iterations
def apply(self, data: np.ndarray) -> np.ndarray:
"""Apply smoothing bias.
Uses a causal running-mean filter:
``y[t] = mean(x[t-W+1 : t+1])`` for ``t >= W``, with an expanding
window for the first ``W`` samples. Repeated ``iterations`` passes
approximate a Gaussian kernel.
"""
orig_shape = data.shape
if data.ndim == 3:
data_2d = data.reshape(data.shape[0], -1)
else:
data_2d = data
W = int(self.window)
smoothed = data_2d.copy()
for _ in range(self.iterations):
mean_head = np.mean(smoothed[..., : W + 1], axis=-1, keepdims=True)
centered = smoothed - mean_head
# Causal running mean via cumulative sums
cs = np.cumsum(centered, axis=-1)
out = np.empty_like(centered)
# First W samples: expanding window
out[..., :W] = cs[..., :W] / np.arange(1, W + 1)
# Remaining samples: fixed-width causal window
out[..., W:] = (cs[..., W:] - cs[..., :-W]) / W
smoothed = out + mean_head
if data.ndim == 3:
return smoothed.reshape(orig_shape)
return smoothed
[docs]
class DCTDenoiser(NonlinearDenoiser):
"""DCT domain denoiser (MATLAB denoise_dct.m).
Applies a mask in the DCT (Discrete Cosine Transform) domain.
Useful for temporal smoothness without explicit bandpass.
Parameters
----------
mask : ndarray or None
DCT domain mask. Must have same length as signal, or will be
expanded/truncated. If None, creates lowpass mask.
If mask is None, this fraction of DCT coefficients are kept.
Default 0.5 (lowpass, keep first 50% of coefficients).
cutoff_fraction : float
Fraction of DCT coefficients to keep. If mask is None,
this fraction of DCT coefficients are kept.
Default 0.5 (lowpass, keep first 50% of coefficients).
Examples
--------
>>> from mne_denoise.dss.denoisers import DCTDenoiser
>>> # Keep only the lowest 20% of DCT coefficients (smooth signal)
>>> denoiser = DCTDenoiser(cutoff_fraction=0.2)
>>> smooth_source = denoiser.denoise(source)
References
----------
Särelä & Valpola (2005). Section 4.1.2 "DENOISING BASED ON FREQUENCY CONTENT"
"""
[docs]
def __init__(
self, mask: np.ndarray | None = None, cutoff_fraction: float = 0.5
) -> None:
self.mask = mask
self.cutoff_fraction = cutoff_fraction
self._cached_mask = None
self._cached_len = None
def denoise(self, source: np.ndarray) -> np.ndarray:
"""Apply DCT filtering."""
from scipy.fftpack import dct, idct
n = len(source)
# Create or retrieve mask
if self.mask is not None:
if len(self.mask) == n:
mask = self.mask
else:
# Resample mask to match signal length
mask = np.interp(
np.linspace(0, 1, n), np.linspace(0, 1, len(self.mask)), self.mask
)
else:
# Create lowpass mask if not cached or length changed
if self._cached_mask is None or self._cached_len != n:
cutoff = int(n * self.cutoff_fraction)
mask = np.zeros(n)
mask[:cutoff] = 1.0
self._cached_mask = mask
self._cached_len = n
else:
mask = self._cached_mask
if source.ndim == 1:
dct_coeffs = dct(source, type=2, norm="ortho")
dct_filtered = dct_coeffs * mask
return idct(dct_filtered, type=2, norm="ortho")
elif source.ndim == 2:
_, n_epochs = source.shape
denoised = np.zeros_like(source)
for ep in range(n_epochs):
denoised[:, ep] = self._denoise_1d(source[:, ep], mask)
return denoised
else:
raise ValueError(f"Source must be 1D or 2D, got {source.ndim}D")
def _denoise_1d(self, source, mask):
from scipy.fftpack import dct, idct
dct_coeffs = dct(source, type=2, norm="ortho")
dct_filtered = dct_coeffs * mask
return idct(dct_filtered, type=2, norm="ortho")