.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/dss/plot_08_blind_source_separation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_dss_plot_08_blind_source_separation.py: ============================================================================= Blind Source Separation and ICA Equivalence. ============================================================================= This example demonstrates how **Nonlinear DSS** can perform Blind Source Separation (BSS), effectively recovering independent sources from mixed signals. It explicitly shows the equivalence between DSS with specific nonlinearities and ICA. It covers synthetic blind source separation, the link between DSS nonlinearities and FastICA, and a real MEG decomposition that exposes both artifact and brain-like components. Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca) Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca) .. GENERATED FROM PYTHON SOURCE LINES 18-34 .. code-block:: Python import matplotlib.pyplot as plt import mne import numpy as np from mne.datasets import sample from scipy import stats from mne_denoise.dss import IterativeDSS, KurtosisDenoiser, TanhMaskDenoiser, beta_tanh from mne_denoise.viz import ( plot_component_summary, plot_component_time_series, plot_signal_overlay, ) print(__doc__) .. GENERATED FROM PYTHON SOURCE LINES 35-41 Part 1: Synthetic Blind Source Separation ----------------------------------------- We generate synthetic sources with different statistical properties (Super-Gaussian, Sub-Gaussian) and mix them linearly. We then attempt to recover them using DSS and FastICA. .. GENERATED FROM PYTHON SOURCE LINES 41-86 .. code-block:: Python print("\n--- 1. Creating Synthetic Mixed Data ---") n_samples = 2000 time = np.linspace(0, 8, n_samples) # 1. Super-Gaussian (Laplace) - "Sparse" / "Bursty" s1 = stats.laplace.rvs(size=n_samples) s1 /= s1.std() # 2. Super-Gaussian (Square Wave) - High Kurtosis s2 = np.sign(np.sin(3 * time)) s2 /= s2.std() # 3. Sub-Gaussian (Sinusoid) - Low Kurtosis s3 = np.sin(10 * time) s3 /= s3.std() # 4. Gaussian Noise s4 = np.random.randn(n_samples) # Stack true sources S_true = np.c_[s1, s2, s3, s4].T n_sources = S_true.shape[0] # Mix sources np.random.seed(42) A = np.random.randn(n_sources, n_sources) # Mixing matrix X = np.dot(A, S_true) # Mixed signals # Visualize fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True) axes[0].plot(time, S_true.T + np.arange(n_sources) * 5) axes[0].set_title("True Sources") axes[0].set_yticks(np.arange(n_sources) * 5) axes[0].set_yticklabels([f"S{i}" for i in range(n_sources)]) axes[1].plot(time, X.T + np.arange(n_sources) * 5) axes[1].set_title("Mixed Signals (Input)") axes[1].set_yticks(np.arange(n_sources) * 5) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_001.png :alt: True Sources, Mixed Signals (Input) :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none --- 1. Creating Synthetic Mixed Data --- .. GENERATED FROM PYTHON SOURCE LINES 87-95 Run DSS with Tanh Nonlinearity (Robust ICA) ------------------------------------------- The `TanhMaskDenoiser` implements the `tanh` nonlinearity, which is robust to outliers. This section compares the Newton-style update (`beta=beta_tanh`) with plain gradient ascent (`beta=None`). The faster convergence is the same idea that makes FastICA efficient in practice. .. GENERATED FROM PYTHON SOURCE LINES 95-133 .. code-block:: Python print("\nRunning DSS with Tanh Nonlinearity (Robust)...") # 1. Gradient Ascent (Slow) print(" Fitting with Gradient Ascent (beta=None)...") dss_grad = IterativeDSS( denoiser=TanhMaskDenoiser(), method="deflation", n_components=n_sources, beta=None, # Gradient ascent random_state=42, verbose=False, ) dss_grad.fit(X) # 2. Newton Method (Fast - FastICA style) print(" Fitting with Newton Method (beta=beta_tanh)...") dss_tanh = IterativeDSS( denoiser=TanhMaskDenoiser(), method="deflation", n_components=n_sources, beta=beta_tanh, # Newton step random_state=42, verbose=False, ) dss_tanh.fit(X) S_dss_tanh = dss_tanh.transform(X) # Compare iterations iters_grad = dss_grad.convergence_info_[:, 0].sum() iters_newton = dss_tanh.convergence_info_[:, 0].sum() print(f" Gradient Iterations: {iters_grad:.0f}") print( f" Newton Iterations: {iters_newton:.0f} " f"(Speedup: {iters_grad / iters_newton:.1f}x)" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Running DSS with Tanh Nonlinearity (Robust)... Fitting with Gradient Ascent (beta=None)... Fitting with Newton Method (beta=beta_tanh)... Gradient Iterations: 116 Newton Iterations: 20 (Speedup: 5.8x) .. GENERATED FROM PYTHON SOURCE LINES 134-138 Run DSS with Kurtosis Nonlinearity (Standard FastICA) ----------------------------------------------------- `KurtosisDenoiser` with `nonlinearity='cube'` maximizes kurtosis ($s^3$), which is the classic definition of FastICA. .. GENERATED FROM PYTHON SOURCE LINES 138-152 .. code-block:: Python print("Running DSS with Kurtosis Nonlinearity (FastICA standard)...") dss_kurt = IterativeDSS( denoiser=KurtosisDenoiser(nonlinearity="cube"), method="deflation", n_components=n_sources, beta=-3.0, # Newton step for kurtosis random_state=42, verbose=False, ) dss_kurt.fit(X) S_dss_kurt = dss_kurt.transform(X) .. rst-class:: sphx-glr-script-out .. code-block:: none Running DSS with Kurtosis Nonlinearity (FastICA standard)... .. GENERATED FROM PYTHON SOURCE LINES 153-156 Comparison with sklearn FastICA ------------------------------- We run `sklearn.decomposition.FastICA` to serve as a ground truth benchmark. .. GENERATED FROM PYTHON SOURCE LINES 156-166 .. code-block:: Python from sklearn.decomposition import FastICA print("Running sklearn FastICA (Benchmark)...") ica = FastICA( n_components=n_sources, algorithm="deflation", fun="logcosh", random_state=42 ) S_fastica = ica.fit_transform(X.T).T .. rst-class:: sphx-glr-script-out .. code-block:: none Running sklearn FastICA (Benchmark)... .. GENERATED FROM PYTHON SOURCE LINES 167-172 Evaluate Performance (Correlation with True Sources) ---------------------------------------------------- We compute the absolute correlation matrix between recovered components and true sources. A perfect recovery would have one 1.0 per row/column (permutation matrix). .. GENERATED FROM PYTHON SOURCE LINES 172-223 .. code-block:: Python def match_sources(S_est, S_true): """Calculate best correlation match for each source.""" n_est = S_est.shape[0] n_true = S_true.shape[0] corr = np.zeros((n_est, n_true)) for i in range(n_est): for j in range(n_true): corr[i, j] = np.abs(np.corrcoef(S_est[i], S_true[j])[0, 1]) return corr print("\n--- Evaluation ---") corr_tanh = match_sources(S_dss_tanh, S_true) corr_kurt = match_sources(S_dss_kurt, S_true) fig, axes = plt.subplots(1, 3, figsize=(15, 4)) im0 = axes[0].imshow(corr_tanh, vmin=0, vmax=1, cmap="Greens") axes[0].set_title( f"DSS (Tanh) Match\nMean Max Corr: {np.mean(np.max(corr_tanh, axis=1)):.3f}" ) axes[0].set_ylabel("Recovered") axes[0].set_xlabel("True") plt.colorbar(im0, ax=axes[0]) im1 = axes[1].imshow(corr_kurt, vmin=0, vmax=1, cmap="Blues") axes[1].set_title( f"DSS (Kurtosis) Match\nMean Max Corr: {np.mean(np.max(corr_kurt, axis=1)):.3f}" ) axes[1].set_xlabel("True") plt.colorbar(im1, ax=axes[1]) corr_ica = match_sources(S_fastica, S_true) im2 = axes[2].imshow(corr_ica, vmin=0, vmax=1, cmap="Oranges") axes[2].set_title( f"sklearn FastICA Match\nMean Max Corr: {np.mean(np.max(corr_ica, axis=1)):.3f}" ) axes[2].set_xlabel("True") plt.colorbar(im2, ax=axes[2]) plt.suptitle("Source Recovery Quality (Abs Correlation)") plt.tight_layout() plt.show() # Plot recovered time series using viz module print("Visualizing Recovered Sources (Stacked)...") # We treat the sources as "components" of the estimator .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_002.png :alt: Source Recovery Quality (Abs Correlation), DSS (Tanh) Match Mean Max Corr: 0.961, DSS (Kurtosis) Match Mean Max Corr: 0.936, sklearn FastICA Match Mean Max Corr: 0.999 :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none --- Evaluation --- Visualizing Recovered Sources (Stacked)... .. GENERATED FROM PYTHON SOURCE LINES 224-226 Recovered Source Time Series ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 226-231 .. code-block:: Python plot_component_time_series(dss_tanh, data=X, show=False) plt.gcf().suptitle("Recovered Sources (DSS Tanh) - Newton Optimization") plt.show() .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_003.png :alt: Recovered Sources (DSS Tanh) - Newton Optimization, Component Time Series :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 232-237 Part 2: Blind Separation of Real MEG Data ----------------------------------------- We apply nonlinear DSS to the MNE sample dataset (MEG channels) to blindly extract artifacts (EOG, ECG) and brain sources. This is similar to running `mne.preprocessing.ICA`. .. GENERATED FROM PYTHON SOURCE LINES 237-298 .. code-block:: Python print("\n--- 2. Real MEG Data (Blind Separation) ---") data_path = sample.data_path() raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" raw = mne.io.read_raw_fif( raw_fname, verbose=False ) # list_url=[] prevents download print spam usually raw.crop(0, 60).pick_types(meg=True, eeg=False, eog=True, stim=False).load_data() # Filter to remove drifts and high freq noise raw.filter(1, 40, verbose=False) # Prepare MEG-only data for BSS # We want to find artifacts *in the MEG channels*, ensuring we # don't just pick up the EOG channel itself. raw_meg = raw.copy().pick_types(meg=True, eeg=False, eog=False, stim=False) print(f"Data shape (MEG only): {raw_meg.get_data().shape}") # Fit DSS-Tanh (Blind Decomposition) print("Fitting Blind DSS (this may take a moment)...") n_components = 15 dss_meg = IterativeDSS( denoiser=TanhMaskDenoiser(), method="deflation", n_components=n_components, beta=beta_tanh, verbose=True, ) dss_meg.fit(raw_meg) # Identify Artifacts by correlation with EOG channel # We use the separate EOG channel to validate which extracted # source corresponds to blinks. eog_ch = raw.get_data(picks="eog")[0] sources = dss_meg.transform(raw_meg) corrs = [np.abs(np.corrcoef(s, eog_ch)[0, 1]) for s in sources] blink_idx = np.argmax(corrs) print(f"\nMost likely EOG component: #{blink_idx} (Corr: {corrs[blink_idx]:.3f})") # Visualize the Blink Component print("Visualizing Blink Component...") plot_component_summary( dss_meg, data=raw_meg, info=raw_meg.info, picks=mne.pick_types(raw_meg.info, meg="grad", eeg=False, eog=False, stim=False), n_components=[blink_idx], show=False, ) plt.gcf().suptitle(f"Component #{blink_idx}: Blindly Extracted EOG Artifact") plt.show() # Visualize a Brain Component (candidate) # We look for a component that is NOT the blink argmax candidate_indices = [i for i in range(n_components) if i != blink_idx] brain_idx = candidate_indices[1] # Pick arbitrary one, e.g. 2nd candidate print(f"Visualizing Candidate Brain Component #{brain_idx}...") .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_004.png :alt: Component #1: Blindly Extracted EOG Artifact, Comp 1 Pattern, Comp 1 Time Course, PSD :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none --- 2. Real MEG Data (Blind Separation) --- NOTE: pick_types() is a legacy function. New code should use inst.pick(...). Reading 0 ... 36037 = 0.000 ... 60.000 secs... NOTE: pick_types() is a legacy function. New code should use inst.pick(...). Data shape (MEG only): (305, 36038) Fitting Blind DSS (this may take a moment)... Component 1: 9 iterations (converged) Component 2: 9 iterations (converged) Component 3: 9 iterations (converged) Component 4: 8 iterations (converged) Component 5: 8 iterations (converged) Component 6: 8 iterations (converged) Component 7: 8 iterations (converged) Component 8: 8 iterations (converged) Component 9: 8 iterations (converged) Component 10: 8 iterations (converged) Component 11: 8 iterations (converged) Component 12: 8 iterations (converged) Component 13: 8 iterations (converged) Component 14: 8 iterations (converged) Component 15: 8 iterations (converged) Most likely EOG component: #1 (Corr: 0.106) Visualizing Blink Component... Visualizing Candidate Brain Component #2... .. GENERATED FROM PYTHON SOURCE LINES 299-301 Candidate Brain Component ------------------------- .. GENERATED FROM PYTHON SOURCE LINES 301-321 .. code-block:: Python plot_component_summary( dss_meg, data=raw_meg, info=raw_meg.info, picks=mne.pick_types(raw_meg.info, meg="grad", eeg=False, eog=False, stim=False), n_components=[brain_idx], show=False, ) plt.gcf().suptitle(f"Component #{brain_idx}: Candidate Brain Source") plt.show() # Overlay comparison for EOG # Show how the extracted component matches the EOG channel eog_raw = mne.io.RawArray( eog_ch[None, :], mne.create_info(["EOG"], raw.info["sfreq"], "eog") ) comp_raw = mne.io.RawArray( sources[blink_idx : blink_idx + 1], mne.create_info(1, raw.info["sfreq"], "misc") ) .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_005.png :alt: Component #2: Candidate Brain Source, Comp 2 Pattern, Comp 2 Time Course, PSD :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Creating RawArray with float64 data, n_channels=1, n_times=36038 Range : 0 ... 36037 = 0.000 ... 60.000 secs Ready. Creating RawArray with float64 data, n_channels=1, n_times=36038 Range : 0 ... 36037 = 0.000 ... 60.000 secs Ready. .. GENERATED FROM PYTHON SOURCE LINES 322-324 EOG Overlay ----------- .. GENERATED FROM PYTHON SOURCE LINES 324-336 .. code-block:: Python plot_signal_overlay( eog_raw, comp_raw, times=eog_raw.times, start=10, stop=20, title="EOG Channel vs Extracted Component (Time Domain)", show=False, ) plt.show() print("\nBlind Source Separation complete!") .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_006.png :alt: EOG Channel vs Extracted Component (Time Domain) :srcset: /auto_examples/dss/images/sphx_glr_plot_08_blind_source_separation_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Blind Source Separation complete! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.438 seconds) .. _sphx_glr_download_auto_examples_dss_plot_08_blind_source_separation.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_08_blind_source_separation.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_08_blind_source_separation.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_08_blind_source_separation.zip `