from copy import copy, deepcopy
import numpy as np
import pandas as pd
import xarray as xr
from mne.utils import (
_check_combine,
_check_event_id,
_check_option,
_ensure_events,
_on_missing,
_validate_type,
check_random_state,
copy_function_doc_to_method_doc,
object_size,
sizeof_fmt,
warn,
)
from mne_connectivity.utils import _prepare_xarray_mne_data_structures, fill_doc
from mne_connectivity.viz import plot_connectivity_circle
class SpectralMixin:
"""Mixin class for spectral connectivities.
Note: In mne-connectivity, we associate the word "spectral" with time-frequency.
Reference to eigenvalue structure is not captured in this mixin.
"""
@property
def freqs(self):
"""The frequency points of the connectivity data.
If these are computed over a frequency band, it will be the median frequency of
the frequency band.
"""
return self.xarray.coords.get("freqs").values.tolist()
class TimeMixin:
@property
def times(self):
"""The time points of the connectivity data."""
return self.xarray.coords.get("times").values.tolist()
class EpochMixin:
def _init_epochs(self, events, event_id, on_missing="warn") -> None:
# Epochs should have the events array that informs user of
# sample points at which each Epoch was taken from.
# An empty list occurs when NetCDF stores empty arrays.
if events is not None and np.array(events).size != 0:
events = _ensure_events(events)
else:
events = np.empty((0, 3))
event_id = _check_event_id(event_id, events)
self.event_id = event_id
self.events = events
# see BaseEpochs init in MNE-Python
if events is not None:
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = f"No matching events found for {key} (event id {val})"
_on_missing(on_missing, msg)
# ensure metadata matches original events size
self.selection = np.arange(len(events))
self.events = events
del events
values = list(self.event_id.values())
selected = np.where(np.isin(self.events[:, 2], values))[0]
self.events = self.events[selected]
def append(self, epoch_conn):
"""Append another connectivity structure.
Parameters
----------
epoch_conn : instance of Connectivity
The epoched Connectivity class to append.
Returns
-------
self : instance of Connectivity
The altered epoched Connectivity class.
"""
if not isinstance(self, type(epoch_conn)):
raise ValueError(
f"The type of the epoch connectivity to append is {type(epoch_conn)}, "
f"which does not match {type(self)}."
)
if hasattr(self, "times"):
if not np.allclose(self.times, epoch_conn.times):
raise ValueError("Epochs must have same times")
if hasattr(self, "freqs"):
if not np.allclose(self.freqs, epoch_conn.freqs):
raise ValueError("Epochs must have same frequencies")
events = list(deepcopy(self.events))
event_id = deepcopy(self.event_id)
metadata = copy(self.metadata)
# compare event_id
common_keys = list(set(event_id).intersection(set(epoch_conn.event_id)))
for key in common_keys:
if not event_id[key] == epoch_conn.event_id[key]:
msg = (
"event_id values must be the same for identical keys "
'for all concatenated epochs. Key "{}" maps to {} in '
"some epochs and to {} in others."
)
raise ValueError(
msg.format(key, event_id[key], epoch_conn.event_id[key])
)
evs = epoch_conn.events.copy()
if epoch_conn.n_epochs == 0:
warn("Epoch Connectivity object to append was empty.")
event_id.update(epoch_conn.event_id)
events = np.concatenate((events, evs), axis=0)
metadata = pd.concat([epoch_conn.metadata, metadata])
# now combine the xarray data, altered events and event ID
self._obj = xr.concat([self.xarray, epoch_conn.xarray], dim="epochs")
self.events = events
self.event_id = event_id
return self
def combine(self, combine="mean"):
"""Combine connectivity data over epochs.
Parameters
----------
combine : ``'mean'`` | ``'median'`` | callable
How to combine correlation estimates across epochs. Default is ``'mean'``.
If callable, it must accept one positional input. For example::
combine = lambda data: np.median(data, axis=0)
Returns
-------
conn : instance of Connectivity
The combined connectivity data structure. Instance type reflects that of the
input instance, without the epoch dimension.
""" # noqa: E501
from .io import _xarray_to_conn
if not self.is_epoched:
raise RuntimeError(
"Combine only works over Epoched connectivity. It does not work with "
f"{self}"
)
fun = _check_combine(combine, valid=("mean", "median"))
# get a copy of metadata into attrs as a dictionary
self = _prepare_xarray_mne_data_structures(self)
# apply function over the array
new_xr = xr.apply_ufunc(
fun, self.xarray, input_core_dims=[["epochs"]], vectorize=True
)
new_xr.attrs = self.xarray.attrs
# map class name to its actual class
conn_cls = {
"EpochConnectivity": Connectivity,
"EpochTemporalConnectivity": TemporalConnectivity,
"EpochSpectralConnectivity": SpectralConnectivity,
"EpochSpectroTemporalConnectivity": SpectroTemporalConnectivity,
}
cls_func = conn_cls[self.__class__.__name__]
# convert new xarray to non-Epoch data structure
conn = _xarray_to_conn(new_xr, cls_func)
return conn
class DynamicMixin:
def is_stable(self):
companion_mat = self.companion
return np.abs(np.linalg.eigvals(companion_mat)).max() < 1.0
def eigvals(self):
return np.linalg.eigvals(self.companion)
@property
def companion(self):
"""Generate block companion matrix.
Returns the data matrix if the model is VAR(1).
"""
from .vector_ar.utils import _block_companion
lags = self.attrs.get("lags")
data = self.get_data()
if lags == 1:
return data
arrs = []
for idx in range(self.n_epochs):
blocks = _block_companion([data[idx, ..., jdx] for jdx in range(lags)])
arrs.append(blocks)
return arrs
def predict(self, data):
"""Predict samples on actual data.
The result of this function is used for calculating the residuals.
Parameters
----------
data : array, shape ([n_epochs,] n_signals, n_times)
Epoched or continuous data set.
Returns
-------
predicted : array, shape ([n_epochs,] n_signals, n_times)
Data as predicted by the VAR model of shape same as ``data``.
Notes
-----
Residuals are obtained by ``r = x - var.predict(x)``.
To compute residual covariances::
# compute the covariance of the residuals
# row are observations, columns are variables
t = residuals.shape[0]
sampled_residuals = np.concatenate(
np.split(residuals[:, :, lags:], t, 0), axis=2
).squeeze(0)
rescov = np.cov(sampled_residuals)
"""
if data.ndim < 2 or data.ndim > 3:
raise ValueError(
"Data passed in must be either 2D or 3D. The data you passed in has "
f"{data.ndim} dims."
)
if data.ndim == 2 and self.is_epoched:
raise RuntimeError(
"If there is a VAR model over epochs, one must pass in a 3D array."
)
if data.ndim == 3 and not self.is_epoched:
raise RuntimeError(
"If there is a single VAR model, one must pass in a 2D array."
)
# make the data 3D
if data.ndim == 2:
data = data[np.newaxis, ...]
n_epochs, _, n_times = data.shape
var_model = self.get_data(output="dense")
# get the model order
lags = self.attrs.get("lags")
# predict the data by applying forward model
predicted_data = np.zeros(data.shape)
# which takes less loop iterations
if n_epochs > n_times - lags:
for idx in range(1, lags + 1):
for jdx in range(lags, n_times):
if self.is_epoched:
bp = var_model[jdx, :, (idx - 1) :: lags]
else:
bp = var_model[:, (idx - 1) :: lags]
predicted_data[:, :, jdx] += np.dot(data[:, :, jdx - idx], bp.T)
else:
for idx in range(1, lags + 1):
for jdx in range(n_epochs):
if self.is_epoched:
bp = var_model[jdx, :, (idx - 1) :: lags]
else:
bp = var_model[:, (idx - 1) :: lags]
predicted_data[jdx, :, lags:] += np.dot(
bp, data[jdx, :, (lags - idx) : (n_times - idx)]
)
return predicted_data
@fill_doc
def simulate(self, n_samples, noise_func=None, random_state=None):
"""Simulate vector autoregressive (VAR) model.
This function generates data from the VAR model.
Parameters
----------
n_samples : int
Number of samples to generate.
noise_func : callable | None
This function is used to create the generating noise process. If ``None``,
Gaussian white noise with zero mean and unit variance is used.
%(random_state)s
Returns
-------
data : array, shape (n_samples, n_channels)
Generated data.
"""
var_model = self.get_data(output="dense")
if self.is_epoched:
var_model = var_model.mean(axis=0)
n_nodes = self.n_nodes
lags = self.attrs.get("lags")
# set noise function
if noise_func is None:
rng = check_random_state(random_state)
def noise_func():
return rng.normal(size=(1, n_nodes))
n = n_samples + 10 * lags
# simulated data
data = np.zeros((n, n_nodes))
res = np.zeros((n, n_nodes))
for jdx in range(lags):
e = noise_func()
res[jdx, :] = e
data[jdx, :] = e
for jdx in range(lags, n):
e = noise_func()
res[jdx, :] = e
data[jdx, :] = e
for idx in range(1, lags + 1):
data[jdx, :] += var_model[:, (idx - 1) :: lags].dot(data[jdx - idx, :])
# self.residuals = res[10 * lags:, :, :].T
# self.rescov = sp.cov(cat_trials(self.residuals).T, rowvar=False)
return data[10 * lags :, :].transpose()
@fill_doc
class BaseConnectivity(DynamicMixin, EpochMixin):
"""Base class for connectivity data.
This class should not be instantiated directly, but should be used to do
type-checking. All connectivity classes will be returned from corresponding
connectivity computing functions.
Connectivity data is anything that represents "connections" between nodes as a
``(N, N)`` array. It can be symmetric, or asymmetric (if it is symmetric, storage
optimization will occur).
Parameters
----------
%(data)s
%(names)s
%(indices)s
%(method)s
%(n_nodes)s
%(events)s
%(event_id)s
metadata : instance of pandas.DataFrame | None
The metadata data frame that would come from the :class:`mne.Epochs` class. See
:class:`mne.Epochs` docstring for details.
%(connectivity_kwargs)s
Notes
-----
Connectivity data can be generally represented as a square matrix with values
intending the connectivity function value between two nodes. We optimize storage of
symmetric connectivity data and allow support for computing connectivity data on a
subset of nodes. We store connectivity data as a raveled ``(n_estimated_nodes,
...)`` where ``n_estimated_nodes`` can be ``n_nodes_in * n_nodes_out`` if a full
connectivity structure is computed, or a subset of the nodes (equal to the length of
the indices passed in).
Since we store connectivity data as a raveled array, one can easily optimize the
storage of "symmetric" connectivity data. One can use numpy to convert a full
all-to-all connectivity into an upper triangular portion, and set
``indices='symmetric'``. This would reduce the RAM needed in half.
The underlying data structure is an :class:`xarray.DataArray`, with a similar API to
``xarray``. We provide support for storing connectivity data in a subset of nodes.
Thus the underlying data structure instead of a ``(n_nodes_in, n_nodes_out)`` 2D
array would be a ``(n_nodes_in * n_nodes_out,)`` raveled 1D array. This allows us to
optimize storage also for symmetric connectivity.
"""
# whether or not the connectivity occurs over epochs
is_epoched = False
def __init__(
self,
data,
names,
indices,
method,
n_nodes,
events=None,
event_id=None,
metadata=None,
**kwargs,
):
if isinstance(indices, str) and indices not in ["all", "symmetric"]:
raise ValueError(
'Indices can only be "all", "symmetric", or a list of tuples. '
f"It cannot be {indices}."
)
# prepare metadata pandas dataframe and ensure metadata is a Pandas
# DataFrame object
if metadata is None:
metadata = pd.DataFrame(dtype="float64")
self.metadata = metadata
# check the incoming data structure
self._check_data_consistency(data, indices=indices, n_nodes=n_nodes)
self._prepare_xarray(
data,
names=names,
indices=indices,
n_nodes=n_nodes,
method=method,
events=events,
event_id=event_id,
**kwargs,
)
def __repr__(self) -> str:
r = f"<{self.__class__.__name__} | "
if self.n_epochs is not None:
r += f"n_epochs : {self.n_epochs}, "
if "freqs" in self.dims:
r += f"freq : [{self.freqs[0]}, {self.freqs[-1]}], " # type: ignore
if "times" in self.dims:
r += f"time : [{self.times[0]}, {self.times[-1]}], " # type: ignore
r += f", nave : {self.n_epochs_used}"
r += f", nodes, n_estimated : {self.n_nodes}, {self.n_estimated_nodes}"
if "components" in self.dims:
r += f", n_components : {len(self.coords['components'])}, "
r += f", ~{sizeof_fmt(self._size)}"
r += ">"
return r
def _get_num_connections(self, data):
"""Compute the number of estimated nodes' connectivity."""
# account for epoch data structures
if self.is_epoched:
start_idx = 1
else:
start_idx = 0
self.n_estimated_nodes = data.shape[start_idx]
def _prepare_xarray(
self, data, names, indices, n_nodes, method, events, event_id, **kwargs
):
"""Prepare xarray data structure."""
# generate events and event_id that originate from Epochs class
# which stores the windows of Raw that were used to generate
# the corresponding connectivity data
self._init_epochs(events, event_id, on_missing="warn")
# set node names
if names is None:
names = list(map(str, range(n_nodes)))
# the names of each first few dimensions of
# the data depending if data is epoched or not
if self.is_epoched:
dims = ["epochs", "node_in -> node_out"]
else:
dims = ["node_in -> node_out"]
# the coordinates of each dimension
n_estimated_list = list(map(str, range(self.n_estimated_nodes)))
coords = dict()
if self.is_epoched:
coords["epochs"] = list(map(str, range(data.shape[0])))
coords["node_in -> node_out"] = n_estimated_list
if "components" in kwargs:
coords["components"] = kwargs.pop("components")
dims.append("components")
if "freqs" in kwargs:
coords["freqs"] = kwargs.pop("freqs")
dims.append("freqs")
if "times" in kwargs:
times = kwargs.pop("times")
if times is None:
times = list(range(data.shape[-1]))
coords["times"] = list(times)
dims.append("times")
# convert all numpy arrays to lists
for key, val in kwargs.items():
if isinstance(val, np.ndarray):
kwargs[key] = val.tolist()
kwargs["node_names"] = names
# set method, indices and n_nodes
if isinstance(indices, tuple):
if all(isinstance(inds, np.ndarray) for inds in indices):
# leave multivariate indices as arrays for easier indexing
if all(inds.ndim > 1 for inds in indices):
new_indices = (indices[0], indices[1])
else:
new_indices = (list(indices[0]), list(indices[1]))
else:
new_indices = (list(indices[0]), list(indices[1]))
indices = new_indices
kwargs["method"] = method
kwargs["indices"] = indices
kwargs["n_nodes"] = n_nodes
kwargs["events"] = self.events
# kwargs['event_id'] = self.event_id
# create xarray object
xarray_obj = xr.DataArray(data=data, coords=coords, dims=dims, attrs=kwargs)
self._obj = xarray_obj
def _check_data_consistency(self, data, indices, n_nodes):
"""Perform data input checks."""
if not isinstance(data, np.ndarray):
raise TypeError("Connectivity data must be passed in as a numpy array.")
if self.is_epoched:
if data.ndim < 2 or data.ndim > 5:
raise RuntimeError(
"Data using an epoched data structure should have at least 2 "
f"dimensions and at most 5 dimensions. Your data was {data.shape} "
"shape."
)
else:
if data.ndim > 4:
raise RuntimeError(
"Data not using an epoched data structure should have at least 1 "
f"dimensions and at most 4 dimensions. Your data was {data.shape} "
"shape."
)
# get the number of estimated nodes
self._get_num_connections(data)
if self.is_epoched:
data_len = data.shape[1]
else:
data_len = data.shape[0]
if isinstance(indices, tuple):
# check that the indices passed in are of the same length
if len(indices[0]) != len(indices[1]):
raise ValueError(
"If indices are passed in then they must be the same length. They "
f"are right now {len(indices[0])} and {len(indices[1])}."
)
# indices length should match the data length
if len(indices[0]) != data_len:
raise ValueError(
f"The number of indices, {len(indices[0])} should match the "
f"raveled data length passed in of {data_len}."
)
elif indices == "symmetric":
expected_len = ((n_nodes + 1) * n_nodes) // 2
if data_len != expected_len:
raise ValueError(
'If "indices" is "symmetric", then '
f"connectivity data should be the upper-triangular part of the "
f"matrix. There are {data_len} estimated connections. But there "
f"should be {expected_len} estimated connections."
)
def copy(self):
return deepcopy(self)
def get_epoch_annotations(self):
pass
@property
def n_epochs(self):
"""The number of epochs the connectivity data varies over."""
if self.is_epoched:
n_epochs = self._data.shape[0]
else:
n_epochs = None
return n_epochs
@property
def _data(self):
"""Numpy array of connectivity data."""
return self.xarray.values
@property
def dims(self):
"""The dimensions of the xarray data."""
return self.xarray.dims
@property
def coords(self):
"""The coordinates of the xarray data."""
return self.xarray.coords
@property
def attrs(self):
"""Xarray attributes of connectivity.
See ``xarray``'s ``attrs``.
"""
return self.xarray.attrs
@property
def shape(self):
"""Shape of raveled connectivity."""
return self.xarray.shape
@property
def n_nodes(self):
"""The number of nodes in the original dataset.
Even if ``indices`` defines a subset of nodes that were computed, this should be
the total number of nodes in the original dataset.
"""
return self.attrs["n_nodes"]
@property
def method(self):
"""The method used to compute connectivity."""
return self.attrs["method"]
@property
def indices(self):
"""Indices of connectivity data.
Returns
-------
indices : ``'all'`` | ``'symmetric'`` | tuple of list
Either ``'all'`` for all-to-all connectivity, ``'symmetric'`` for symmetric
connectivity, or a tuple of lists representing the node-to-nodes that
connectivity was computed for.
"""
return self.attrs["indices"]
@property
def names(self):
"""Node names."""
return self.attrs["node_names"]
@property
def xarray(self):
"""Xarray of the connectivity data."""
return self._obj
@property
def n_epochs_used(self):
"""Number of epochs used in computation of connectivity.
Can be ``None``, if there was no epochs used. This is equivalent to the number
of epochs, if there is no combining of epochs.
"""
return self.attrs.get("n_epochs_used")
@property
def _size(self):
"""Estimate the object size."""
size = 0
size += object_size(self._data)
size += object_size(self.attrs)
# if self.metadata is not None:
# size += self.metadata.memory_usage(index=True).sum()
return size
def get_data(self, output="compact"):
"""Get connectivity data as a numpy array.
Parameters
----------
output : ``'compact'`` | ``'raveled'`` | ``'dense'``
How to format the output:
- ``'raveled'`` will represent each connectivity matrix as a
``(..., n_nodes_in * n_nodes_out, ...)`` array
- ``'dense'`` will return each connectivity matrix as a ``(..., n_nodes_in,
n_nodes_out, ...)`` array
- ``'compact'`` (default) will return ``'raveled'`` if ``indices`` were
defined as a tuple of arrays, or ``'dense'`` if ``indices='all'``
Multivariate connectivity data cannot be returned in a dense form.
Returns
-------
data : array
The output connectivity data.
"""
_check_option("output", output, ["raveled", "dense", "compact"])
if output == "compact":
if self.indices in ["all", "symmetric"]:
output = "dense"
else:
output = "raveled"
if output == "raveled":
data = self._data
else:
if (
isinstance(self.indices, tuple)
and not isinstance(self.indices[0], int)
and not isinstance(self.indices[1], int)
): # i.e. check if multivariate results based on nested indices
# multivariate results cannot be returned in a dense form as a single
# set of results would correspond to multiple entries in the matrix, and
# there could also be cases where multiple results correspond to the
# same entries in the matrix.
raise ValueError(
"cannot return multivariate connectivity data in a dense form"
)
# get the new shape of the data array
if self.is_epoched:
new_shape = [self.n_epochs]
else:
new_shape = []
# handle the case where model order is defined in VAR connectivity
# and thus appends the connectivity matrices side by side, so the
# shape is N x N * lags
new_shape.extend([self.n_nodes, self.n_nodes])
if "components" in self.dims:
new_shape.append(len(self.coords["components"]))
if "freqs" in self.dims:
new_shape.append(len(self.coords["freqs"]))
if "times" in self.dims:
new_shape.append(len(self.coords["times"]))
# handle things differently if indices is defined
if isinstance(self.indices, tuple):
# TODO: improve this to be more memory efficient
# from all-to-all connectivity structure
data = np.zeros(new_shape)
data[:] = np.nan
row_idx, col_idx = self.indices
if self.is_epoched:
data[:, row_idx, col_idx, ...] = self._data
else:
data[row_idx, col_idx, ...] = self._data
elif self.indices == "symmetric":
data = np.zeros(new_shape)
# get the upper/lower triangular indices
row_triu_inds, col_triu_inds = np.triu_indices(self.n_nodes, k=0)
if self.is_epoched:
data[:, row_triu_inds, col_triu_inds, ...] = self._data
data[:, col_triu_inds, row_triu_inds, ...] = self._data
else:
data[row_triu_inds, col_triu_inds, ...] = self._data
data[col_triu_inds, row_triu_inds, ...] = self._data
else:
data = self._data.reshape(new_shape)
return data
def rename_nodes(self, mapping):
"""Rename nodes.
Parameters
----------
mapping : dict
Mapping from original node names (keys) to new node names (values).
"""
names = copy(self.names)
# first check and assemble clean mappings of index and name
if isinstance(mapping, dict):
orig_names = sorted(list(mapping.keys()))
missing = [orig_name not in names for orig_name in orig_names]
if any(missing):
raise ValueError(
"Name(s) in mapping missing from info: "
f"{np.array(orig_names)[np.array(missing)]}"
)
new_names = [
(names.index(name), new_name) for name, new_name in mapping.items()
]
elif callable(mapping):
new_names = [(ci, mapping(name)) for ci, name in enumerate(names)]
else:
raise ValueError(f"mapping must be callable or dict, not {type(mapping)}")
# check we got all strings out of the mapping
for new_name in new_names:
_validate_type(new_name[1], "str", "New name mappings")
# do the remapping locally
for c_ind, new_name in new_names:
names[c_ind] = new_name
# check that all the channel names are unique
if len(names) != len(np.unique(names)):
raise ValueError("New channel names are not unique, renaming failed")
# rename the new names
self._obj.attrs["node_names"] = names
@copy_function_doc_to_method_doc(plot_connectivity_circle)
def plot_circle(self, **kwargs):
plot_connectivity_circle(
self.get_data(), node_names=self.names, indices=self.indices, **kwargs
)
# def plot_matrix(self):
# pass
# def plot_3d(self):
# pass
def save(self, fname):
"""Save connectivity data to disk.
Can later be loaded using the function
:func:`~mne_connectivity.read_connectivity`.
Parameters
----------
fname : str | pathlib.Path
The filepath to save the data. Data is saved as netCDF files (``.nc``
extension).
"""
method = self.method
indices = self.indices
n_nodes = self.n_nodes
# create a copy of the old attributes
old_attrs = copy(self.attrs)
# assign these to xarray's attrs
self.attrs["method"] = method
self.attrs["indices"] = indices
self.attrs["n_nodes"] = n_nodes
# save the name of the connectivity structure
self.attrs["data_structure"] = str(self.__class__.__name__)
# get a copy of metadata into attrs as a dictionary
self = _prepare_xarray_mne_data_structures(self)
# netCDF does not support 'None'
# so map these to 'n/a'
for key, val in self.attrs.items():
if val is None:
self.attrs[key] = "n/a"
# save as a netCDF file
# note this requires the netcdf4 python library
# and h5netcdf library.
# The engine specified requires the ability to save
# complex data types, which was not natively supported
# in xarray. Therefore, h5netcdf is the only engine
# to support that feature at this moment.
self.xarray.to_netcdf(
fname, mode="w", format="NETCDF4", engine="h5netcdf", invalid_netcdf=True
)
# re-set old attributes
self.xarray.attrs = old_attrs
[docs]
@fill_doc
class SpectralConnectivity(BaseConnectivity, SpectralMixin):
"""Spectral connectivity class.
This class stores connectivity data that varies over frequencies. The underlying
data is an array of shape ``(n_connections, [n_components,] n_freqs)``, or
``(n_nodes, n_nodes, [n_components,] n_freqs)``. ``n_components`` is an optional
dimension for multivariate methods where each connection has multiple components of
connectivity.
Parameters
----------
%(data)s
%(freqs)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(spec_method)s
%(n_epochs_used)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.phase_slope_index
mne_connectivity.spectral_connectivity_epochs
mne_connectivity.spectral_connectivity_time
"""
expected_n_dim = 2
def __init__(
self,
data,
freqs,
n_nodes,
names=None,
indices="all",
method=None,
spec_method=None,
n_epochs_used=None,
**kwargs,
):
super().__init__(
data,
names=names,
method=method,
indices=indices,
n_nodes=n_nodes,
freqs=freqs,
spec_method=spec_method,
n_epochs_used=n_epochs_used,
**kwargs,
)
[docs]
@fill_doc
class TemporalConnectivity(BaseConnectivity, TimeMixin):
"""Temporal connectivity class.
This is an array of shape ``(n_connections, [n_components,] n_times)``, or
``(n_nodes, n_nodes, [n_components,] n_times)``. This describes how connectivity
varies over time. It describes sample-by-sample time-varying connectivity (usually
on the order of milliseconds). Here time (t=0) is the same for all connectivity
measures. ``n_components`` is an optional dimension for multivariate methods where
each connection has multiple components of connectivity.
Parameters
----------
%(data)s
%(times)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(n_epochs_used)s
%(connectivity_kwargs)s
Notes
-----
:class:`mne_connectivity.EpochConnectivity` is a similar connectivity class to this
one. However, that describes one connectivity snapshot for each epoch. These epochs
might be chunks of time that have different meaning for time ``t=0``. Epochs can
mean separate trials, where the beginning of the trial implies t=0. These epochs may
also be discontiguous.
"""
expected_n_dim = 2
def __init__(
self,
data,
times,
n_nodes,
names=None,
indices="all",
method=None,
n_epochs_used=None,
**kwargs,
):
super().__init__(
data,
names=names,
method=method,
n_nodes=n_nodes,
indices=indices,
times=times,
n_epochs_used=n_epochs_used,
**kwargs,
)
[docs]
@fill_doc
class SpectroTemporalConnectivity(BaseConnectivity, SpectralMixin, TimeMixin):
"""Spectrotemporal connectivity class.
This class stores connectivity data that varies over both frequency and time. The
temporal part describes sample-by-sample time-varying connectivity (usually on the
order of milliseconds). Note the difference relative to epochs.
The underlying data is an array of shape ``(n_connections, [n_components,] n_freqs,
n_times)``, or ``(n_nodes, n_nodes, [n_components,] n_freqs, n_times)``.
``n_components`` is an optional dimension for multivariate methods where each
connection has multiple components of connectivity.
Parameters
----------
%(data)s
%(freqs)s
%(times)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(spec_method)s
%(n_epochs_used)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.phase_slope_index
mne_connectivity.spectral_connectivity_epochs
"""
def __init__(
self,
data,
freqs,
times,
n_nodes,
names=None,
indices="all",
method=None,
spec_method=None,
n_epochs_used=None,
**kwargs,
):
super().__init__(
data,
names=names,
method=method,
indices=indices,
n_nodes=n_nodes,
freqs=freqs,
spec_method=spec_method,
times=times,
n_epochs_used=n_epochs_used,
**kwargs,
)
[docs]
@fill_doc
class EpochSpectralConnectivity(SpectralConnectivity):
"""Spectral connectivity class over epochs.
This is an array of shape ``(n_epochs, n_connections, [n_components,] n_freqs)``, or
``(n_epochs, n_nodes, n_nodes, [n_components,] n_freqs)``. This describes how
connectivity varies over frequencies for different epochs. ``n_components`` is an
optional dimension for multivariate methods where each connection has multiple
components of connectivity.
Parameters
----------
%(data)s
%(freqs)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(spec_method)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.spectral_connectivity_time
"""
# whether or not the connectivity occurs over epochs
is_epoched = True
def __init__(
self,
data,
freqs,
n_nodes,
names=None,
indices="all",
method=None,
spec_method=None,
**kwargs,
):
super().__init__(
data,
freqs=freqs,
names=names,
indices=indices,
n_nodes=n_nodes,
method=method,
spec_method=spec_method,
**kwargs,
)
[docs]
@fill_doc
class EpochTemporalConnectivity(TemporalConnectivity):
"""Temporal connectivity class over epochs.
This is an array of shape ``(n_epochs, n_connections, [n_components,] n_times)``, or
``(n_epochs, n_nodes, n_nodes, [n_components,] n_times)``. This describes how
connectivity varies over time for different epochs. ``n_components`` is an optional
dimension for multivariate methods where each connection has multiple components of
connectivity.
Parameters
----------
%(data)s
%(times)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.envelope_correlation
mne_connectivity.vector_auto_regression
"""
# whether or not the connectivity occurs over epochs
is_epoched = True
def __init__(
self, data, times, n_nodes, names=None, indices="all", method=None, **kwargs
):
super().__init__(
data,
times=times,
names=names,
indices=indices,
n_nodes=n_nodes,
method=method,
**kwargs,
)
[docs]
@fill_doc
class EpochSpectroTemporalConnectivity(SpectroTemporalConnectivity):
"""Spectrotemporal connectivity class over epochs.
This is an array of shape ``(n_epochs, n_connections, [n_components,] n_freqs,
n_times)``, or ``(n_epochs, n_nodes, n_nodes, [n_components,] n_freqs, n_times)``.
This describes how connectivity varies over frequencies and time for different
epochs. ``n_components`` is an optional dimension for multivariate methods where
each connection has multiple components of connectivity.
Parameters
----------
%(data)s
%(freqs)s
%(times)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(spec_method)s
%(connectivity_kwargs)s
"""
# whether or not the connectivity occurs over epochs
is_epoched = True
def __init__(
self,
data,
freqs,
times,
n_nodes,
names=None,
indices="all",
method=None,
spec_method=None,
**kwargs,
):
super().__init__(
data,
names=names,
freqs=freqs,
times=times,
indices=indices,
n_nodes=n_nodes,
method=method,
spec_method=spec_method,
**kwargs,
)
[docs]
@fill_doc
class Connectivity(BaseConnectivity):
"""Connectivity class without frequency or time component.
This is an array of shape ``(n_connections[, n_components])``, or ``(n_nodes,
n_nodes[, n_components])``. This describes a connectivity matrix/graph that does not
vary over time, frequency, or epochs. ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.
Parameters
----------
%(data)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(n_epochs_used)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.vector_auto_regression
"""
def __init__(
self,
data,
n_nodes,
names=None,
indices="all",
method=None,
n_epochs_used=None,
**kwargs,
):
super().__init__(
data,
names=names,
method=method,
n_nodes=n_nodes,
indices=indices,
n_epochs_used=n_epochs_used,
**kwargs,
)
[docs]
@fill_doc
class EpochConnectivity(BaseConnectivity):
"""Epoch connectivity class.
This is an array of shape ``(n_epochs, n_connections[, n_components])``, or
``(n_epochs, n_nodes, n_nodes[, n_components])``. This describes how connectivity
varies for different epochs. ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.
Parameters
----------
%(data)s
%(n_nodes)s
%(names)s
%(indices)s
%(method)s
%(n_epochs_used)s
%(connectivity_kwargs)s
See Also
--------
mne_connectivity.vector_auto_regression
"""
# whether or not the connectivity occurs over epochs
is_epoched = True
def __init__(
self,
data,
n_nodes,
names=None,
indices="all",
method=None,
n_epochs_used=None,
**kwargs,
):
super().__init__(
data,
names=names,
method=method,
n_nodes=n_nodes,
indices=indices,
n_epochs_used=n_epochs_used,
**kwargs,
)