.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/dss/plot_12_joint_dss.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_dss_plot_12_joint_dss.py: Joint DSS (Multi-Dataset Repeatability). ========================================= This example demonstrates **Joint Denoising Source Separation (JDSS)**, a method for extracting components that are reproducible across multiple datasets. JDSS is useful for finding sources that are consistent across subjects, consistent across recording blocks, or reproducible across repeated stimulus presentations. The objective function is: .. math:: \max_w \frac{w^T R_{signal} w}{w^T R_{total} w} where :math:`R_{signal} = Cov(\bar{X})` is the covariance of the grand average and :math:`R_{total} = \frac{1}{N} \sum Cov(X_i)` is the mean of the individual covariances. Components with high eigenvalue are highly reproducible. Reference: de Cheveigné, A., & Parra, L. C. (2014). Joint decorrelation, a versatile tool for multichannel data analysis. NeuroImage, 98, 487-505. Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca) Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca) .. GENERATED FROM PYTHON SOURCE LINES 32-34 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 34-41 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from scipy import signal as sig from mne_denoise.dss import DSS from mne_denoise.dss.denoisers import AverageBias .. GENERATED FROM PYTHON SOURCE LINES 42-48 Simulate Multi-Subject Data --------------------------- We simulate five subjects that share one weak 10 Hz source while each subject also has strong pink noise with its own spatial pattern and additional low-amplitude sensor noise. This is exactly the regime where single-subject PCA or ICA struggles but JDSS can exploit cross-subject consistency. .. GENERATED FROM PYTHON SOURCE LINES 48-98 .. code-block:: Python print("=== Joint DSS Example ===\n") print("Simulating 5 subjects with a shared 10 Hz source buried in noise...") n_subjects = 5 n_channels = 16 n_times = 1000 sfreq = 250 times = np.arange(n_times) / sfreq rng = np.random.RandomState(42) # Common source: 10 Hz sine (signal) common_source = np.sin(2 * np.pi * 10 * times) common_source /= np.std(common_source) # Common (but slightly varying) topography base_topo = np.ones(n_channels) base_topo[:8] = 1.5 # Stronger in first half of channels base_topo /= np.linalg.norm(base_topo) datasets = [] for subj in range(n_subjects): # Slightly perturb topography per subject (realistic) topo = base_topo + 0.1 * rng.randn(n_channels) topo /= np.linalg.norm(topo) # Signal component signal_part = np.outer(topo, common_source) * 1.0 # SNR ~ 1 # Subject-specific noise (pink, strong) noise_topo = rng.randn(n_channels) noise_topo /= np.linalg.norm(noise_topo) noise_source = rng.randn(n_times) # Pink filter b, a = sig.butter(3, 0.1) noise_source = sig.filtfilt(b, a, noise_source) noise_source /= np.std(noise_source) noise_part = np.outer(noise_topo, noise_source) * 3.0 # 3x signal # Sensor noise sensor_noise = 0.3 * rng.randn(n_channels, n_times) data = signal_part + noise_part + sensor_noise datasets.append(data) datasets = np.array(datasets) # (n_subjects, n_channels, n_times) print(f"Created {n_subjects} datasets of shape {datasets[0].shape}") print("Signal amplitude: 1.0, Noise amplitude: 3.0 (SNR ~ 0.33)") .. rst-class:: sphx-glr-script-out .. code-block:: none === Joint DSS Example === Simulating 5 subjects with a shared 10 Hz source buried in noise... Created 5 datasets of shape (16, 1000) Signal amplitude: 1.0, Noise amplitude: 3.0 (SNR ~ 0.33) .. GENERATED FROM PYTHON SOURCE LINES 99-109 Apply Joint DSS --------------- JDSS finds the spatial filter that maximizes the ratio of "grand average variance" to "mean of individual variances". Note: DSS expects input shape (n_channels, n_times, n_epochs). We treat the 5 subjects as "epochs" for the purpose of finding reproducible components across subjects. So we transpose datasets from (n_subjects, n_ch, n_times) to (n_ch, n_times, n_subjects). .. GENERATED FROM PYTHON SOURCE LINES 109-121 .. code-block:: Python print("\nApplying JDSS (via DSS with group averaging)...") datasets_dss = np.transpose(datasets, (1, 2, 0)) # (16, 1000, 5) # Use 'epochs' axis to average over the 3rd dimension (which represents subjects here) jdss = DSS(bias=AverageBias(axis="epochs"), n_components=3) jdss.fit(datasets_dss) print(f"Eigenvalues (repeatability scores): {jdss.eigenvalues_}") print(" -> Score near 1.0 = highly reproducible.") print(" -> Score near 0.0 = random noise.\n") .. rst-class:: sphx-glr-script-out .. code-block:: none Applying JDSS (via DSS with group averaging)... Eigenvalues (repeatability scores): [0.91191608 0.25192183 0.2334974 ] -> Score near 1.0 = highly reproducible. -> Score near 0.0 = random noise. .. GENERATED FROM PYTHON SOURCE LINES 122-127 Extract Sources --------------- Apply the learned filters to the data. Transform returns (n_components, n_times, n_subjects) because input was (n_ch, n_times, n_subjects) .. GENERATED FROM PYTHON SOURCE LINES 127-138 .. code-block:: Python sources = jdss.transform(datasets_dss) # (3, 1000, 5) sources = np.transpose(sources, (2, 0, 1)) # (n_subjects, n_components, n_times) # Grand average of sources ga_sources = np.mean(sources, axis=0) # (n_components, n_times) # Grand average of raw data (best channel) ga_raw = np.mean(datasets, axis=0) best_ch = np.argmax(np.var(ga_raw, axis=1)) .. GENERATED FROM PYTHON SOURCE LINES 139-142 Visualize Results ----------------- Compare the ground truth, raw grand average, and JDSS component 1. .. GENERATED FROM PYTHON SOURCE LINES 142-170 .. code-block:: Python fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True) # Ground truth axes[0].plot(times, common_source, "k", lw=2) axes[0].set_ylabel("Amplitude") axes[0].set_title("Ground Truth: Common 10 Hz Source") axes[0].grid(True, alpha=0.3) # Raw grand average (best channel) axes[1].plot(times, ga_raw[best_ch], "gray", lw=1) axes[1].set_ylabel("Amplitude") axes[1].set_title(f"Raw Grand Average (Channel {best_ch}) - Noisy") axes[1].grid(True, alpha=0.3) # JDSS Component 1 # Flip sign if anti-correlated with ground truth corr = np.corrcoef(ga_sources[0], common_source)[0, 1] sign = np.sign(corr) axes[2].plot(times, sign * ga_sources[0], "g", lw=2) axes[2].set_ylabel("Amplitude") axes[2].set_xlabel("Time (s)") axes[2].set_title(f"JDSS Component 1 (Score: {jdss.eigenvalues_[0]:.3f})") axes[2].grid(True, alpha=0.3) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_001.png :alt: Ground Truth: Common 10 Hz Source, Raw Grand Average (Channel 15) - Noisy, JDSS Component 1 (Score: 0.912) :srcset: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 171-174 Quantitative Evaluation ----------------------- Compute correlation between JDSS output and ground truth. .. GENERATED FROM PYTHON SOURCE LINES 174-183 .. code-block:: Python corr_raw = np.abs(np.corrcoef(ga_raw[best_ch], common_source)[0, 1]) corr_jdss = np.abs(np.corrcoef(ga_sources[0], common_source)[0, 1]) print("=== Correlation with Ground Truth ===") print(f" Raw Grand Average (best channel): {corr_raw:.3f}") print(f" JDSS Component 1: {corr_jdss:.3f}") print(f"\nJDSS improves recovery by {(corr_jdss - corr_raw) / corr_raw * 100:.1f}%") .. rst-class:: sphx-glr-script-out .. code-block:: none === Correlation with Ground Truth === Raw Grand Average (best channel): 0.268 JDSS Component 1: 0.988 JDSS improves recovery by 268.5% .. GENERATED FROM PYTHON SOURCE LINES 184-187 Per-Subject Sources ------------------- Show that JDSS sources are consistent across subjects. .. GENERATED FROM PYTHON SOURCE LINES 187-207 .. code-block:: Python fig, axes = plt.subplots(2, 3, figsize=(12, 5)) for i, ax in enumerate(axes.flat): if i < n_subjects: src = sources[i, 0] # Align sign if np.corrcoef(src, common_source)[0, 1] < 0: src = -src ax.plot(times, src, "b", alpha=0.7) ax.plot(times, common_source, "k--", alpha=0.5, lw=1) ax.set_title(f"Subject {i + 1}") ax.set_xlabel("Time (s)") ax.grid(True, alpha=0.3) else: ax.axis("off") fig.suptitle("JDSS Component 1 per Subject (blue) vs Ground Truth (dashed)") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_002.png :alt: JDSS Component 1 per Subject (blue) vs Ground Truth (dashed), Subject 1, Subject 2, Subject 3, Subject 4, Subject 5 :srcset: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 208-211 Eigenvalue Spectrum ------------------- The first eigenvalue should be much larger than the rest. .. GENERATED FROM PYTHON SOURCE LINES 211-221 .. code-block:: Python plt.figure(figsize=(6, 4)) plt.bar(range(1, len(jdss.eigenvalues_) + 1), jdss.eigenvalues_, color="steelblue") plt.xlabel("Component") plt.ylabel("Repeatability Score") plt.title("JDSS Eigenvalue Spectrum") plt.grid(True, alpha=0.3, axis="y") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_003.png :alt: JDSS Eigenvalue Spectrum :srcset: /auto_examples/dss/images/sphx_glr_plot_12_joint_dss_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 222-228 Conclusion ---------- JDSS successfully extracted the common 10 Hz signal even though the signal was weaker than the noise, the noise patterns varied across subjects, and the signal topographies were only approximately aligned. That is exactly the setting where a shared group-level component model is useful. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.743 seconds) .. _sphx_glr_download_auto_examples_dss_plot_12_joint_dss.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_12_joint_dss.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_12_joint_dss.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_12_joint_dss.zip `