"""
.. _ex-receptive-field-mtrf:

=========================================
Receptive Field Estimation and Prediction
=========================================

This example reproduces figures from Lalor et al.'s mTRF toolbox in
MATLAB :footcite:`CrosseEtAl2016`. We will show how the
:class:`mne.decoding.ReceptiveField` class
can perform a similar function along with scikit-learn. We will first fit a
linear encoding model using the continuously-varying speech envelope to predict
activity of a 128 channel EEG system. Then, we will take the reverse approach
and try to predict the speech envelope from the EEG (known in the literature
as a decoding model, or simply stimulus reconstruction).

.. _figure 1: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F1
.. _figure 2: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F2
.. _figure 5: https://www.frontiersin.org/articles/10.3389/fnhum.2016.00604/full#F5
"""

# Authors: Chris Holdgraf <choldgraf@gmail.com>
#          Eric Larson <larson.eric.d@gmail.com>
#          Nicolas Barascud <nicolas.barascud@ens.fr>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from os.path import join

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import KFold
from sklearn.preprocessing import scale

import mne
from mne.decoding import ReceptiveField

# %%
# Load the data from the publication
# ----------------------------------
#
# First we will load the data collected in :footcite:`CrosseEtAl2016`.
# In this experiment subjects
# listened to natural speech. Raw EEG and the speech stimulus are provided.
# We will load these below, downsampling the data in order to speed up
# computation since we know that our features are primarily low-frequency in
# nature. Then we'll visualize both the EEG and speech envelope.

path = mne.datasets.mtrf.data_path()
decim = 2
data = loadmat(join(path, "speech_data.mat"))
raw = data["EEG"].T
speech = data["envelope"].T
sfreq = float(data["Fs"].item())
sfreq /= decim
speech = mne.filter.resample(speech, down=decim, method="polyphase")
raw = mne.filter.resample(raw, down=decim, method="polyphase")

# Read in channel positions and create our MNE objects from the raw data
montage = mne.channels.make_standard_montage("biosemi128")
info = mne.create_info(montage.ch_names, sfreq, "eeg").set_montage(montage)
raw = mne.io.RawArray(raw, info)
n_channels = len(raw.ch_names)

# Plot a sample of brain and stimulus activity
fig, ax = plt.subplots(layout="constrained")
lns = ax.plot(scale(raw[:, :800][0].T), color="k", alpha=0.1)
ln1 = ax.plot(scale(speech[0, :800]), color="r", lw=2)
ax.legend([lns[0], ln1[0]], ["EEG", "Speech Envelope"], frameon=False)
ax.set(title="Sample activity", xlabel="Time (s)")

# %%
# Create and fit a receptive field model
# --------------------------------------
#
# We will construct an encoding model to find the linear relationship between
# a time-delayed version of the speech envelope and the EEG signal. This allows
# us to make predictions about the response to new stimuli.

# Define the delays that we will use in the receptive field
tmin, tmax = -0.2, 0.4

# Initialize the model
rf = ReceptiveField(
    tmin, tmax, sfreq, feature_names=["envelope"], estimator=1.0, scoring="corrcoef"
)
# We'll have (tmax - tmin) * sfreq delays
# and an extra 2 delays since we are inclusive on the beginning / end index
n_delays = int((tmax - tmin) * sfreq) + 2

n_splits = 3
cv = KFold(n_splits)

# Prepare model data (make time the first dimension)
speech = speech.T
Y, _ = raw[:]  # Outputs for the model
Y = Y.T

# Iterate through splits, fit the model, and predict/test on held-out data
coefs = np.zeros((n_splits, n_channels, n_delays))
scores = np.zeros((n_splits, n_channels))
for ii, (train, test) in enumerate(cv.split(speech)):
    print(f"split {ii + 1} / {n_splits}")
    rf.fit(speech[train], Y[train])
    scores[ii] = rf.score(speech[test], Y[test])
    # coef_ is shape (n_outputs, n_features, n_delays). we only have 1 feature
    coefs[ii] = rf.coef_[:, 0, :]
times = rf.delays_ / float(rf.sfreq)

# Average scores and coefficients across CV splits
mean_coefs = coefs.mean(axis=0)
mean_scores = scores.mean(axis=0)

# Plot mean prediction scores across all channels
fig, ax = plt.subplots(layout="constrained")
ix_chs = np.arange(n_channels)
ax.plot(ix_chs, mean_scores)
ax.axhline(0, ls="--", color="r")
ax.set(title="Mean prediction score", xlabel="Channel", ylabel="Score ($r$)")

# %%
# Investigate model coefficients
# ==============================
# Finally, we will look at how the linear coefficients (sometimes
# referred to as beta values) are distributed across time delays as well as
# across the scalp. We will recreate `figure 1`_ and `figure 2`_ from
# :footcite:`CrosseEtAl2016`.

# sphinx_gallery_thumbnail_number = 3

# Print mean coefficients across all time delays / channels (see Fig 1)
time_plot = 0.180  # For highlighting a specific time.
fig, ax = plt.subplots(figsize=(4, 8), layout="constrained")
max_coef = mean_coefs.max()
ax.pcolormesh(
    times,
    ix_chs,
    mean_coefs,
    cmap="RdBu_r",
    vmin=-max_coef,
    vmax=max_coef,
    shading="gouraud",
)
ax.axvline(time_plot, ls="--", color="k", lw=2)
ax.set(
    xlabel="Delay (s)",
    ylabel="Channel",
    title="Mean Model\nCoefficients",
    xlim=times[[0, -1]],
    ylim=[len(ix_chs) - 1, 0],
    xticks=np.arange(tmin, tmax + 0.2, 0.2),
)
plt.setp(ax.get_xticklabels(), rotation=45)

# Make a topographic map of coefficients for a given delay (see Fig 2C)
ix_plot = np.argmin(np.abs(time_plot - times))
fig, ax = plt.subplots(layout="constrained")
mne.viz.plot_topomap(
    mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, vlim=(-max_coef, max_coef)
)
ax.set(title=f"Topomap of model coefficients\nfor delay {time_plot}")

# %%
# Create and fit a stimulus reconstruction model
# ----------------------------------------------
#
# We will now demonstrate another use case for the for the
# :class:`mne.decoding.ReceptiveField` class as we try to predict the stimulus
# activity from the EEG data. This is known in the literature as a decoding, or
# stimulus reconstruction model :footcite:`CrosseEtAl2016`.
# A decoding model aims to find the
# relationship between the speech signal and a time-delayed version of the EEG.
# This can be useful as we exploit all of the available neural data in a
# multivariate context, compared to the encoding case which treats each M/EEG
# channel as an independent feature. Therefore, decoding models might provide a
# better quality of fit (at the expense of not controlling for stimulus
# covariance), especially for low SNR stimuli such as speech.

# We use the same lags as in :footcite:`CrosseEtAl2016`. Negative lags now
# index the relationship
# between the neural response and the speech envelope earlier in time, whereas
# positive lags would index how a unit change in the amplitude of the EEG would
# affect later stimulus activity (obviously this should have an amplitude of
# zero).
tmin, tmax = -0.2, 0.0

# Initialize the model. Here the features are the EEG data. We also specify
# ``patterns=True`` to compute inverse-transformed coefficients during model
# fitting (cf. next section and :footcite:`HaufeEtAl2014`).
# We'll use a ridge regression estimator with an alpha value similar to
# Crosse et al.
sr = ReceptiveField(
    tmin,
    tmax,
    sfreq,
    feature_names=raw.ch_names,
    estimator=1e4,
    scoring="corrcoef",
    patterns=True,
)
# We'll have (tmax - tmin) * sfreq delays
# and an extra 2 delays since we are inclusive on the beginning / end index
n_delays = int((tmax - tmin) * sfreq) + 2

n_splits = 3
cv = KFold(n_splits)

# Iterate through splits, fit the model, and predict/test on held-out data
coefs = np.zeros((n_splits, n_channels, n_delays))
patterns = coefs.copy()
scores = np.zeros((n_splits,))
for ii, (train, test) in enumerate(cv.split(speech)):
    print(f"split {ii + 1} / {n_splits}")
    sr.fit(Y[train], speech[train])
    scores[ii] = sr.score(Y[test], speech[test])[0]
    # coef_ is shape (n_outputs, n_features, n_delays). We have 128 features
    coefs[ii] = sr.coef_[0, :, :]
    patterns[ii] = sr.patterns_[0, :, :]
times = sr.delays_ / float(sr.sfreq)

# Average scores and coefficients across CV splits
mean_coefs = coefs.mean(axis=0)
mean_patterns = patterns.mean(axis=0)
mean_scores = scores.mean(axis=0)
max_coef = np.abs(mean_coefs).max()
max_patterns = np.abs(mean_patterns).max()

# %%
# Visualize stimulus reconstruction
# =================================
#
# To get a sense of our model performance, we can plot the actual and predicted
# stimulus envelopes side by side.

y_pred = sr.predict(Y[test])
time = np.linspace(0, 2.0, 5 * int(sfreq))
fig, ax = plt.subplots(figsize=(8, 4), layout="constrained")
ax.plot(
    time, speech[test][sr.valid_samples_][: int(5 * sfreq)], color="grey", lw=2, ls="--"
)
ax.plot(time, y_pred[sr.valid_samples_][: int(5 * sfreq)], color="r", lw=2)
ax.legend([lns[0], ln1[0]], ["Envelope", "Reconstruction"], frameon=False)
ax.set(title="Stimulus reconstruction")
ax.set_xlabel("Time (s)")

# %%
# Investigate model coefficients
# ==============================
#
# Finally, we will look at how the decoding model coefficients are distributed
# across the scalp. We will attempt to recreate `figure 5`_ from
# :footcite:`CrosseEtAl2016`. The
# decoding model weights reflect the channels that contribute most toward
# reconstructing the stimulus signal, but are not directly interpretable in a
# neurophysiological sense. Here we also look at the coefficients obtained
# via an inversion procedure :footcite:`HaufeEtAl2014`, which have a more
# straightforward
# interpretation as their value (and sign) directly relates to the stimulus
# signal's strength (and effect direction).

time_plot = (-0.140, -0.125)  # To average between two timepoints.
ix_plot = np.arange(
    np.argmin(np.abs(time_plot[0] - times)), np.argmin(np.abs(time_plot[1] - times))
)
fig, ax = plt.subplots(1, 2)
mne.viz.plot_topomap(
    np.mean(mean_coefs[:, ix_plot], axis=1),
    pos=info,
    axes=ax[0],
    show=False,
    vlim=(-max_coef, max_coef),
)
ax[0].set(title=f"Model coefficients\nbetween delays {time_plot[0]} and {time_plot[1]}")

mne.viz.plot_topomap(
    np.mean(mean_patterns[:, ix_plot], axis=1),
    pos=info,
    axes=ax[1],
    show=False,
    vlim=(-max_patterns, max_patterns),
)
ax[1].set(
    title=(
        f"Inverse-transformed coefficients\nbetween delays {time_plot[0]} and "
        f"{time_plot[1]}"
    )
)

# %%
# References
# ----------
#
# .. footbibliography::
