Source code for mne_connectivity.base

from copy import copy

import numpy as np
import xarray as xr
from sklearn.utils import check_random_state
from mne.utils import (_check_combine, _check_option, _validate_type,
                       copy_function_doc_to_method_doc, object_size,

from mne_connectivity.utils import fill_doc
from mne_connectivity.viz import plot_connectivity_circle

class SpectralMixin:
    """Mixin class for spectral connectivities.

    Note: In mne-connectivity, we associate the word
    "spectral" with time-frequency. Reference to
    eigenvalue structure is not captured in this mixin.
    def freqs(self):
        """The frequency points of the connectivity data.

        If these are computed over a frequency band, it will
        be the median frequency of the frequency band.
        return self.xarray.coords.get('freqs').values.tolist()

class TimeMixin:
    def times(self):
        """The time points of the connectivity data."""
        return self.xarray.coords.get('times').values.tolist()

class EpochMixin:
    def combine(self, combine='mean'):
        """Combine connectivity data over epochs.

        combine : 'mean' | 'median' | callable
            How to combine correlation estimates across epochs.
            Default is 'mean'. If callable, it must accept one
            positional input. For example::

                combine = lambda data: np.median(data, axis=0)

        conn : instance of Connectivity
            The combined connectivity data structure.
        from .io import _xarray_to_conn

        fun = _check_combine(combine, valid=('mean', 'median'))

        # apply function over the  array
        new_xr = xr.apply_ufunc(fun, self.xarray,
        new_xr.attrs = self.xarray.attrs

        # map class name to its actual class
        conn_cls = {
            'EpochConnectivity': Connectivity,
            'EpochTemporalConnectivity': TemporalConnectivity,
            'EpochSpectralConnectivity': SpectralConnectivity,
            'EpochSpectroTemporalConnectivity': SpectroTemporalConnectivity
        cls_func = conn_cls[self.__class__.__name__]

        # convert new xarray to non-Epoch data structure
        conn = _xarray_to_conn(new_xr, cls_func)
        return conn

class DynamicMixin:
    def predict(self, data):
        """Predict samples on actual data.

        The result of this function is used for calculating the residuals.

        data : array
            Epoched or continuous data set. Has shape
            (n_epochs, n_signals, n_times) or (n_signals, n_times).

        predicted : array
            Data as predicted by the VAR model of
            shape same as ``data``.

        Residuals are obtained by r = x - var.predict(x).

        To compute residual covariances::

            # compute the covariance of the residuals
            # row are observations, columns are variables
            t = residuals.shape[0]
            sampled_residuals = np.concatenate(
                np.split(residuals[:, :, model_order:], t, 0),
            rescov = np.cov(sampled_residuals)

        if data.ndim < 2 or data.ndim > 3:
            raise ValueError(f'Data passed in must be either 2D or 3D. '
                             f'The data you passed in has {data.ndim} dims.')
        if data.ndim == 2 and self.is_epoched:
            raise RuntimeError('If there is a VAR model over epochs, '
                               'one must pass in a 3D array.')
        if data.ndim == 3 and not self.is_epoched:
            raise RuntimeError('If there is a single VAR model, '
                               'one must pass in a 2D array.')

        # make the data 3D
        if data.ndim == 2:
            data = data[np.newaxis, ...]

        n_epochs, _, n_times = data.shape
        var_model = self.get_data(output='dense')

        # get the model order
        model_order = self.attrs.get('model_order')

        # predict the data by applying forward model
        predicted_data = np.zeros(data.shape)
        # which takes less loop iterations
        if n_epochs > n_times - model_order:
            for idx in range(1, model_order + 1):
                for jdx in range(model_order, n_times):
                    if self.is_epoched:
                        bp = var_model[jdx, :, (idx - 1)::model_order]
                        bp = var_model[:, (idx - 1)::model_order]
                    predicted_data[:, :,
                                   jdx] +=[:, :, jdx - idx], bp.T)
            for idx in range(1, model_order + 1):
                for jdx in range(n_epochs):
                    if self.is_epoched:
                        bp = var_model[jdx, :, (idx - 1)::model_order]
                        bp = var_model[:, (idx - 1)::model_order]
                    predicted_data[jdx, :, model_order:] += \
                            data[jdx, :, (model_order - idx):(n_times - idx)]

        return predicted_data

    def simulate(self, n_samples, noise_func=None, random_state=None):
        """Simulate vector autoregressive (VAR) model.

        This function generates data from the VAR model.

        n_samples : int
            Number of samples to generate.
        noisefunc : func, optional
            This function is used to create the generating noise process. If
            set to None, Gaussian white noise with zero mean and unit variance
            is used.

        data : array, shape (n_samples, n_channels)
            Generated data.
        var_model = self.get_data(output='dense')
        if self.is_epoched:
            var_model = var_model.mean(axis=0)

        n_nodes = self.n_nodes
        model_order = self.attrs.get('model_order')

        # set noise function
        if noise_func is None:
            rng = check_random_state(random_state)

            def noise_func():

                return rng.normal(size=(1, n_nodes))

        n = n_samples + 10 * model_order

        # simulated data
        data = np.zeros((n, n_nodes))
        res = np.zeros((n, n_nodes))

        for jdx in range(model_order):
            e = noise_func()
            res[jdx, :] = e
            data[jdx, :] = e
        for jdx in range(model_order, n):
            e = noise_func()
            res[jdx, :] = e
            data[jdx, :] = e
            for idx in range(1, model_order + 1):
                data[jdx, :] += \
                    var_model[:, (idx - 1)::model_order].dot(
                        data[jdx - idx, :]

        # self.residuals = res[10 * model_order:, :, :].T
        # self.rescov = sp.cov(cat_trials(self.residuals).T, rowvar=False)
        return data[10 * model_order:, :].transpose()

class _Connectivity(DynamicMixin):
    """Base class for connectivity data.

    Connectivity data is anything that represents "connections"
    between nodes as a (N, N) array. It can be symmetric, or
    asymmetric (if it is symmetric, storage optimization will

    kwargs : dict
        Extra connectivity parameters. These may include
        ``freqs`` for spectral connectivity, and/or
        ``times`` for connectivity over time. In addition,
        these may include extra parameters that are stored
        as xarray ``attrs``.

    Connectivity data can be generally represented as a square matrix
    with values intending the connectivity function value between two
    nodes. We optimize storage of symmetric connectivity data
    and allow support for computing connectivity data on a subset of nodes.
    We store connectivity data as a raveled ``(n_estimated_nodes, ...)``
    where ``n_estimated_nodes`` can be ``n_nodes_in * n_nodes_out`` if a
    full connectivity structure is computed, or a subset of the nodes
    (equal to the length of the indices passed in).

    Since we store connectivity data as a raveled array, one can
    easily optimize the storage of "symmetric" connectivity data.
    One can use numpy to convert a full all-to-all connectivity
    into an upper triangular portion, and set ``indices='symmetric'``.
    This would reduce the RAM needed in half.

    The underlying data structure is an ``xarray.DataArray``,
    with a similar API to ``xarray``. We provide support for storing
    connectivity data in a subset of nodes. Thus the underlying
    data structure instead of a ``(n_nodes_in, n_nodes_out)`` 2D array
    would be a ``(n_nodes_in * n_nodes_out,)`` raveled 1D array. This
    allows us to optimize storage also for symmetric connectivity.
    # whether or not the connectivity occurs over epochs
    is_epoched = False

    def __init__(self, data, names, indices, method,
                 n_nodes, **kwargs):

        if isinstance(indices, str) and \
                indices not in ['all', 'symmetric']:
            raise ValueError(f'Indices can only be '
                             f'"all", otherwise '
                             f'should be a list of tuples. '
                             f'It cannot be {indices}.')

        # check the incoming data structure
        self._check_data_consistency(data, indices=indices, n_nodes=n_nodes)
        self._prepare_xarray(data, names=names, indices=indices,
                             n_nodes=n_nodes, method=method, **kwargs)

    def __repr__(self) -> str:
        r = f'<{self.__class__.__name__} | '

        if 'freqs' in self.dims:
            r += "freq : [%f, %f], " % (self.freqs[0], self.freqs[-1])
        if 'times' in self.dims:
            r += "time : [%f, %f], " % (self.times[0], self.times[-1])
        r += f", nave : {self.n_epochs_used}"
        r += f', nodes, n_estimated : {self.n_nodes}, ' \
        r += ', ~%s' % (sizeof_fmt(self._size),)
        r += '>'
        return r

    def _get_num_connections(self, data):
        """Compute the number of estimated nodes' connectivity."""
        # account for epoch data structures
        if self.is_epoched:
            start_idx = 1
            self.n_epochs = data.shape[0]
            self.n_epochs = None
            start_idx = 0
        self.n_estimated_nodes = data.shape[start_idx]

    def _prepare_xarray(self, data, names, indices, n_nodes, method,
        """Prepare xarray data structure.

        data : [type]
        names : [type]
        # set node names
        if names is None:
            names = list(map(str, range(n_nodes)))

        # the names of each first few dimensions of
        # the data depending if data is epoched or not
        if self.is_epoched:
            dims = ['epochs', 'node_in -> node_out']
            dims = ['node_in -> node_out']

        # the coordinates of each dimension
        n_estimated_list = list(map(str, range(self.n_estimated_nodes)))
        coords = dict()
        if self.is_epoched:
            coords['epochs'] = list(map(str, range(data.shape[0])))
        coords["node_in -> node_out"] = n_estimated_list
        if 'freqs' in kwargs:
            coords['freqs'] = kwargs.pop('freqs')
        if 'times' in kwargs:
            times = kwargs.pop('times')
            if times is None:
                times = list(range(data.shape[-1]))
            coords['times'] = list(times)

        # convert all numpy arrays to lists
        for key, val in kwargs.items():
            if isinstance(val, np.ndarray):
                kwargs[key] = val.tolist()
        kwargs['node_names'] = names

        # set method, indices and n_nodes
        if isinstance(indices, tuple):
            new_indices = (list(indices[0]), list(indices[1]))
            indices = new_indices
        kwargs['method'] = method
        kwargs['indices'] = indices
        kwargs['n_nodes'] = n_nodes

        # create xarray object
        xarray_obj = xr.DataArray(
        self._obj = xarray_obj

    def _check_data_consistency(self, data, indices, n_nodes):
        """Perform data input checks."""
        if not isinstance(data, np.ndarray):
            raise TypeError('Connectivity data must be passed in as a '
                            'numpy array.')

        if self.is_epoched:
            if data.ndim < 2 or data.ndim > 4:
                raise RuntimeError(f'Data using an epoched data '
                                   f'structure should have at least '
                                   f'2 dimensions and at most 4 '
                                   f'dimensions. Your data was '
                                   f'{data.shape} shape.')
            if data.ndim > 3:
                raise RuntimeError(f'Data not using an epoched data '
                                   f'structure should have at least '
                                   f'1 dimensions and at most 3 '
                                   f'dimensions. Your data was '
                                   f'{data.shape} shape.')

        # get the number of estimated nodes
        if self.is_epoched:
            data_len = data.shape[1]
            data_len = data.shape[0]

        if isinstance(indices, tuple):
            # check that the indices passed in are of the same length
            if len(indices[0]) != len(indices[1]):
                raise ValueError(f'If indices are passed in '
                                 f'then they must be the same '
                                 f'length. They are right now '
                                 f'{len(indices[0])} and '
            # indices length should match the data length
            if len(indices[0]) != data_len:
                raise ValueError(
                    f'The number of indices, {len(indices[0])} '
                    f'should match the raveled data length passed '
                    f'in of {data_len}.')

        elif indices == 'symmetric':
            expected_len = ((n_nodes + 1) * n_nodes) // 2
            if data_len != expected_len:
                raise ValueError(f'If "indices" is "symmetric", then '
                                 f'connectivity data should be the '
                                 f'upper-triangular part of the matrix. There '
                                 f'are {data_len} estimated connections. '
                                 f'But there should be {expected_len} '
                                 f'estimated connections.')

    def _data(self):
        """Numpy array of connectivity data."""
        return self.xarray.values

    def dims(self):
        """The dimensions of the xarray data."""
        return self.xarray.dims

    def coords(self):
        """The coordinates of the xarray data."""
        return self.xarray.coords

    def attrs(self):
        """Xarray attributes of connectivity.

        See ``xarray``'s ``attrs``.
        return self.xarray.attrs

    def shape(self):
        """Shape of raveled connectivity."""
        return self.xarray.shape

    def n_nodes(self):
        """The number of nodes in the original dataset.

        Even if ``indices`` defines a subset of nodes that
        were computed, this should be the total number of
        nodes in the original dataset.
        return self.attrs['n_nodes']

    def method(self):
        """The method used to compute connectivity."""
        return self.attrs['method']

    def indices(self):
        """Indices of connectivity data.

        indices : str | tuple of lists
            Either 'all' for all-to-all connectivity,
            'symmetric' for symmetric all-to-all connectivity,
            or a tuple of lists representing the node-to-nodes
            that connectivity was computed.
        return self.attrs['indices']

    def names(self):
        """Node names."""
        return self.attrs['node_names']

    def xarray(self):
        """Xarray of the connectivity data."""
        return self._obj

    def n_epochs_used(self):
        """Number of epochs used in computation of connectivity.

        Can be 'None', if there was no epochs used.
        return self.attrs.get('n_epochs_used')

    def _size(self):
        """Estimate the object size."""
        size = 0
        size += object_size(self.get_data())
        size += object_size(self.attrs)
        return size

    def get_data(self, output='compact'):
        """Get connectivity data as a numpy array.

        output : str, optional
            How to format the output, by default 'raveled', which
            will represent each connectivity matrix as a
            ``(n_nodes_in * n_nodes_out,)`` list. If 'dense', then
            will return each connectivity matrix as a 2D array. If 'compact'
            (default) then will return 'raveled' if ``indices`` were defined as
            a list of tuples, or ``dense`` if indices is 'all'.

        data : np.ndarray
            The output connectivity data.
        _check_option('output', output, ['raveled', 'dense', 'compact'])

        if output == 'compact':
            if self.indices in ['all', 'symmetric']:
                output = 'dense'
                output = 'raveled'

        if output == 'raveled':
            data = self._data
            # get the new shape of the data array
            if self.is_epoched:
                new_shape = [self.n_epochs]
                new_shape = []

            # handle the case where model order is defined in VAR connectivity
            # and thus appends the connectivity matrices side by side, so the
            # shape is N x N * model_order
            model_order = self.attrs.get('model_order', 1)
            new_shape.extend([self.n_nodes * model_order, self.n_nodes])
            if 'freqs' in self.dims:
            if 'times' in self.dims:

            # handle things differently if indices is defined
            if isinstance(self.indices, tuple):
                # TODO: improve this to be more memory efficient
                # from all-to-all connectivity structure
                data = np.zeros(new_shape)
                data[:] = np.nan

                row_idx, col_idx = self.indices
                if self.is_epoched:
                    data[:, row_idx, col_idx, ...] = self._data
                    data[row_idx, col_idx, ...] = self._data
            elif self.indices == 'symmetric':
                data = np.zeros(new_shape)

                # get the upper/lower triangular indices
                row_triu_inds, col_triu_inds = np.triu_indices(
                    self.n_nodes, k=0)
                if self.is_epoched:
                    data[:, row_triu_inds, col_triu_inds, ...] = self._data
                    data[:, col_triu_inds, row_triu_inds, ...] = self._data
                    data[row_triu_inds, col_triu_inds, ...] = self._data
                    data[col_triu_inds, row_triu_inds, ...] = self._data
                data = self._data.reshape(new_shape)

        return data

    def rename_nodes(self, mapping):
        """Rename nodes.

        mapping : dict
            Mapping from original node names (keys) to new node names (values).
        names = copy(self.names)

        # first check and assemble clean mappings of index and name
        if isinstance(mapping, dict):
            orig_names = sorted(list(mapping.keys()))
            missing = [orig_name not in names for orig_name in orig_names]
            if any(missing):
                raise ValueError(
                    "Name(s) in mapping missing from info: "
                    "%s" % np.array(orig_names)[np.array(missing)])
            new_names = [(names.index(name), new_name)
                         for name, new_name in mapping.items()]
        elif callable(mapping):
            new_names = [(ci, mapping(name))
                         for ci, name in enumerate(names)]
            raise ValueError('mapping must be callable or dict, not %s'
                             % (type(mapping),))

        # check we got all strings out of the mapping
        for new_name in new_names:
            _validate_type(new_name[1], 'str', 'New name mappings')

        # do the remapping locally
        for c_ind, new_name in new_names:
            names[c_ind] = new_name

        # check that all the channel names are unique
        if len(names) != len(np.unique(names)):
            raise ValueError(
                'New channel names are not unique, renaming failed')

        # rename the new names
        self._obj.attrs['node_names'] = names

    def plot_circle(self, **kwargs):
            indices=self.indices, **kwargs)

    # def plot_matrix(self):
    #     pass

    # def plot_3d(self):
    #     pass

    def save(self, fname):
        """Save connectivity data to disk.

        fname : str | pathlib.Path
            The filepath to save the data. Data is saved
            as netCDF files (``.nc`` extension).
        method = self.method
        indices = self.indices
        n_nodes = self.n_nodes

        # create a copy of the old attributes
        old_attrs = copy(self.attrs)

        # assign these to xarray's attrs
        self.attrs['method'] = method
        self.attrs['indices'] = indices
        self.attrs['n_nodes'] = n_nodes

        # save the name of the connectivity structure
        self.attrs['data_structure'] = str(self.__class__.__name__)

        # netCDF does not support 'None'
        # so map these to 'n/a'
        for key, val in self.attrs.items():
            if val is None:
                self.attrs[key] = 'n/a'

        # save as a netCDF file
        # note this requires the netcdf4 python library
        self.xarray.to_netcdf(fname, mode='w', format='NETCDF4',

        # re-set old attributes
        self.xarray.attrs = old_attrs

[docs]@fill_doc class SpectralConnectivity(_Connectivity, SpectralMixin): """Spectral connectivity class. This class stores connectivity data that varies over frequencies. The underlying data is an array of shape (n_connections, n_freqs), or (n_nodes, n_nodes, n_freqs). Parameters ---------- %(data)s %(freqs)s %(n_nodes)s %(names)s %(indices)s %(method)s %(spec_method)s %(n_epochs_used)s See Also -------- mne_connectivity.spectral_connectivity """ expected_n_dim = 2 def __init__(self, data, freqs, n_nodes, names=None, indices='all', method=None, spec_method=None, n_epochs_used=None, **kwargs): super().__init__(data, names=names, method=method, indices=indices, n_nodes=n_nodes, freqs=freqs, spec_method=spec_method, n_epochs_used=n_epochs_used, **kwargs)
[docs]@fill_doc class TemporalConnectivity(_Connectivity, TimeMixin): """Temporal connectivity class. This is an array of shape (n_connections, n_times), or (n_nodes, n_nodes, n_times). This describes how connectivity varies over time. It describes sample-by-sample time-varying connectivity (usually on the order of milliseconds). Here time (t=0) is the same for all connectivity measures. Parameters ---------- %(data)s %(times)s %(n_nodes)s %(names)s %(indices)s %(method)s %(n_epochs_used)s Notes ----- `mne_connectivity.EpochConnectivity` is a similar connectivity class to this one. However, that describes one connectivity snapshot for each epoch. These epochs might be chunks of time that have different meaning for time ``t=0``. Epochs can mean separate trials, where the beginning of the trial implies t=0. These Epochs may also be discontiguous. """ expected_n_dim = 2 def __init__(self, data, times, n_nodes, names=None, indices='all', method=None, n_epochs_used=None, **kwargs): super().__init__(data, names=names, method=method, n_nodes=n_nodes, indices=indices, times=times, n_epochs_used=n_epochs_used, **kwargs)
[docs]@fill_doc class SpectroTemporalConnectivity(_Connectivity, SpectralMixin, TimeMixin): """Spectrotemporal connectivity class. This class stores connectivity data that varies over both frequency and time. The temporal part describes sample-by-sample time-varying connectivity (usually on the order of milliseconds). Note the difference relative to Epochs. The underlying data is an array of shape (n_connections, n_freqs, n_times), or (n_nodes, n_nodes, n_freqs, n_times). Parameters ---------- %(data)s %(freqs)s %(times)s %(n_nodes)s %(names)s %(indices)s %(method)s %(spec_method)s %(n_epochs_used)s """ def __init__(self, data, freqs, times, n_nodes, names=None, indices='all', method=None, spec_method=None, n_epochs_used=None, **kwargs): super().__init__(data, names=names, method=method, indices=indices, n_nodes=n_nodes, freqs=freqs, spec_method=spec_method, times=times, n_epochs_used=n_epochs_used, **kwargs)
[docs]@fill_doc class EpochSpectralConnectivity(SpectralConnectivity, EpochMixin): """Spectral connectivity class over Epochs. This is an array of shape (n_epochs, n_connections, n_freqs), or (n_epochs, n_nodes, n_nodes, n_freqs). This describes how connectivity varies over frequencies for different epochs. Parameters ---------- %(data)s %(freqs)s %(n_nodes)s %(names)s %(indices)s %(method)s %(spec_method)s """ # whether or not the connectivity occurs over epochs is_epoched = True def __init__(self, data, freqs, n_nodes, names=None, indices='all', method=None, spec_method=None, **kwargs): super().__init__( data, freqs=freqs, names=names, indices=indices, n_nodes=n_nodes, method=method, spec_method=spec_method, **kwargs)
[docs]@fill_doc class EpochTemporalConnectivity(TemporalConnectivity, EpochMixin): """Temporal connectivity class over Epochs. This is an array of shape (n_epochs, n_connections, n_times), or (n_epochs, n_nodes, n_nodes, n_times). This describes how connectivity varies over time for different epochs. Parameters ---------- %(data)s %(times)s %(n_nodes)s %(names)s %(indices)s %(method)s """ # whether or not the connectivity occurs over epochs is_epoched = True def __init__(self, data, times, n_nodes, names=None, indices='all', method=None, **kwargs): super().__init__(data, times=times, names=names, indices=indices, n_nodes=n_nodes, method=method, **kwargs)
[docs]@fill_doc class EpochSpectroTemporalConnectivity( SpectroTemporalConnectivity, EpochMixin ): """Spectrotemporal connectivity class over Epochs. This is an array of shape (n_epochs, n_connections, n_freqs, n_times), or (n_epochs, n_nodes, n_nodes, n_freqs, n_times). This describes how connectivity varies over frequencies and time for different epochs. Parameters ---------- %(data)s %(freqs)s %(times)s %(n_nodes)s %(names)s %(indices)s %(method)s %(spec_method)s """ # whether or not the connectivity occurs over epochs is_epoched = True def __init__(self, data, freqs, times, n_nodes, names=None, indices='all', method=None, spec_method=None, **kwargs): super().__init__( data, names=names, freqs=freqs, times=times, indices=indices, n_nodes=n_nodes, method=method, spec_method=spec_method, **kwargs)
[docs]@fill_doc class Connectivity(_Connectivity, EpochMixin): """Connectivity class without frequency or time component. This is an array of shape (n_connections,), or (n_nodes, n_nodes). This describes a connectivity matrix/graph that does not vary over time, frequency, or epochs. Parameters ---------- %(data)s %(n_nodes)s %(names)s %(indices)s %(method)s %(n_epochs_used)s See Also -------- mne_connectivity.vector_auto_regression mne_connectivity.envelope_correlation """ def __init__(self, data, n_nodes, names=None, indices='all', method=None, n_epochs_used=None, **kwargs): super().__init__(data, names=names, method=method, n_nodes=n_nodes, indices=indices, n_epochs_used=n_epochs_used, **kwargs)
[docs]@fill_doc class EpochConnectivity(_Connectivity, EpochMixin): """Epoch connectivity class. This is an array of shape (n_epochs, n_connections), or (n_epochs, n_nodes, n_nodes). This describes how connectivity varies for different epochs. Parameters ---------- %(data)s %(n_nodes)s %(names)s %(indices)s %(method)s %(n_epochs_used)s See Also -------- mne_connectivity.vector_auto_regression mne_connectivity.envelope_correlation """ # whether or not the connectivity occurs over epochs is_epoched = True def __init__(self, data, n_nodes, names=None, indices='all', method=None, n_epochs_used=None, **kwargs): super().__init__(data, names=names, method=method, n_nodes=n_nodes, indices=indices, n_epochs_used=n_epochs_used, **kwargs)