"""
.. _ex-dSPM-epochs:

==================================================
Compute MNE-dSPM inverse solution on single epochs
==================================================

Compute dSPM inverse solution on single trial epochs restricted
to a brain label.
"""
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

# %%

import matplotlib.pyplot as plt
import numpy as np

import mne
from mne.datasets import sample
from mne.minimum_norm import apply_inverse, apply_inverse_epochs, read_inverse_operator

print(__doc__)

data_path = sample.data_path()
meg_path = data_path / "MEG" / "sample"
fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif"
fname_raw = meg_path / "sample_audvis_filt-0-40_raw.fif"
fname_event = meg_path / "sample_audvis_filt-0-40_raw-eve.fif"
label_name = "Aud-lh"
fname_label = meg_path / "labels" / f"{label_name}.label"

event_id, tmin, tmax = 1, -0.2, 0.5

# Using the same inverse operator when inspecting single trials Vs. evoked
snr = 3.0  # Standard assumption for average data but using it for single trial
lambda2 = 1.0 / snr**2

method = "dSPM"  # use dSPM method (could also be MNE or sLORETA)

# Load data
inverse_operator = read_inverse_operator(fname_inv)
label = mne.read_label(fname_label)
raw = mne.io.read_raw_fif(fname_raw)
events = mne.read_events(fname_event)

# Set up pick list
include = []

# Add a bad channel
raw.info["bads"] += ["EEG 053"]  # bads + 1 more

# pick MEG channels
picks = mne.pick_types(
    raw.info, meg=True, eeg=False, stim=False, eog=True, include=include, exclude="bads"
)
# Read epochs
epochs = mne.Epochs(
    raw,
    events,
    event_id,
    tmin,
    tmax,
    picks=picks,
    baseline=(None, 0),
    reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6),
)

# Get evoked data (averaging across trials in sensor space)
evoked = epochs.average()

# Compute inverse solution and stcs for each epoch
# Use the same inverse operator as with evoked data (i.e., set nave)
# If you use a different nave, dSPM just scales by a factor sqrt(nave)
stcs = apply_inverse_epochs(
    epochs,
    inverse_operator,
    lambda2,
    method,
    label,
    pick_ori="normal",
    nave=evoked.nave,
)

# Mean across trials but not across vertices in label
mean_stc = sum(stcs) / len(stcs)

# compute sign flip to avoid signal cancellation when averaging signed values
flip = mne.label_sign_flip(label, inverse_operator["src"])

label_mean = np.mean(mean_stc.data, axis=0)
label_mean_flip = np.mean(flip[:, np.newaxis] * mean_stc.data, axis=0)

# Get inverse solution by inverting evoked data
stc_evoked = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal")

# apply_inverse() does whole brain, so sub-select label of interest
stc_evoked_label = stc_evoked.in_label(label)

# Average over label (not caring to align polarities here)
label_mean_evoked = np.mean(stc_evoked_label.data, axis=0)

# %%
# View activation time-series to illustrate the benefit of aligning/flipping

times = 1e3 * stcs[0].times  # times in ms

plt.figure()
h0 = plt.plot(times, mean_stc.data.T, "k")
(h1,) = plt.plot(times, label_mean, "r", linewidth=3)
(h2,) = plt.plot(times, label_mean_flip, "g", linewidth=3)
plt.legend((h0[0], h1, h2), ("all dipoles in label", "mean", "mean with sign flip"))
plt.xlabel("time (ms)")
plt.ylabel("dSPM value")
plt.show()

# %%
# Viewing single trial dSPM and average dSPM for unflipped pooling over label
# Compare to (1) Inverse (dSPM) then average, (2) Evoked then dSPM

# Single trial
plt.figure()
for k, stc_trial in enumerate(stcs):
    plt.plot(
        times,
        np.mean(stc_trial.data, axis=0).T,
        "k--",
        label="Single Trials" if k == 0 else "_nolegend_",
        alpha=0.5,
    )

# Single trial inverse then average.. making linewidth large to not be masked
plt.plot(times, label_mean, "b", linewidth=6, label="dSPM first, then average")

# Evoked and then inverse
plt.plot(times, label_mean_evoked, "r", linewidth=2, label="Average first, then dSPM")

plt.xlabel("time (ms)")
plt.ylabel("dSPM value")
plt.legend()
plt.show()
