Source code for mne_connectivity.var

import numpy as np
import scipy
from scipy.linalg import sqrtm
from tqdm import tqdm

from .utils import fill_doc
from .base import Connectivity, EpochConnectivity


[docs]@fill_doc def vector_auto_regression( data, times=None, names=None, model_order=1, l2_reg=0.0, compute_fb_operator=False, n_jobs=1, model='dynamic', verbose=None): """Compute vector auto-regresssive (VAR) model. Parameters ---------- data : array-like, shape=(n_epochs, n_signals, n_times) | generator The data from which to compute connectivity. The epochs dimension is interpreted differently, depending on ``'output'`` argument. times : array-like (Optional) The time points used to construct the epoched ``data``. If ``None``, then ``times_used`` in the Connectivity will not be available. %(names)s model_order : int | str, optional Autoregressive model order, by default 1. l2_reg : float, optional Ridge penalty (l2-regularization) parameter, by default 0.0 compute_fb_operator : bool Whether to compute the backwards operator and average with the forward operator. Addresses bias in the least-square estimation :footcite:`Dawson_2016`. model : str Whether to compute one VAR model using all epochs as multiple samples of the same VAR model ('avg-epochs'), or to compute a separate VAR model for each epoch ('dynamic'), which results in a time-varying VAR model. See Notes. %(n_jobs)s %(verbose)s Returns ------- conn : Connectivity | TemporalConnectivity | EpochConnectivity The connectivity data estimated. See Also -------- mne_connectivity.Connectivity mne_connectivity.EpochConnectivity Notes ----- Names can be passed in, which are then used to instantiate the nodes of the connectivity class. For example, they can be the electrode names of EEG. For higher-order VAR models, there are n_order ``A`` matrices, representing the linear dynamics with respect to that lag. These are represented by vertically concatenated matrices. For example, if the input is data where n_signals is 3, then an order-1 VAR model will result in a 3x3 connectivity matrix. An order-2 VAR model will result in a 6x3 connectivity matrix, with two 3x3 matrices representing the dynamics at lag 1 and lag 2, respectively. When computing a VAR model (i.e. linear dynamical system), we require the input to be a ``(n_epochs, n_signals, n_times)`` 3D array. There are two ways one can interpret the data in the model. First, epochs can be treated as multiple samples observed for a single VAR model. That is, we have $X_1, X_2, ..., X_n$, where each $X_i$ is a ``(n_signals, n_times)`` data array, with n epochs. We are interested in estimating the parameters, $(A_1, A_2, ..., A_{order})$ from the following model over **all** epochs: .. math:: X(t+1) = \sum_{i=0}^{order} A_i X(t-i) This results in one VAR model over all the epochs. The second approach treats each epoch as a different VAR model, estimating a time-varying VAR model. Using the same data as above, we now are interested in estimating the parameters, $(A_1, A_2, ..., A_{order})$ for **each** epoch. The model would be the following for **each** epoch: .. math:: X(t+1) = \sum_{i=0}^{order} A_i X(t-i) This results in one VAR model for each epoch. This is done according to the model in :footcite:`li_linear_2017`. *b* is of shape [m, m*p], with sub matrices arranged as follows: +------+------+------+------+ | b_00 | b_01 | ... | b_0m | +------+------+------+------+ | b_10 | b_11 | ... | b_1m | +------+------+------+------+ | ... | ... | ... | ... | +------+------+------+------+ | b_m0 | b_m1 | ... | b_mm | +------+------+------+------+ Each sub matrix b_ij is a column vector of length p that contains the filter coefficients from channel j (source) to channel i (sink). References ---------- .. footbibliography:: """ if model not in ['avg-epochs', 'dynamic']: raise ValueError(f'"model" parameter must be one of ' f'(avg-epochs, dynamic), not {model}.') # 1. determine shape of the window of data n_epochs, n_nodes, n_times = data.shape model_params = { 'model_order': model_order, 'l2_reg': l2_reg } if model == 'avg-epochs': # compute VAR model where each epoch is a # sample of the multivariate time-series of interest # ordinary least squares or regularized least squares # (ridge regression) X, Y = _construct_var_eqns(data, **model_params) b, res, rank, s = scipy.linalg.lstsq(X, Y) # get the coefficients coef = b.transpose() # create connectivity coef = coef.flatten() conn = Connectivity(data=coef, n_nodes=n_nodes, names=names, n_epochs_used=n_epochs, times_used=times, method='VAR', **model_params) else: assert model == 'dynamic' if times is None and n_epochs > 1: raise RuntimeError('If computing time (epoch) varying VAR model, ' 'then "times" must be passed in. From ' 'MNE epochs, one can extract this using ' '"epochs.times".') # compute time-varying VAR model where each epoch # is one sample of a time-varying multivariate time-series # linear system conn = _system_identification( data=data, times=times, names=names, model_order=model_order, l2_reg=l2_reg, n_jobs=n_jobs, compute_fb_operator=compute_fb_operator, verbose=verbose) return conn
def _construct_var_eqns(data, model_order, l2_reg=None): """Construct VAR equation system (optionally with RLS constraint). This function was originally imported from ``scot``. Parameters ---------- data : np.ndarray (n_epochs, n_signals, n_times) The multivariate data. model_order : int The order of the VAR model. l2_reg : float, optional The l2 penalty term for ridge regression, by default None, which will result in ordinary VAR equation. Returns ------- X : np.ndarray The predictor multivariate time-series. This will have shape ``(model_order * (n_times - model_order), n_signals * model_order)``. See Notes. Y : np.ndarray The predicted multivariate time-series. This will have shape ``(model_order * (n_times - model_order), n_signals * model_order)``. See Notes. Notes ----- This function will format data such as: Y = A X where Y is time-shifted data copy of X and ``A`` defines how X linearly maps to Y. """ # n_epochs, n_signals, n_times n_epochs, n_signals, n_times = np.shape(data) # number of linear relations n = (n_times - model_order) * n_epochs rows = n if l2_reg is None else n + n_signals * model_order # Construct matrix X (predictor variables) X = np.zeros((rows, n_signals * model_order)) for i in range(n_signals): for k in range(1, model_order + 1): X[:n, i * model_order + k - 1] = np.reshape(data[:, i, model_order - k:-k].T, n) if l2_reg is not None: np.fill_diagonal(X[n:, :], l2_reg) # Construct vectors yi (response variables for each channel i) Y = np.zeros((rows, n_signals)) for i in range(n_signals): Y[:n, i] = np.reshape(data[:, i, model_order:].T, n) return X, Y def _construct_snapshots(snapshots, order, n_times): """Construct snapshots matrix. This will construct a matrix along the 0th axis (rows), stacking copies of the data based on order. Parameters ---------- snapshots : np.ndarray (n_signals, n_times) A multivariate time-series. order : int The order of the linear model to be estimated. n_times : int The number of times in the original dataset. Returns ------- snaps : np.ndarray (n_signals * order, n_times - order) A snapshot matrix with copies of the original ``snapshots`` along the rows based on the ``order`` of the model. Notes ----- Say ``snapshots`` is an array with shape (N, T) with order ``M``. We will abbreviate this matrix and call it ``X``. ``X_ij`` is the ith signal at time point j. The resulting ``snaps`` matrix would be a (N*M, T - M) array: +------+------+------+------------+ | X_00 | X_01 | ... | X_0(T-M+1) | +------+------+------+------------+ | X_10 | X_11 | ... | X_1(T-M+1) | +------+------+------+------------+ | ... | ... | ... | ... | +------+------+------+------------+ | X_N0 | X_N1 | ... | X_N(T-M+1) | +------+------+------+------------+ | X_01 | X_02 | ... | X_0(T-M+2) | +------+------+------+------------+ | ... | ... | ... | ... | +------+------+------+------------+ | X_N1 | X_N2 | ... | X_N(T-M+2) | +------+------+------+------------+ | ... | ... | ... | ... | +------+------+------+------------+ | X_NM | X_N2 | ... | X_N(T-M+M-1)| +------+------+------+------------+ """ snaps = np.concatenate( [snapshots[:, i: n_times - order + i + 1] for i in range(order)], axis=0, ) return snaps def _system_identification(data, times, names=None, model_order=1, l2_reg=0, random_state=None, n_jobs=-1, compute_fb_operator=False, verbose=True): """Solve system identification using least-squares over all epochs. Treats each epoch as a different window of time to estimate the model: .. math:: X(t+1) = \sum_{i=0}^{order} A_i X(t - i) where ``data`` comprises of ``(n_signals, n_times)`` and ``X(t)`` are the data snapshots. """ # 1. determine shape of the window of data n_epochs, n_nodes, n_times = data.shape model_params = { 'l2_reg': l2_reg, 'model_order': model_order, 'random_state': random_state, 'compute_fb_operator': compute_fb_operator } # compute the A matrix for all Epochs A_mats = np.zeros((n_epochs, n_nodes * model_order, n_nodes)) if n_jobs == 1: for idx in tqdm(range(n_epochs)): A = _compute_lds_func(data[idx, ...], **model_params) A_mats[idx, ...] = A else: try: from joblib import Parallel, delayed except ImportError as e: raise ImportError(e) arr = data # run parallelized job to compute over all windows results = Parallel(n_jobs=n_jobs)( delayed(_compute_lds_func)( arr[idx, ...], **model_params ) for idx in tqdm(range(n_epochs)) ) for idx in range(len(results)): adjmat = results[idx] # add additional order models in dynamic connectivity # along the first node axes for jdx in range(model_order): A_mats[idx, jdx * n_nodes: n_nodes * (jdx + 1), :] = adjmat[ -n_nodes:, jdx * n_nodes: n_nodes * (jdx + 1) ] # create connectivity A_mats = A_mats.reshape((n_epochs, -1)) conn = EpochConnectivity(data=A_mats, n_nodes=n_nodes, names=names, n_epochs_used=n_epochs, times_used=times, method='Time-varying LDS', **model_params) return conn def _compute_lds_func(data, model_order, l2_reg, compute_fb_operator, random_state): """Compute linear system using VAR model. Allows for parallelization over epochs. """ from sklearn.linear_model import Ridge n_times = data.shape[-1] # create large snapshot with time-lags of order specified by # ``order`` value snaps = _construct_snapshots( data, order=model_order, n_times=n_times ) # get the time-shifted components of each X, Y = snaps[:, :-1], snaps[:, 1:] # use scikit-learn Ridge Regression to fit fit_intercept = False normalize = False solver = 'auto' clf = Ridge( alpha=l2_reg, fit_intercept=fit_intercept, normalize=normalize, solver=solver, random_state=random_state, ) # n_samples X n_features and n_samples X n_targets clf.fit(X.T, Y.T) # n_targets X n_features A = clf.coef_ if compute_fb_operator: # compute backward linear operator clf.fit(Y.T, X.T) back_A = clf.coef_ A = sqrtm(A.dot(np.linalg.inv(back_A))) return A