Artifact Correction with DSS.#

DSS is a powerful tool for removing artifacts (ECG, EOG) from data. The core idea is: Artifacts are repetitive.

If we can define when artifacts happen (e.g., using EOG/ECG channels), we can use Trial Average Bias to find the artifact source and remove it.

This tutorial demonstrates two artifact-correction workflows with DSS: blink correction based on EOG epochs and heartbeat correction based on ECG epochs.

Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca)

Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca)

Imports#

import contextlib
import os

import matplotlib.pyplot as plt
import mne
import numpy as np
from mne.datasets import sample
from mne.preprocessing import create_ecg_epochs, create_eog_epochs

from mne_denoise.dss import DSS, AverageBias, CycleAverageBias
from mne_denoise.viz import (
    plot_component_patterns,
    plot_component_score_curve,
    plot_component_summary,
    plot_component_time_series,
    plot_evoked_gfp_comparison,
    plot_psd_comparison,
)

Load Data#

We use the MNE sample dataset which contains clear ECG and EOG artifacts.

print("Loading MNE Sample data...")
# Ensure MNE_DATA directory exists
home = os.path.expanduser("~")
mne_data_path = os.path.join(home, "mne_data")
if not os.path.exists(mne_data_path):
    with contextlib.suppress(OSError):
        os.makedirs(mne_data_path)

data_path = sample.data_path()
raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif"
raw = mne.io.read_raw_fif(raw_fname, preload=True, verbose=False)
raw.crop(0, 60)  # Keep full duration but no picking yet
print(f"Data: {len(raw.ch_names)} channels (MEG, EEG, EOG, ECG), 60s duration")
Loading MNE Sample data...
Data: 376 channels (MEG, EEG, EOG, ECG), 60s duration

Denoising Comparison#

We project the data into the DSS space, zero out the first component (the blink), and project back.

print("Removing blink component...")
# Transform continuous data
# We must ensure we apply to the same channels used in fit (the gradiometers).
raw_meg = raw.copy().pick_types(meg="grad", eeg=False, eog=False, ecg=False)
raw_meg_picks = np.arange(len(raw_meg.ch_names))
sources = dss_eog.transform(raw_meg)

# Check Correlation with EOG channel
# This validates that Comp 0 is indeed the blink artifact.
eog_picks = mne.pick_types(raw.info, meg=False, eog=True)
if len(eog_picks) > 0:
    eog_data = raw.get_data(picks=eog_picks[0]).flatten()
    blink_source = sources[0, :]
    for i in range(3):
        comp_source = sources[i, :]
        corr = np.corrcoef(eog_data, comp_source)[0, 1]
        print(f"Correlation (Comp {i} vs EOG): {abs(corr):.3f}")
else:
    print("No EOG channel found for correlation check.")
    # Create dummy data for plot to avoid crash, or skip plot?
    # We'll skip plot logic if no channel, but for now let's hope it exists.
    eog_data = np.zeros(len(sources[0]))
    blink_source = sources[0, :]

# Visual Comparison: EOG vs Component 0
# Show a time window with clear blinks, scaled and aligned

# Find a window with blinks (sample 5000-10000)
start_idx, end_idx = 5000, 10000
t_window = np.arange(start_idx, end_idx) / raw.info["sfreq"]

# Get data snippets
eog_snippet = eog_data[start_idx:end_idx]
comp_snippet = blink_source[start_idx:end_idx]

# Flip component if negatively correlated
corr_window = np.corrcoef(eog_snippet, comp_snippet)[0, 1]
flip = -1 if corr_window < 0 else 1

# Scale component to match EOG amplitude
scale = np.max(np.abs(eog_snippet)) / np.max(np.abs(comp_snippet))

plt.figure(figsize=(12, 4))
plt.plot(t_window, eog_snippet, "b", linewidth=1.5, label="EOG Channel")
plt.plot(
    t_window,
    flip * comp_snippet * scale,
    "r",
    linewidth=1.5,
    label="DSS Comp 0 (aligned & scaled)",
    alpha=0.8,
)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude (a.u.)")
plt.title(f"TrialAverageBias: Blink Peaks Aligned (r={abs(corr):.3f})")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
TrialAverageBias: Blink Peaks Aligned (r=0.045)
Removing blink component...
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Correlation (Comp 0 vs EOG): 0.682
Correlation (Comp 1 vs EOG): 0.340
Correlation (Comp 2 vs EOG): 0.045

Alternative: CycleAverageBias (Continuous Data Approach)#

CycleAverageBias is artifact-specific and works directly on continuous data. Instead of pre-epoching, we provide event samples and a window.

print("\n---  Comparing with CycleAverageBias ---")

# Find blink events from continuous data
from mne.preprocessing import find_eog_events

blink_events = find_eog_events(raw, ch_name="EOG 061", verbose=False)
blink_samples = blink_events[:, 0]

print(f"Found {len(blink_samples)} blink events")

# Create CycleAverageBias
# Window: 100ms before to 100ms after each blink (in samples)
window_samples = (-int(0.1 * raw.info["sfreq"]), int(0.1 * raw.info["sfreq"]))
bias_cycle = CycleAverageBias(event_samples=blink_samples, window=window_samples)

# Fit DSS on continuous MEG data
dss_cycle = DSS(n_components=10, bias=bias_cycle, return_type="sources")
dss_cycle.fit(raw_meg)

print("Fitted DSS with CycleAverageBias")
---  Comparing with CycleAverageBias ---
Found 10 blink events
Fitted DSS with CycleAverageBias

Visualize Cycle Average Components#

plot_component_summary(
    dss_cycle,
    data=raw_meg,
    info=raw_meg.info,
    picks=raw_meg_picks,
    n_components=[0, 1],
    show=False,
)
plt.gcf().suptitle("CycleAverageBias Results")
plt.show()
CycleAverageBias Results, Comp 0 Pattern, Comp 0 Time Course, PSD, Comp 1 Pattern, Comp 1 Time Course, PSD

Compare spatial patterns (both bias types)

print("\n--- Comparing Spatial Patterns ---")
plot_component_patterns(
    dss_eog,
    info=eog_epochs.info,
    picks=eog_sensor_picks,
    n_components=1,
    show=False,
)
plt.gcf().suptitle("TrialAverageBias: Blink Component Topography")
plt.show()
TrialAverageBias: Blink Component Topography, Comp 0
--- Comparing Spatial Patterns ---
plot_component_patterns(
    dss_cycle,
    info=raw_meg.info,
    picks=raw_meg_picks,
    n_components=1,
    show=False,
)
plt.gcf().suptitle("CycleAverageBias: Blink Component Topography")
plt.show()

print("\nBoth approaches extract the same blink artifact!")
print("- TrialAverageBias: Works on MNE Epochs (easier integration)")
print("- CycleAverageBias: Works on continuous data + event samples (more direct)")
CycleAverageBias: Blink Component Topography, Comp 0
Both approaches extract the same blink artifact!
- TrialAverageBias: Works on MNE Epochs (easier integration)
- CycleAverageBias: Works on continuous data + event samples (more direct)
n_samples_plot = int(20 * raw.info["sfreq"])  # Plot 20 seconds
scaler_eog = 1.0 / np.std(eog_data[:n_samples_plot])
scaler_dss = 1.0 / np.std(blink_source[:n_samples_plot])

# Flip DSS source if anti-correlated for better visual comparison
corr_0 = np.corrcoef(eog_data, blink_source)[0, 1]
sign = np.sign(corr_0)
if sign == 0:
    sign = 1

plt.figure(figsize=(10, 4))
times_plot = raw.times[:n_samples_plot]
plt.plot(
    times_plot,
    eog_data[:n_samples_plot] * scaler_eog,
    label="EOG Channel (Norm)",
    color="tab:orange",
    alpha=0.7,
)
plt.plot(
    times_plot,
    blink_source[:n_samples_plot] * scaler_dss * sign,
    label="DSS Comp 0 (Sign-Matched)",
    color="tab:blue",
    alpha=0.7,
)
plt.title(f"Temporal Comparison: EOG vs DSS Comp 0 (Corr={abs(corr_0):.2f})")
plt.xlabel("Time (s)")
plt.legend()
plt.tight_layout()
plt.show()
Temporal Comparison: EOG vs DSS Comp 0 (Corr=0.68)

Part 2: ECG (Heartbeat) Correction#

The same idea applies to heartbeat contamination: epoch on the artifact, fit DSS on those repeats, and remove the dominant cardiac component.

print("\n--- Part 2: ECG (Heartbeat) Correction ---")

# 1. Create ECG Epochs
# We let MNE find the ECG channel automatically (looking for type='ecg')
ecg_epochs = create_ecg_epochs(
    raw, ch_name=None, tmin=-0.1, tmax=0.1, baseline=(None, 0), verbose=False
)
ecg_epochs.pick_types(meg="grad", eeg=False, eog=False, ecg=False)
print(
    f"Found {len(ecg_epochs)} heartbeats. "
    f"Using {len(ecg_epochs.ch_names)} MEG channels."
)
ecg_sensor_picks = np.arange(len(ecg_epochs.ch_names))

# 2. Fit DSS
dss_ecg = DSS(n_components=8, bias=AverageBias(axis="epochs"))
dss_ecg.fit(ecg_epochs)
--- Part 2: ECG (Heartbeat) Correction ---
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Found 59 heartbeats. Using 203 MEG channels.
/home/runner/work/mne-denoise/mne-denoise/mne_denoise/dss/linear.py:496: RuntimeWarning: Epochs are not baseline corrected, covariance matrix may be inaccurate
  baseline_cov = mne.compute_covariance(inst, method=method, **kws)
/home/runner/work/mne-denoise/mne-denoise/mne_denoise/dss/linear.py:498: RuntimeWarning: Epochs are not baseline corrected, covariance matrix may be inaccurate
  biased_cov = mne.compute_covariance(biased_inst, method=method, **kws)
DSS(bias=<mne_denoise.dss.denoisers.averaging.AverageBias object at 0x7f6458c08ef0>,
    n_components=8)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Visualize Cardiac Components#

# Score Curve
plot_component_score_curve(dss_ecg, mode="ratio", show=True)
Component Scores
<Figure size 1400x800 with 1 Axes>

Time Series The dominant cardiac component should follow a QRS-like shape.

plot_component_time_series(dss_ecg, data=ecg_epochs, n_components=8, show=True)
Component Time Series
<Figure size 2000x800 with 1 Axes>

Spatial Patterns The corresponding field pattern should look broad and deep rather than strictly focal.

plot_component_patterns(
    dss_ecg,
    info=ecg_epochs.info,
    picks=ecg_sensor_picks,
    n_components=8,
    show=True,
)
Component Patterns, Comp 0, Comp 1, Comp 2, Comp 3, Comp 4, Comp 5, Comp 6, Comp 7
<Figure size 2400x1200 with 8 Axes>

Summary

plot_component_summary(
    dss_ecg,
    data=ecg_epochs,
    info=ecg_epochs.info,
    picks=ecg_sensor_picks,
    n_components=[0],
    show=True,
)
Comp 0 Pattern, Comp 0 Time Course, PSD
<Figure size 2400x600 with 3 Axes>

Removing the Artifact#

print("Removing cardiac component...")
sources_ecg = dss_ecg.transform(raw_meg)  # Apply to continuous data

sources_ecg[0, :] = 0  # Zero out heartbeat
raw_clean_ecg = mne.io.RawArray(dss_ecg.inverse_transform(sources_ecg), raw_meg.info)
Removing cardiac component...
Creating RawArray with float64 data, n_channels=203, n_times=36038
    Range : 0 ... 36037 =      0.000 ...    60.000 secs
Ready.

Heartbeat-Locked Average Before and After#

The heartbeat-locked average should show a much smaller QRS-like transient after removing the dominant cardiac component.

# Verification
ecg_epochs_clean = mne.Epochs(
    raw_clean_ecg,
    ecg_epochs.events,
    tmin=-0.1,
    tmax=0.1,
    baseline=(None, 0),
    verbose=False,
)

plot_evoked_gfp_comparison(
    ecg_epochs,
    ecg_epochs_clean,
    times=ecg_epochs.times,
    show=False,
    labels=("Original", "Cleaned"),
)
plt.show()
Evoked GFP Comparison
Using data from preloaded Raw for 59 events and 121 original time points ...
42 bad epochs dropped

Spectral Preservation After Heartbeat Removal#

Cardiac harmonics should be reduced while broader neural rhythms remain.

PSD Comparison
    Using multitaper spectrum estimation with 7 DPSS windows
Using data from preloaded Raw for 17 events and 121 original time points ...
    Using multitaper spectrum estimation with 7 DPSS windows

<Figure size 1600x800 with 1 Axes>

Conclusion#

We used DSS with AverageBias to find and remove stereotypic artifacts. By epoching on the artifact events, we turned the artifact into the most repeatable signal in the data. DSS isolates that repeatable component, and denoising then reduces to zeroing it before projection back to sensor space.

Total running time of the script: (0 minutes 12.863 seconds)