# Authors: Eric Larson <larson.eric.d@gmail.com>
# Sheraz Khan <sheraz@khansheraz.com>
# Denis Engemann <denis.engemann@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from mne.filter import next_fast_len
from mne.source_estimate import _BaseSourceEstimate
from mne.utils import verbose, _check_combine, _check_option
[docs]@verbose
def envelope_correlation(data, combine='mean', orthogonalize="pairwise",
log=False, absolute=True, verbose=None):
"""Compute the envelope correlation.
Parameters
----------
data : array-like, shape=(n_epochs, n_signals, n_times) | generator
The data from which to compute connectivity.
The array-like object can also be a list/generator of array,
each with shape (n_signals, n_times), or a :class:`~mne.SourceEstimate`
object (and ``stc.data`` will be used). If it's float data,
the Hilbert transform will be applied; if it's complex data,
it's assumed the Hilbert has already been applied.
combine : 'mean' | callable | None
How to combine correlation estimates across epochs.
Default is 'mean'. Can be None to return without combining.
If callable, it must accept one positional input.
For example::
combine = lambda data: np.median(data, axis=0)
orthogonalize : 'pairwise' | False
Whether to orthogonalize with the pairwise method or not.
Defaults to 'pairwise'. Note that when False,
the correlation matrix will not be returned with
absolute values.
.. versionadded:: 0.19
log : bool
If True (default False), square and take the log before orthonalizing
envelopes or computing correlations.
.. versionadded:: 0.22
absolute : bool
If True (default), then take the absolute value of correlation
coefficients before making each epoch's correlation matrix
symmetric (and thus before combining matrices across epochs).
Only used when ``orthogonalize=True``.
.. versionadded:: 0.22
%(verbose)s
Returns
-------
corr : ndarray, shape ([n_epochs, ]n_nodes, n_nodes)
The pairwise orthogonal envelope correlations.
This matrix is symmetric. If combine is None, the array
with have three dimensions, the first of which is ``n_epochs``.
Notes
-----
This function computes the power envelope correlation between
orthogonalized signals :footcite:`HippEtAl2012,KhanEtAl2018`.
.. versionchanged:: 0.22
Computations fixed for ``orthogonalize=True`` and diagonal entries are
set explicitly to zero.
References
----------
.. footbibliography::
"""
_check_option('orthogonalize', orthogonalize, (False, 'pairwise'))
from scipy.signal import hilbert
n_nodes = None
if combine is not None:
fun = _check_combine(combine, valid=('mean',))
else: # None
fun = np.array
corrs = list()
# Note: This is embarassingly parallel, but the overhead of sending
# the data to different workers is roughly the same as the gain of
# using multiple CPUs. And we require too much GIL for prefer='threading'
# to help.
for ei, epoch_data in enumerate(data):
if isinstance(epoch_data, _BaseSourceEstimate):
epoch_data = epoch_data.data
if epoch_data.ndim != 2:
raise ValueError('Each entry in data must be 2D, got shape %s'
% (epoch_data.shape,))
n_nodes, n_times = epoch_data.shape
if ei > 0 and n_nodes != corrs[0].shape[0]:
raise ValueError('n_nodes mismatch between data[0] and data[%d], '
'got %s and %s'
% (ei, n_nodes, corrs[0].shape[0]))
# Get the complex envelope (allowing complex inputs allows people
# to do raw.apply_hilbert if they want)
if epoch_data.dtype in (np.float32, np.float64):
n_fft = next_fast_len(n_times)
epoch_data = hilbert(epoch_data, N=n_fft, axis=-1)[..., :n_times]
if epoch_data.dtype not in (np.complex64, np.complex128):
raise ValueError('data.dtype must be float or complex, got %s'
% (epoch_data.dtype,))
data_mag = np.abs(epoch_data)
data_conj_scaled = epoch_data.conj()
data_conj_scaled /= data_mag
if log:
data_mag *= data_mag
np.log(data_mag, out=data_mag)
# subtract means
data_mag_nomean = data_mag - np.mean(data_mag, axis=-1, keepdims=True)
# compute variances using linalg.norm (square, sum, sqrt) since mean=0
data_mag_std = np.linalg.norm(data_mag_nomean, axis=-1)
data_mag_std[data_mag_std == 0] = 1
corr = np.empty((n_nodes, n_nodes))
for li, label_data in enumerate(epoch_data):
if orthogonalize is False: # the new code
label_data_orth = data_mag[li]
label_data_orth_std = data_mag_std[li]
else:
label_data_orth = (label_data * data_conj_scaled).imag
np.abs(label_data_orth, out=label_data_orth)
# protect against invalid value -- this will be zero
# after (log and) mean subtraction
label_data_orth[li] = 1.
if log:
label_data_orth *= label_data_orth
np.log(label_data_orth, out=label_data_orth)
label_data_orth -= np.mean(label_data_orth, axis=-1,
keepdims=True)
label_data_orth_std = np.linalg.norm(label_data_orth, axis=-1)
label_data_orth_std[label_data_orth_std == 0] = 1
# correlation is dot product divided by variances
corr[li] = np.sum(label_data_orth * data_mag_nomean, axis=1)
corr[li] /= data_mag_std
corr[li] /= label_data_orth_std
if orthogonalize is not False:
# Make it symmetric (it isn't at this point)
if absolute:
corr = np.abs(corr)
corr = (corr.T + corr) / 2.
corrs.append(corr)
del corr
corr = fun(corrs)
return corr