"""Core nonlinear/iterative DSS algorithm and Estimator.
This module contains:
1. `iterative_dss`: The core mathematical implementation of Nonlinear DSS.
2. `IterativeDSS`: The Scikit-learn estimator compatible with MNE-Python objects or NumPy arrays.
Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. J. Mach. Learn. Res., 6, 233-272.
"""
from __future__ import annotations
from collections.abc import Callable
import numpy as np
from ..utils import extract_data_from_mne
from .utils.whitening import whiten_data
def _resolve_callable(param, x, default=None):
"""Resolve parameters that can be float, callable, or None."""
if param is None:
return default
if callable(param):
return param
# Wrap constant in callable
return lambda *args: param
def iterative_dss_one(
X_whitened: np.ndarray,
denoiser: Callable[[np.ndarray], np.ndarray],
*,
w_init: np.ndarray | None = None,
max_iter: int = 100,
tol: float = 1e-6,
alpha: float | Callable[[np.ndarray], float] | None = None,
beta: float | Callable[[np.ndarray], float] | None = None,
gamma: float | Callable[[np.ndarray, np.ndarray, int], float] | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray, int, bool]:
"""Fixed-point iteration for extracting a single DSS component.
This implements Algorithm 1 from Särelä & Valpola (2005) [1]_ with optional
Newton step acceleration (FastICA equivalence).
The algorithm finds a spatial filter **w** that maximizes the objective:
.. math:: J(w) = E[f(w^T X)^2]
where f(·) is the nonlinear denoising function.
**Algorithm**::
Initialize: w = random unit vector
Repeat until converged:
1. source = w @ X # Project data → 1D source
2. source_denoised = f(source) # Apply nonlinearity
3. source_denoised *= alpha # (optional) signal normalization
4. w_new = E[X · source_denoised] # Gradient direction
5. w_new += beta · w # (optional) Newton step
6. w_new = normalize(w_new) # Unit norm constraint
7. w = w_old + gamma·(w_new - w_old) # (optional) relaxation
8. Check convergence: |w · w_old| ≈ 1
Parameters
----------
X_whitened : ndarray, shape (n_components, n_times)
Whitened data matrix. Must have identity covariance.
denoiser : callable
Nonlinear denoising function f(s) → s_denoised.
Examples: TanhMaskDenoiser, GaussDenoiser, WienerMaskDenoiser.
w_init : ndarray, shape (n_components,), optional
Initial weight vector. If None, random initialization.
max_iter : int
Maximum iterations. Default 100.
tol : float
Convergence tolerance. Default 1e-6.
alpha : float or callable, optional
Signal normalization factor applied after denoising:
``source_denoised *= alpha``.
Useful for denoisers with different output variance.
beta : float or callable, optional
Spectral shift parameter for Newton-like acceleration.
For tanh denoiser: ``beta = -E[1 - tanh(s)²]`` (use ``beta_tanh``).
For cubic denoiser: ``beta = -3`` (use ``beta_pow3``).
gamma : float or callable, optional
Learning rate / relaxation parameter. Controls step size:
``w = w_old + gamma · (w_new - w_old)``.
Default None (gamma=1, full step).
Returns
-------
w : ndarray, shape (n_components,)
Optimal spatial filter (unit norm).
source : ndarray, shape (n_times,)
Extracted source time series.
n_iter : int
Number of iterations performed.
converged : bool
Whether the algorithm converged within max_iter.
References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. JMLR, 6, 233-272.
"""
n_components, n_times = X_whitened.shape
# Initialize RNG (handle both int and Generator)
if isinstance(random_state, np.random.Generator):
rng = random_state
else:
rng = np.random.default_rng(random_state)
# Initialize weight vector
if w_init is not None:
w = w_init.copy()
else:
w = rng.standard_normal(n_components)
# Normalize
norm = np.linalg.norm(w)
if norm < 1e-12:
w = np.ones(n_components) / np.sqrt(n_components)
else:
w = w / norm
# Resolve parameters to callables
alpha_func = _resolve_callable(alpha, None)
beta_func = _resolve_callable(beta, None)
gamma_func = _resolve_callable(gamma, None)
converged = False
n_iter = 0
for iteration in range(max_iter):
n_iter = iteration + 1
w_old = w.copy()
# Step 2: Extract source
source = w @ X_whitened # (n_times,)
# Step 3: Apply denoiser
source_denoised = denoiser(source)
# Apply alpha (signal normalization)
if alpha_func is not None:
source_denoised = alpha_func(source) * source_denoised
# Calculate beta step if applicable
step_beta = 0.0
if beta_func is not None:
step_beta = beta_func(source)
# Step 4: Update weights
# w_new = E[X * f(s)] + beta * w
# Standard DSS: w_new = E[X * f(s)]
# FastICA Newton: beta = -E[f'(s)]
gradient_part = X_whitened @ source_denoised / n_times
w_new = gradient_part + step_beta * w
# Step 5: Normalize
norm = np.linalg.norm(w_new)
if norm < 1e-12:
# Denoiser killed the signal, reinitialize
w = rng.standard_normal(n_components)
w = w / np.linalg.norm(w)
continue
w_normalized = w_new / norm
# Apply gamma (learning rate / relaxation)
if gamma_func is not None:
step_gamma = gamma_func(w_normalized, w_old, iteration)
# w = w_old + gamma * (w_new - w_old)
w = w_old + step_gamma * (w_normalized - w_old)
# Re-normalize after relaxation
w = w / np.linalg.norm(w)
else:
w = w_normalized
# Check convergence (using abs because sign can flip)
correlation = np.abs(np.dot(w, w_old))
if 1 - correlation < tol:
converged = True
break
# Final source extraction
source = w @ X_whitened
return w, source, n_iter, converged
[docs]
def iterative_dss(
data: np.ndarray,
denoiser: Callable[[np.ndarray], np.ndarray],
n_components: int,
*,
method: str = "deflation",
rank: int | None = None,
reg: float = 1e-9,
max_iter: int = 100,
tol: float = 1e-6,
w_init: np.ndarray | None = None,
verbose: bool = False,
alpha: float | Callable | None = None,
beta: float | Callable | None = None,
gamma: float | Callable | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extract multiple DSS components using iterative (nonlinear) algorithm.
This implements the Iterative DSS algorithm from Särelä & Valpola (2005) [1]_.
Unlike linear DSS which uses a closed-form eigendecomposition, iterative DSS
uses fixed-point iteration with a nonlinear denoising function.
**Algorithm Overview**::
1. Center data: X = X - mean(X)
2. Whiten data: X_white = Whitener @ X (identity covariance)
3. Extract components using deflation or symmetric method:
Deflation (sequential):
For each component i = 1..n_components:
w_i = iterative_dss_one(X_deflated)
Orthogonalize w_i against w_1..w_{i-1}
Deflate: X_deflated -= w_i @ w_i.T @ X_deflated
Symmetric (parallel):
Initialize W = [w_1, ..., w_n] randomly
Repeat until converged:
Update all w_i simultaneously
W = symmetric_orthogonalize(W)
4. Convert filters to sensor space: filters = W @ Whitener
Parameters
----------
data : ndarray, shape (n_channels, n_times) or (n_epochs, n_channels, n_times)
Input multichannel data.
denoiser : callable or NonlinearDenoiser
Nonlinear denoising function f(s) → s_denoised.
Examples: TanhMaskDenoiser, GaussDenoiser, WienerMaskDenoiser.
n_components : int
Number of components to extract.
method : {'deflation', 'symmetric'}
Component extraction method:
- ``'deflation'``: Extract one-by-one, orthogonalizing after each.
More stable, but order-dependent.
- ``'symmetric'``: Update all simultaneously, then orthogonalize.
Order-independent, may be less stable.
Default ``'deflation'``.
rank : int, optional
Rank for whitening. If None, auto-determined from eigenvalue threshold.
reg : float
Regularization for whitening eigenvalue cutoff. Default 1e-9.
max_iter : int
Maximum iterations per component. Default 100.
tol : float
Convergence tolerance. Default 1e-6.
w_init : ndarray, shape (n_components, n_whitened), optional
Initial weight matrix. If None, random initialization.
verbose : bool
Print convergence info. Default False.
alpha : float or callable, optional
Signal normalization factor (see ``iterative_dss_one``).
beta : float or callable, optional
Newton step parameter (see ``iterative_dss_one``).
gamma : float or callable, optional
Learning rate / relaxation (see ``iterative_dss_one``).
Returns
-------
filters : ndarray, shape (n_components, n_channels)
DSS spatial filters in sensor space.
Apply as: ``sources = filters @ data``.
sources : ndarray, shape (n_components, n_times)
Extracted source time series.
patterns : ndarray, shape (n_channels, n_components)
Spatial patterns for visualization / reconstruction.
Note: These are returned in original sensor units (not normalized),
satisfying the identity :math:`X_{recon} = patterns @ sources`.
convergence_info : ndarray, shape (n_components, 2)
``[n_iterations, converged]`` for each component.
Examples
--------
>>> import numpy as np
>>> from mne_denoise.dss import iterative_dss
>>> from mne_denoise.dss.denoisers import TanhMaskDenoiser
>>> # Basic usage with numpy array
>>> data = np.random.randn(10, 1000)
>>> denoiser = TanhMaskDenoiser()
>>> filters, sources, patterns, _ = iterative_dss(
... data, denoiser, n_components=2, method="symmetric"
... )
See Also
--------
iterative_dss_one : Single component extraction.
IterativeDSS : Sklearn-style estimator wrapper.
References
----------
.. [1] Särelä & Valpola (2005). Denoising Source Separation. JMLR, 6, 233-272.
"""
# Use helper for validation/extraction
data, _, mne_type, _ = extract_data_from_mne(data)
# Flatten if 3D (assume n_epochs, n_channels, n_times)
if data.ndim == 3:
n_epochs, n_channels, n_times = data.shape
data_2d = data.transpose(1, 0, 2).reshape(n_channels, -1)
else:
data_2d = data
if data_2d.ndim != 2:
raise ValueError(f"Data must be 2D or 3D, got {data.ndim}D")
n_channels, n_samples = data_2d.shape
# Center data
data_centered = data_2d - data_2d.mean(axis=1, keepdims=True)
# Whiten data
X_whitened, whitener, dewhitener = whiten_data(data_centered, rank=rank, reg=reg)
n_whitened = X_whitened.shape[0]
# Limit components to whitened dimension
n_components = min(n_components, n_whitened)
if method == "deflation":
filters_whitened, sources, convergence_info = _iterative_dss_deflation(
X_whitened,
denoiser,
n_components,
max_iter=max_iter,
tol=tol,
w_init=w_init,
verbose=verbose,
alpha=alpha,
beta=beta,
gamma=gamma,
random_state=random_state,
)
elif method == "symmetric":
filters_whitened, sources, convergence_info = _iterative_dss_symmetric(
X_whitened,
denoiser,
n_components,
max_iter=max_iter,
tol=tol,
w_init=w_init,
verbose=verbose,
alpha=alpha,
beta=beta,
gamma=gamma,
random_state=random_state,
)
else:
raise ValueError(f"Unknown method: {method}. Use 'deflation' or 'symmetric'")
# Convert filters from whitened to sensor space
# filters_whitened: (n_components, n_whitened)
# whitener: (n_whitened, n_channels)
# sensor_filter = whitened_filter @ whitener
filters = filters_whitened @ whitener # (n_components, n_channels)
# patterns = C @ filters.T
C = data_centered @ data_centered.T / n_samples
patterns = C @ filters.T
return filters, sources, patterns, convergence_info
def _iterative_dss_deflation(
X_whitened: np.ndarray,
denoiser: Callable,
n_components: int,
*,
max_iter: int,
tol: float,
w_init: np.ndarray | None,
verbose: bool,
alpha: float | Callable | None = None,
beta: float | Callable | None = None,
gamma: float | Callable | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract components one-by-one using deflation.
**Algorithm**::
For i = 1..n_components:
1. w_i = iterative_dss_one(X_deflated) # Extract one component
2. Orthogonalize: w_i -= W_prev.T @ (W_prev @ w_i)
3. Normalize: w_i = w_i / ||w_i||
4. s_i = w_i @ X_whitened # Extract source
5. Deflate: X_deflated -= w_i @ w_i.T @ X_deflated
Parameters
----------
X_whitened : ndarray, shape (n_whitened, n_times)
Whitened data with identity covariance.
denoiser : callable
Nonlinear denoising function.
n_components : int
Number of components to extract.
max_iter, tol : int, float
Convergence parameters passed to ``iterative_dss_one``.
w_init : ndarray, optional
Initial weight matrix.
verbose : bool
Print progress.
alpha, beta, gamma : optional
Passed to ``iterative_dss_one``.
Returns
-------
W : ndarray, shape (n_components, n_whitened)
Weight matrix (spatial filters in whitened space).
sources : ndarray, shape (n_components, n_times)
Extracted source time series.
convergence_info : ndarray, shape (n_components, 2)
[n_iter, converged] per component.
"""
n_whitened, n_times = X_whitened.shape
# Initialize RNG (handle both int and Generator)
if isinstance(random_state, np.random.Generator):
rng = random_state
else:
rng = np.random.default_rng(random_state)
# Storage
W = np.zeros((n_components, n_whitened))
sources = np.zeros((n_components, n_times))
convergence_info = np.zeros((n_components, 2))
X_deflated = X_whitened.copy()
for i in range(n_components):
# Get initial weight
if w_init is not None and i < w_init.shape[0]:
w_i = w_init[i]
else:
w_i = None
# Run single-component iteration
w, source, n_iter, converged = iterative_dss_one(
X_deflated,
denoiser,
w_init=w_i,
max_iter=max_iter,
tol=tol,
alpha=alpha,
beta=beta,
gamma=gamma,
random_state=rng,
)
if verbose:
status = "converged" if converged else "max_iter"
print(f" Component {i + 1}: {n_iter} iterations ({status})")
# Orthogonalize against previous components (vectorized)
if i > 0:
W_prev = W[:i] # (i, n_whitened)
# Vectorized: w - W_prev.T @ (W_prev @ w)
w = w - W_prev.T @ (W_prev @ w)
norm = np.linalg.norm(w)
if norm < 1e-12:
if verbose:
print(f" Component {i + 1}: degenerate, using random")
w = rng.standard_normal(n_whitened)
w = w - W_prev.T @ (W_prev @ w)
norm = np.linalg.norm(w)
w = w / norm
W[i] = w
sources[i] = w @ X_whitened
convergence_info[i] = [n_iter, float(converged)]
# Deflate: remove component from data
outer = np.outer(w, w)
X_deflated = X_deflated - outer @ X_deflated
return W, sources, convergence_info
def _iterative_dss_symmetric(
X_whitened: np.ndarray,
denoiser: Callable,
n_components: int,
*,
max_iter: int,
tol: float,
w_init: np.ndarray | None,
verbose: bool,
alpha: float | Callable | None = None,
beta: float | Callable | None = None,
gamma: float | Callable | None = None,
random_state: int | np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract components simultaneously with symmetric orthogonalization.
**Algorithm**::
Initialize: W = random (n_components, n_whitened)
W = symmetric_orthogonalize(W)
Repeat until converged:
S = W @ X
S_denoised = f(S)
W_new = E[S_denoised @ X.T] + beta * W
W = symmetric_orthogonalize(W_new)
Check convergence
Parameters
----------
X_whitened : ndarray, shape (n_whitened, n_times)
Whitened data with identity covariance.
denoiser : callable
Nonlinear denoising function.
n_components : int
Number of components to extract.
max_iter, tol : int, float
Convergence parameters.
w_init : ndarray, optional
Initial weight matrix.
verbose : bool
Print progress.
alpha, beta, gamma : optional
Iteration parameters.
Returns
-------
W : ndarray, shape (n_components, n_whitened)
Weight matrix (spatial filters in whitened space).
sources : ndarray, shape (n_components, n_times)
Extracted source time series.
convergence_info : ndarray, shape (n_components, 2)
[n_iter, converged] (same for all components in symmetric).
"""
n_whitened, n_times = X_whitened.shape
# Initialize weight matrix
if w_init is not None:
W = w_init[:n_components, :n_whitened].copy()
else:
rng = np.random.default_rng(random_state)
W = rng.standard_normal((n_components, n_whitened))
# Symmetric orthogonalization (decorrelation)
W = _symmetric_orthogonalize(W)
# Resolve parameters to callables
alpha_func = _resolve_callable(alpha, None)
beta_func = _resolve_callable(beta, None)
# Gamma not typically used in vectorized symmetric step, but could be added
convergence_info = np.zeros((n_components, 2))
for iteration in range(max_iter):
W_old = W.copy()
# 1. Project to sources
# W: (n_comp, n_white), X: (n_white, n_times) -> S: (n_comp, n_times)
S = W @ X_whitened
# 2. Apply denoiser (vectorized)
S_denoised = denoiser(S)
# Apply alpha
if alpha_func is not None:
# Broadcast alpha across columns if it returns (n_comp,) or scalar
a = alpha_func(S)
if np.ndim(a) == 1:
a = a[:, np.newaxis]
S_denoised = a * S_denoised
# 3. Update weights
# Gradient part: E[S_denoised @ X.T] -> (n_comp, n_white)
gradient = S_denoised @ X_whitened.T / n_times
# Beta part
step_beta = 0.0
if beta_func is not None:
b = beta_func(S)
if np.ndim(b) == 1:
b = b[:, np.newaxis]
step_beta = b
W = gradient + step_beta * W
# 4. Symmetric orthogonalization
W = _symmetric_orthogonalize(W)
# 5. Check convergence (max change across components)
# Dot product of rows: diag(W @ W_old.T)
correlations = np.abs(np.sum(W * W_old, axis=1))
max_change = np.max(1 - correlations)
if max_change < tol:
if verbose:
print(f" Symmetric: converged at iteration {iteration + 1}")
convergence_info[:, 0] = iteration + 1
convergence_info[:, 1] = 1.0
break
else:
if verbose:
print(f" Symmetric: max iterations ({max_iter})")
convergence_info[:, 0] = max_iter
convergence_info[:, 1] = 0.0
# Extract final sources
sources = W @ X_whitened
return W, sources, convergence_info
def _symmetric_orthogonalize(W: np.ndarray) -> np.ndarray:
"""Symmetric orthogonalization using (W * W.T)^{-1/2} * W."""
# EVD of W @ W.T
gram = W @ W.T
D, E = np.linalg.eigh(gram)
# Handle numerical issues
D = np.maximum(D, 1e-12)
# W_orth = E @ diag(1/sqrt(D)) @ E.T @ W
D_inv_sqrt = np.diag(1.0 / np.sqrt(D))
W_orth = E @ D_inv_sqrt @ E.T @ W
return W_orth
[docs]
class IterativeDSS:
"""Iterative (Nonlinear) Denoising Source Separation Transformer.
Implements Iterative DSS as a scikit-learn compatible transformer that
fits natively on MNE-Python objects (Raw, Epochs) or numpy arrays.
Unlike linear DSS which uses closed-form eigendecomposition, Iterative DSS
uses fixed-point iteration with a nonlinear denoising function, making it
equivalent to FastICA when using ICA contrast functions (tanh, gauss, cube).
Parameters
----------
denoiser : callable or NonlinearDenoiser
Nonlinear denoising function f(s) → s_denoised. Must be an instance of
`mne_denoise.dss.NonlinearDenoiser` (e.g. `TanhMaskDenoiser`,
`WienerMaskDenoiser`) or a callable.
n_components : int
Number of components to extract.
method : {'deflation', 'symmetric'}
Component extraction method:
- ``'deflation'``: Extract one-by-one, orthogonalizing after each.
- ``'symmetric'``: Update all simultaneously, then orthogonalize.
Default ``'deflation'``.
rank : int, optional
Rank for whitening. If None, auto-determined from eigenvalue threshold.
reg : float
Regularization for whitening. Default 1e-9.
max_iter : int
Maximum iterations per component. Default 100.
tol : float
Convergence tolerance. Default 1e-6.
verbose : bool
Print convergence info. Default False.
alpha : float or callable, optional
Signal normalization factor applied after denoising.
beta : float or callable, optional
Spectral shift (Newton step) for faster convergence.
Use ``beta_tanh`` for TanhMaskDenoiser, ``beta_pow3`` for cubic.
gamma : float or callable, optional
Learning rate / relaxation parameter.
random_state : int, RandomState, Generator, optional
Seed for random initialization.
Attributes
----------
filters_ : ndarray, shape (n_components, n_channels)
The spatial filters (un-mixing matrix). Apply as: ``sources = filters_ @ data``.
patterns_ : ndarray, shape (n_channels, n_components)
The spatial patterns (mixing matrix). Reconstruct as: ``data = patterns_ @ sources``.
sources_ : ndarray, shape (n_components, n_times)
Extracted sources from fit data.
convergence_info_ : ndarray, shape (n_components, 2)
[n_iterations, converged] for each component.
Examples
--------
>>> from mne_denoise.dss import IterativeDSS
>>> from mne_denoise.dss.denoisers import TanhMaskDenoiser, beta_tanh
>>>
>>> # With numpy array
>>> dss = IterativeDSS(denoiser=denoiser, n_components=3)
>>> dss.fit(data)
>>> sources = dss.transform(data)
>>>
>>> # With MNE Raw object
>>> dss.fit(raw)
>>> sources = dss.transform(raw)
See Also
--------
DSS : Linear DSS transformer.
iterative_dss : Functional API.
"""
[docs]
def __init__(
self,
denoiser: Callable[[np.ndarray], np.ndarray],
n_components: int,
*,
method: str = "deflation",
rank: int | None = None,
reg: float = 1e-9,
normalize_input: bool = True,
max_iter: int = 100,
tol: float = 1e-6,
verbose: bool = False,
alpha: float | Callable | None = None,
beta: float | Callable | None = None,
gamma: float | Callable | None = None,
random_state: int | np.random.Generator | None = None,
) -> None:
self.denoiser = denoiser
self.n_components = n_components
self.method = method
self.rank = rank
self.reg = reg
self.normalize_input = normalize_input
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.random_state = random_state
# Fitted attributes
self.filters_: np.ndarray | None = None
self.patterns_: np.ndarray | None = None
self.sources_: np.ndarray | None = None
self.convergence_info_: np.ndarray | None = None
self._mne_info = None
def fit(self, X) -> IterativeDSS:
"""Compute Iterative DSS spatial filters.
Parameters
----------
X : Raw | Epochs | ndarray
The data to fit. Accepts:
- ``mne.io.Raw``: Continuous data
- ``mne.Epochs``: Epoched data
- ``ndarray``: Shape (n_channels, n_times) or (n_channels, n_times, n_epochs)
Returns
-------
self : IterativeDSS
The fitted transformer.
"""
# Validate and extract data using shared helper
data, _, mne_type, mne_info = extract_data_from_mne(X)
# Store MNE info for later use if available
if (
mne_type in ("raw", "epochs")
and mne_info is not None
and hasattr(mne_info, "info")
):
self._mne_info = mne_info.info
if self.normalize_input:
# Flatten for std calculation: (n_ch, n_times * n_epochs)
d_flat = (
data.transpose(1, 0, 2).reshape(data.shape[1], -1)
if data.ndim == 3
else data
)
self.channel_norms_ = np.std(d_flat, axis=1)
self.channel_norms_ = np.where(
self.channel_norms_ > 0, self.channel_norms_, 1.0
)
# Apply to data
if data.ndim == 3:
data = data / self.channel_norms_[np.newaxis, :, np.newaxis]
else:
data = data / self.channel_norms_[:, np.newaxis]
filters, sources, patterns, conv_info = iterative_dss(
data,
self.denoiser,
self.n_components,
method=self.method,
rank=self.rank,
reg=self.reg,
max_iter=self.max_iter,
tol=self.tol,
verbose=self.verbose,
alpha=self.alpha,
beta=self.beta,
gamma=self.gamma,
random_state=self.random_state,
)
self.filters_ = filters
self.patterns_ = patterns
self.sources_ = sources
self.convergence_info_ = conv_info
return self
def transform(self, X) -> np.ndarray:
"""Apply fitted filters to extract sources.
Parameters
----------
X : Raw | Epochs | ndarray
Data to transform. Same formats as ``fit()``.
Returns
-------
sources : ndarray, shape (n_components, n_times) or (n_components, n_times, n_epochs)
Extracted source time series.
"""
if self.filters_ is None:
raise RuntimeError("IterativeDSS not fitted. Call fit() first.")
# Validate and extract data
data, _, mne_type, _ = extract_data_from_mne(X)
original_shape = data.shape
if data.ndim == 3:
n_epochs, n_channels, n_times = data.shape
data_2d = data.transpose(1, 0, 2).reshape(n_channels, -1)
else:
data_2d = data
if self.normalize_input:
if self.channel_norms_ is None:
raise RuntimeError(
"IterativeDSS not fitted with normalize_input=True. Call fit() first."
)
data_2d = data_2d / self.channel_norms_[:, np.newaxis]
# Center
data_centered = data_2d - data_2d.mean(axis=1, keepdims=True)
# Apply filters
sources = self.filters_ @ data_centered
# Reshape to original 3D if needed
if len(original_shape) == 3:
# original: (n_epochs, n_channels, n_times)
# sources now: (n_components, n_epochs * n_times)
# reshape to: (n_components, n_epochs, n_times)
n_epochs, n_channels, n_times = original_shape
sources = sources.reshape(self.n_components, n_epochs, n_times)
# transpose to standard MNE: (n_epochs, n_components, n_times)
sources = sources.transpose(1, 0, 2)
return sources
def inverse_transform(self, sources: np.ndarray) -> np.ndarray:
"""Reconstruct data from sources.
Parameters
----------
sources : ndarray, shape (n_sources, n_times)
Source time series. Can use fewer sources than fitted.
Returns
-------
reconstructed : ndarray, shape (n_channels, n_times)
Reconstructed data: ``patterns_[:, :n_sources] @ sources``.
"""
if self.patterns_ is None:
raise RuntimeError("IterativeDSS not fitted. Call fit() first.")
n_comp_sources = sources.shape[1] if sources.ndim == 3 else sources.shape[0]
patterns = self.patterns_[:, :n_comp_sources]
if sources.ndim == 3:
# Assume MNE format (n_epochs, n_comp, n_times)
rec = np.tensordot(sources, patterns, axes=(1, 1)).transpose(0, 2, 1)
if self.normalize_input:
if self.channel_norms_ is None:
raise RuntimeError(
"IterativeDSS not fitted with normalize_input=True. Call fit() first."
)
rec *= self.channel_norms_[np.newaxis, :, np.newaxis]
else:
rec = patterns @ sources
if self.normalize_input:
if self.channel_norms_ is None:
raise RuntimeError(
"IterativeDSS not fitted with normalize_input=True. Call fit() first."
)
rec *= self.channel_norms_[:, np.newaxis]
return rec
def get_normalized_patterns(self) -> np.ndarray:
"""Get L2-normalized spatial patterns for visualization.
Returns
-------
patterns_norm : ndarray, shape (n_channels, n_components)
L2-normalized spatial patterns.
"""
if self.patterns_ is None:
raise RuntimeError("IterativeDSS not fitted. Call fit() first.")
norms = np.linalg.norm(self.patterns_, axis=0)
# Use relative threshold for physical units
max_norm = np.max(norms)
threshold = 1e-15 * max_norm if max_norm > 0 else 1e-30
norms = np.where(norms > threshold, norms, 1.0)
return self.patterns_ / norms
def fit_transform(self, X) -> np.ndarray:
"""Fit and transform in one step.
Parameters
----------
X : Raw | Epochs | ndarray
Data to fit and transform.
Returns
-------
sources : ndarray
Extracted sources.
"""
return self.fit(X).transform(X)