Note
Go to the end to download the full example code.
Time-Frequency DSS: Spectrogram Masking.#
This example demonstrates DSS for extracting transient oscillatory bursts using time-frequency (TF) domain constraints via spectrogram masking.
We cover SpectrogramBias (linear) and SpectrogramDenoiser (nonlinear) for isolating activity that is sparse in the TF domain.
The examples move from synthetic spindle-like bursts to fixed-mask DSS, adaptive spectrogram denoising, and a real MEG gamma-burst example.
- Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)
Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)
Imports#
import matplotlib.pyplot as plt
import mne
import numpy as np
from mne.datasets import somato
from scipy import signal as sp_signal
from mne_denoise.dss import DSS, IterativeDSS
from mne_denoise.dss.denoisers import SpectrogramBias, SpectrogramDenoiser
from mne_denoise.viz import (
plot_channel_time_course_comparison,
plot_component_spectrogram,
plot_component_summary,
plot_signal_overlay,
plot_spectrogram_comparison,
plot_time_frequency_mask,
)
Part 0: Synthetic Transient Bursts (Sleep Spindles)#
Simulate 12 Hz spindle bursts embedded in noise
print("--- Part 0: Synthetic Spindle Bursts ---")
rng = np.random.default_rng(42)
sfreq = 250 # Hz
n_seconds = 10
n_times = n_seconds * sfreq
times = np.arange(n_times) / sfreq
# Background noise (broadband)
noise = rng.normal(0, 1.0, n_times)
# Spindle bursts at specific times
spindle_freq = 12.0 # Hz
envelope = np.zeros_like(times)
# Burst 1: 2-3 seconds
mask1 = (times >= 2) & (times < 3)
envelope[mask1] = np.hanning(mask1.sum())
# Burst 2: 7-8 seconds
mask2 = (times >= 7) & (times < 8)
envelope[mask2] = np.hanning(mask2.sum())
signal_spindle = envelope * np.sin(2 * np.pi * spindle_freq * times) * 3.0
# Mixed data
data_mixed = signal_spindle + noise
# Visualize
fig, axes = plt.subplots(3, 1, figsize=(14, 8), sharex=True)
axes[0].plot(times, signal_spindle, "b", linewidth=1.5, label="Clean Spindle")
axes[0].set_title("Ground Truth: 12 Hz Spindle Bursts")
axes[0].set_ylabel("Amplitude")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(times, noise, "gray", alpha=0.7, label="Broadband Noise")
axes[1].set_title("Noise")
axes[1].set_ylabel("Amplitude")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[2].plot(times, data_mixed, "r", alpha=0.7, label="Mixed (Signal + Noise)")
axes[2].set_title("Observed Data")
axes[2].set_xlabel("Time (s)")
axes[2].set_ylabel("Amplitude")
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

--- Part 0: Synthetic Spindle Bursts ---
Visualize Time-Frequency Representation#
Compute spectrogram to see bursts in TF domain
f, t, Sxx = sp_signal.spectrogram(data_mixed, fs=sfreq, nperseg=128, noverlap=96)
plt.figure(figsize=(12, 5))
plt.pcolormesh(t, f, 10 * np.log10(Sxx), shading="gouraud", cmap="viridis")
plt.colorbar(label="Power (dB)")
plt.ylabel("Frequency (Hz)")
plt.xlabel("Time (s)")
plt.title("Spectrogram: Spindle Bursts in Time-Frequency Domain")
plt.ylim(0, 30)
plt.axhline(spindle_freq, color="r", linestyle="--", alpha=0.7, label="12 Hz Spindle")
plt.legend()
plt.tight_layout()
plt.show()

Part 1: SpectrogramBias with Fixed Mask#
Define a TF mask to isolate burst regions
print("\n--- Part 1: Fixed TF Mask (SpectrogramBias) ---")
# Create manually defined TF mask for DSS
# Focus on 10-15 Hz range and burst time windows
nperseg = 128
noverlap = 96
# Get TF grid dimensions
_, t_grid, _ = sp_signal.spectrogram(
data_mixed, fs=sfreq, nperseg=nperseg, noverlap=noverlap
)
n_freqs = nperseg // 2 + 1
n_times_tf = len(t_grid)
# Create mask: 1 where we expect spindles, 0 elsewhere
mask_fixed = np.zeros((n_freqs, n_times_tf))
# Define frequency band (10-15 Hz)
freq_axis = np.fft.rfftfreq(nperseg, 1 / sfreq)
freq_mask = (freq_axis >= 10) & (freq_axis <= 15)
# Define time windows for bursts
time_mask1 = (t_grid >= 2) & (t_grid < 3)
time_mask2 = (t_grid >= 7) & (t_grid < 8)
# Apply mask
mask_fixed[freq_mask[:, None] & (time_mask1 | time_mask2)] = 1.0
print(f"TF mask shape: {mask_fixed.shape}")
print(
f"Masked points: {mask_fixed.sum()} / {mask_fixed.size} "
f"({100 * mask_fixed.sum() / mask_fixed.size:.1f}%)"
)
# Visualize mask
plot_time_frequency_mask(
mask_fixed,
t_grid,
freq_axis,
title="SpectrogramBias Mask: 10-15 Hz Spindles",
show=False,
)
plt.show()
# Apply SPectrogramBias with DSS
# For demonstration, create multi-channel data
n_channels = 8
data_multichan = np.tile(data_mixed, (n_channels, 1)) + rng.normal(
0, 0.2, (n_channels, n_times)
)
# Create MNE Raw with montage for DSS
ch_names = ["Fz", "Cz", "Pz", "F3", "F4", "C3", "C4", "Oz"]
info = mne.create_info(ch_names, sfreq, "eeg")
montage = mne.channels.make_standard_montage("standard_1020")
info.set_montage(montage)
raw_spindle = mne.io.RawArray(data_multichan, info)
# Apply SpectrogramBias
bias_tf = SpectrogramBias(mask=mask_fixed, nperseg=nperseg, noverlap=noverlap)
dss_tf = DSS(n_components=3, bias=bias_tf)
dss_tf.fit(raw_spindle)
print(f"\nDSS Eigenvalues: {dss_tf.eigenvalues_[:3]}")
# Visualize DSS components
plot_component_summary(
dss_tf,
data=raw_spindle,
info=raw_spindle.info,
picks=np.arange(len(raw_spindle.ch_names)),
n_components=2,
show=False,
)
plt.gcf().suptitle("SpectrogramBias: DSS Components")
plt.show()
# Extract component
sources_tf = dss_tf.transform(raw_spindle)
comp0_tf = sources_tf[0]
# Compare with ground truth
if np.corrcoef(comp0_tf, signal_spindle)[0, 1] < 0:
comp0_tf *= -1
corr_tf = np.corrcoef(comp0_tf, signal_spindle)[0, 1]
print(f"Correlation with ground truth: {corr_tf:.3f}")
# Spectrogram comparison
raw_orig = mne.io.RawArray(data_mixed[np.newaxis, :], mne.create_info(1, sfreq, "eeg"))
raw_comp = mne.io.RawArray(comp0_tf[np.newaxis, :], mne.create_info(1, sfreq, "eeg"))
--- Part 1: Fixed TF Mask (SpectrogramBias) ---
TF mask shape: (65, 75)
Masked points: 32.0 / 4875 (0.7%)
Creating RawArray with float64 data, n_channels=8, n_times=2500
Range : 0 ... 2499 = 0.000 ... 9.996 secs
Ready.
DSS Eigenvalues: [0.18872036 0.01285285 0.00731457]
Correlation with ground truth: 0.512
Creating RawArray with float64 data, n_channels=1, n_times=2500
Range : 0 ... 2499 = 0.000 ... 9.996 secs
Ready.
Creating RawArray with float64 data, n_channels=1, n_times=2500
Range : 0 ... 2499 = 0.000 ... 9.996 secs
Ready.
plot_spectrogram_comparison(
raw_orig,
raw_comp,
picks=[0],
times=raw_orig.times,
fmin=5,
fmax=20,
show=False,
)
plt.gcf().suptitle("Spectrogram Comparison: Original vs Extracted Spindles")
plt.show()
# Plot comparison

plot_signal_overlay(
signal_spindle,
comp0_tf,
times=times,
title="Spindle Reconstruction: Ground Truth vs SpectrogramBias Component",
scale_after=True,
show=False,
)
plt.show()

Part 2: SpectrogramDenoiser + IterativeDSS (Adaptive)#
Automatically find TF regions with high energy
print("\n--- Part 2: Adaptive TF Masking (SpectrogramDenoiser) ---")
# SpectrogramDenoiser keeps only top percentile of TF energy
spec_denoiser = SpectrogramDenoiser(
threshold_percentile=90, # Keep top 10%
nperseg=128,
noverlap=96,
)
# Use with IterativeDSS on original spindle data
idss_spec = IterativeDSS(denoiser=spec_denoiser, n_components=2, max_iter=3)
idss_spec.fit(raw_spindle)
print("IterativeDSS with SpectrogramDenoiser converged")
sources_idss = idss_spec.transform(raw_spindle)
comp0_idss = sources_idss[0]
if np.corrcoef(comp0_idss, signal_spindle)[0, 1] < 0:
comp0_idss *= -1
corr_idss = np.corrcoef(comp0_idss, signal_spindle)[0, 1]
print(f"Correlation with ground truth: {corr_idss:.3f}")
# Compare both methods
fig, axes = plt.subplots(3, 1, figsize=(14, 9), sharex=True)
axes[0].plot(times, signal_spindle, "k", linewidth=2, label="Ground Truth")
axes[0].set_title("Ground Truth: Spindle Bursts")
axes[0].set_ylabel("Amplitude")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(
times,
comp0_tf * (np.std(signal_spindle) / np.std(comp0_tf)),
"b",
label=f"Fixed Mask (r={corr_tf:.3f})",
)
axes[1].set_title("SpectrogramBias: Fixed TF Mask")
axes[1].set_ylabel("Amplitude")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[2].plot(
times,
comp0_idss * (np.std(signal_spindle) / np.std(comp0_idss)),
"r",
label=f"Adaptive Mask (r={corr_idss:.3f})",
)
axes[2].set_title("SpectrogramDenoiser: Adaptive TF Masking")
axes[2].set_xlabel("Time (s)")
axes[2].set_ylabel("Amplitude")
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

--- Part 2: Adaptive TF Masking (SpectrogramDenoiser) ---
IterativeDSS with SpectrogramDenoiser converged
Correlation with ground truth: 0.111
Part 3: Real MEG Gamma Bursts (Somato Dataset)#
Extract transient gamma oscillations
print("\n--- Part 3: Real MEG Data (Gamma Bursts) ---")
# Download somato dataset (will skip if already downloaded)
data_path = somato.data_path(verbose=True)
raw_path = data_path / "sub-01" / "meg" / "sub-01_task-somato_meg.fif"
raw_somato = mne.io.read_raw_fif(raw_path, preload=True, verbose=False)
raw_somato.pick_types(meg="grad", eeg=False, eog=False, stim=False, exclude="bads")
# Use broader band and keep more data
raw_somato.filter(1, 100, fir_design="firwin", verbose=False) # Broad band first
raw_somato.crop(0, 60) # 60 seconds for more data
print(f"MEG Data: {len(raw_somato.ch_names)} channels, {raw_somato.times[-1]:.1f}s")
# Apply SpectrogramDenoiser
spec_denoiser_meg = SpectrogramDenoiser(
threshold_percentile=95, # Top 5% of TF energy
nperseg=256,
)
idss_meg = IterativeDSS(denoiser=spec_denoiser_meg, n_components=3, max_iter=3)
idss_meg.fit(raw_somato)
print("\nIterativeDSS on MEG data converged")
# Visualize (skip topomaps - use time series)
sources_meg = idss_meg.transform(raw_somato)
# Create Raw for visualization
comp_raw_meg = mne.io.RawArray(
sources_meg[:1], mne.create_info(1, raw_somato.info["sfreq"], "misc")
)
raw_single_meg = raw_somato.copy().pick([0])
# --- Time Course Comparison ---
plot_channel_time_course_comparison(
raw_single_meg,
comp_raw_meg,
picks=[0],
times=raw_single_meg.times,
start=0,
stop=int(5 * raw_single_meg.info["sfreq"]),
show=False,
)
plt.gcf().suptitle("Real MEG: Original vs TF-DSS Component 0 (Gamma Bursts)")
plt.show()
# --- Spectrogram Comparison (Pre/Post) ---
# Compares broadband raw data vs the broadband component

--- Part 3: Real MEG Data (Gamma Bursts) ---
Using default location ~/mne_data for somato...
Fetching 1 file for the somato dataset ...
0%| | 0.00/610M [00:00<?, ?B/s]
0%|▏ | 3.00M/610M [00:00<00:20, 30.0MB/s]
2%|▋ | 10.7M/610M [00:00<00:10, 57.4MB/s]
3%|█▏ | 18.9M/610M [00:00<00:08, 68.8MB/s]
4%|█▋ | 27.0M/610M [00:00<00:07, 73.8MB/s]
6%|██▏ | 35.2M/610M [00:00<00:07, 76.6MB/s]
7%|██▋ | 42.9M/610M [00:00<00:08, 65.4MB/s]
8%|███▏ | 51.1M/610M [00:00<00:07, 70.4MB/s]
10%|███▋ | 59.4M/610M [00:00<00:07, 74.0MB/s]
11%|████▏ | 67.6M/610M [00:00<00:07, 76.3MB/s]
12%|████▋ | 75.8M/610M [00:01<00:06, 78.2MB/s]
14%|█████▏ | 83.7M/610M [00:01<00:06, 77.8MB/s]
15%|█████▋ | 91.6M/610M [00:01<00:06, 78.0MB/s]
16%|██████▏ | 99.4M/610M [00:01<00:06, 77.0MB/s]
18%|██████▊ | 107M/610M [00:01<00:06, 76.8MB/s]
19%|███████▎ | 115M/610M [00:01<00:06, 77.0MB/s]
20%|███████▊ | 123M/610M [00:01<00:06, 77.3MB/s]
21%|████████▎ | 130M/610M [00:01<00:06, 77.3MB/s]
23%|████████▊ | 138M/610M [00:01<00:06, 77.2MB/s]
24%|█████████▎ | 146M/610M [00:01<00:06, 77.2MB/s]
25%|█████████▊ | 154M/610M [00:02<00:05, 77.4MB/s]
26%|██████████▎ | 161M/610M [00:02<00:05, 77.4MB/s]
28%|██████████▊ | 169M/610M [00:02<00:05, 77.6MB/s]
29%|███████████▎ | 177M/610M [00:02<00:05, 77.5MB/s]
30%|███████████▊ | 185M/610M [00:02<00:05, 77.5MB/s]
32%|████████████▎ | 193M/610M [00:02<00:05, 77.7MB/s]
33%|████████████▊ | 200M/610M [00:02<00:05, 78.1MB/s]
34%|█████████████▎ | 208M/610M [00:02<00:05, 68.6MB/s]
35%|█████████████▊ | 216M/610M [00:02<00:05, 70.8MB/s]
37%|██████████████▎ | 224M/610M [00:03<00:05, 72.8MB/s]
38%|██████████████▊ | 231M/610M [00:03<00:09, 41.6MB/s]
39%|███████████████▎ | 239M/610M [00:03<00:07, 48.3MB/s]
40%|███████████████▊ | 247M/610M [00:03<00:06, 54.6MB/s]
42%|████████████████▎ | 254M/610M [00:03<00:05, 60.0MB/s]
43%|████████████████▊ | 262M/610M [00:03<00:05, 64.2MB/s]
44%|█████████████████▏ | 269M/610M [00:03<00:06, 49.7MB/s]
45%|█████████████████▋ | 277M/610M [00:04<00:06, 55.3MB/s]
47%|██████████████████▏ | 285M/610M [00:04<00:05, 60.4MB/s]
48%|██████████████████▋ | 291M/610M [00:04<00:05, 56.0MB/s]
49%|███████████████████ | 299M/610M [00:04<00:05, 60.9MB/s]
50%|███████████████████▌ | 306M/610M [00:04<00:05, 59.4MB/s]
51%|████████████████████ | 314M/610M [00:04<00:04, 63.7MB/s]
53%|████████████████████▌ | 321M/610M [00:04<00:04, 67.0MB/s]
54%|█████████████████████ | 329M/610M [00:04<00:04, 69.8MB/s]
55%|█████████████████████▌ | 337M/610M [00:04<00:03, 71.7MB/s]
56%|█████████████████████▉ | 344M/610M [00:05<00:03, 70.2MB/s]
58%|██████████████████████▍ | 351M/610M [00:05<00:04, 64.0MB/s]
59%|██████████████████████▉ | 359M/610M [00:05<00:03, 66.7MB/s]
60%|███████████████████████▍ | 367M/610M [00:05<00:03, 71.1MB/s]
61%|███████████████████████▉ | 375M/610M [00:05<00:03, 74.5MB/s]
63%|████████████████████████▌ | 383M/610M [00:05<00:02, 76.8MB/s]
64%|█████████████████████████ | 392M/610M [00:05<00:02, 78.2MB/s]
66%|█████████████████████████▌ | 400M/610M [00:05<00:02, 79.3MB/s]
67%|██████████████████████████ | 408M/610M [00:05<00:02, 80.3MB/s]
68%|██████████████████████████▌ | 416M/610M [00:06<00:02, 81.1MB/s]
70%|███████████████████████████▏ | 425M/610M [00:06<00:02, 81.6MB/s]
71%|███████████████████████████▋ | 433M/610M [00:06<00:02, 81.9MB/s]
72%|████████████████████████████▏ | 441M/610M [00:06<00:02, 79.8MB/s]
74%|████████████████████████████▋ | 449M/610M [00:06<00:02, 76.5MB/s]
75%|█████████████████████████████▏ | 457M/610M [00:06<00:01, 78.3MB/s]
76%|█████████████████████████████▊ | 466M/610M [00:06<00:01, 79.4MB/s]
78%|██████████████████████████████▎ | 474M/610M [00:06<00:01, 79.9MB/s]
79%|██████████████████████████████▊ | 482M/610M [00:06<00:01, 80.8MB/s]
80%|███████████████████████████████▎ | 490M/610M [00:06<00:01, 81.3MB/s]
82%|███████████████████████████████▊ | 498M/610M [00:07<00:01, 81.3MB/s]
83%|████████████████████████████████▍ | 507M/610M [00:07<00:01, 61.9MB/s]
84%|████████████████████████████████▉ | 515M/610M [00:07<00:01, 67.1MB/s]
86%|█████████████████████████████████▍ | 523M/610M [00:07<00:01, 71.2MB/s]
87%|█████████████████████████████████▉ | 531M/610M [00:07<00:01, 73.1MB/s]
88%|██████████████████████████████████▍ | 539M/610M [00:07<00:00, 75.6MB/s]
90%|██████████████████████████████████▉ | 547M/610M [00:07<00:00, 77.3MB/s]
91%|███████████████████████████████████▌ | 555M/610M [00:07<00:00, 78.6MB/s]
92%|████████████████████████████████████ | 564M/610M [00:07<00:00, 79.4MB/s]
94%|████████████████████████████████████▌ | 572M/610M [00:08<00:00, 80.4MB/s]
95%|█████████████████████████████████████ | 580M/610M [00:08<00:00, 81.2MB/s]
96%|█████████████████████████████████████▌ | 588M/610M [00:08<00:00, 78.3MB/s]
98%|██████████████████████████████████████ | 596M/610M [00:08<00:00, 76.9MB/s]
99%|██████████████████████████████████████▋| 605M/610M [00:08<00:00, 78.5MB/s]
0%| | 0.00/610M [00:00<?, ?B/s]
100%|███████████████████████████████████████| 610M/610M [00:00<00:00, 2.25TB/s]
Download complete in 18s (581.8 MB)
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
MEG Data: 204 channels, 60.0s
IterativeDSS on MEG data converged
Creating RawArray with float64 data, n_channels=1, n_times=18019
Range : 0 ... 18018 = 0.000 ... 59.999 secs
Ready.
plot_spectrogram_comparison(
raw_single_meg,
comp_raw_meg,
picks=[0],
times=raw_single_meg.times,
fmin=1,
fmax=100,
show=False,
)
plt.gcf().suptitle("Spectrogram Comparison: Raw Data vs Extracted Component")
plt.show()
# --- Component Spectrogram (Single Component TFR) ---
# Zoom into the component's frequency content
freqs = np.arange(10, 50, 2)

plot_component_spectrogram(
sources_meg[0],
sfreq=raw_somato.info["sfreq"],
freqs=freqs,
title="TF-DSS Component 0: Extracted Transient Oscillations",
show=False,
)
plt.show()
# --- Full Component Summary (with Topomaps!) ---
# Now that we have info, we can show topomaps
# Note: IterativeDSS stores patterns, so plot_component_summary can extract them.
print("\nPlotting full component dashboard (including Topomaps)...")

Plotting full component dashboard (including Topomaps)...
plot_component_summary(
idss_meg,
data=raw_somato,
info=raw_somato.info,
picks=np.arange(len(raw_somato.ch_names)),
n_components=1,
show=False,
)
plt.gcf().suptitle("TF-DSS Component 0 Dashboard")
plt.show()
print("\nSuccessfully extracted transient gamma oscillations using TF masking!")

Successfully extracted transient gamma oscillations using TF masking!
Total running time of the script: (0 minutes 30.214 seconds)

