Source code for mne_connectivity.wsmi

# Authors: Giovanni Marraffini <giovanni.marraffini@gmail.com>
#          Laouen Belloli <laouen.belloli@gmail.com>
#          Based on the work of Jean-Remy King, Jacobo Sitt and Federico Raimondo
#
# License: BSD (3-clause)

import math
import warnings
from itertools import permutations

import numpy as np
from mne._fiff.pick import _picks_to_idx
from mne.epochs import BaseEpochs
from mne.fixes import jit
from mne.utils import _time_mask, logger, verbose
from mne.utils.check import _check_option, _validate_type
from mne.utils.docs import fill_doc
from scipy.signal import butter, filtfilt

from .base import Connectivity, EpochConnectivity
from .utils import check_indices


def _define_symbols(kernel):
    """Define all possible symbols for a given kernel size (original implementation)."""
    result_dict = dict()
    total_symbols = math.factorial(kernel)
    cursymbol = 0
    for perm in permutations(range(kernel)):
        order = "".join(map(str, perm))
        if order not in result_dict:
            result_dict[order] = cursymbol
            cursymbol = cursymbol + 1
            result_dict[order[::-1]] = total_symbols - cursymbol
    result = []
    for v in range(total_symbols):
        for symbol, value in result_dict.items():
            if value == v:
                result += [symbol]
    return result


def _symb(data, kernel, tau):
    """Compute symbolic transform using original logic but optimized.

    This matches the original _symb_python exactly but with optimizations.
    """
    symbols = _define_symbols(kernel)
    dims = data.shape

    signal_sym_shape = list(dims)
    signal_sym_shape[1] = data.shape[1] - tau * (kernel - 1)
    signal_sym = np.zeros(signal_sym_shape, np.int32)

    count_shape = list(dims)
    count_shape[1] = len(symbols)
    count = np.zeros(count_shape, np.int32)

    # Create a dict for fast lookup (instead of symbols.index which is O(n))
    symbol_to_idx = {symbol: idx for idx, symbol in enumerate(symbols)}

    for k in range(signal_sym_shape[1]):
        subsamples = range(k, k + kernel * tau, tau)
        ind = np.argsort(data[:, subsamples], 1)

        # Process each channel and epoch
        for ch in range(data.shape[0]):
            for ep in range(data.shape[2]):
                symbol_str = "".join(map(str, ind[ch, :, ep]))
                signal_sym[ch, k, ep] = symbol_to_idx[symbol_str]

    count = np.double(
        np.apply_along_axis(
            lambda x: np.bincount(x, minlength=len(symbols)), 1, signal_sym
        )
    )

    return signal_sym, (count / signal_sym_shape[1])


def _get_weights_matrix(nsym):
    """Aux function (original implementation)."""
    wts = np.ones((nsym, nsym))
    np.fill_diagonal(wts, 0)
    wts = np.fliplr(wts)
    np.fill_diagonal(wts, 0)
    wts = np.fliplr(wts)
    return wts


@jit(parallel=True)  # Enabled parallel execution
def _wsmi_jitted(  # pragma: no cover
    data_sym, counts, wts_matrix, weighted=True
):
    """Compute raw wSMI or SMI from symbolic data (Numba-jitted, if installed).

    Parameters
    ----------
    data_sym : ndarray
        Symbolic data.
    counts : ndarray
        Symbol counts.
    wts_matrix : ndarray
        Weights matrix.
    weighted : bool
        If True, compute wSMI. If False, compute SMI.

    Returns
    -------
    result : ndarray
        Computed connectivity values (either wSMI or SMI).
    """
    nchannels, nsamples_after_symb, ntrials = data_sym.shape
    n_unique_symbols = counts.shape[1]

    result = np.zeros((nchannels, nchannels, ntrials), dtype=np.double)

    epsilon = 1e-15
    log_counts = np.log(counts + epsilon)

    for trial_idx in range(ntrials):
        for ch1_idx in range(nchannels):
            for ch2_idx in range(ch1_idx + 1, nchannels):
                pxy = np.zeros((n_unique_symbols, n_unique_symbols), dtype=np.double)
                for sample_idx in range(nsamples_after_symb):
                    sym1 = data_sym[ch1_idx, sample_idx, trial_idx]
                    sym2 = data_sym[ch2_idx, sample_idx, trial_idx]

                    pxy[sym1, sym2] += 1

                if nsamples_after_symb > 0:
                    pxy /= nsamples_after_symb

                current_result_val = 0.0

                # Compute MI terms manually to avoid broadcasting issues in Numba
                for r_idx in range(n_unique_symbols):
                    for c_idx in range(n_unique_symbols):
                        if pxy[r_idx, c_idx] > epsilon:
                            log_pxy_val = np.log(pxy[r_idx, c_idx])
                            log_px_val = log_counts[ch1_idx, r_idx, trial_idx]
                            log_py_val = log_counts[ch2_idx, c_idx, trial_idx]

                            mi_term = pxy[r_idx, c_idx] * (
                                log_pxy_val - log_px_val - log_py_val
                            )

                            if weighted:
                                current_result_val += wts_matrix[r_idx, c_idx] * mi_term
                            else:
                                current_result_val += mi_term

                result[ch1_idx, ch2_idx, trial_idx] = current_result_val

    if n_unique_symbols > 1:
        norm_factor = np.log(n_unique_symbols)
        if norm_factor > epsilon:
            result /= norm_factor
    else:
        result_fill_val = 0.0
        result[:, :, :] = result_fill_val

    return result + result.transpose(1, 0, 2)  # make symmetric


def _apply_anti_aliasing(data, sfreq, kernel, tau, anti_aliasing, is_epochs, info=None):
    """Apply anti-aliasing filtering based on parameters and data type.

    Parameters
    ----------
    data : ndarray
        Data array of shape (n_epochs, n_channels, n_times).
    sfreq : float
        Sampling frequency in Hz.
    kernel : int
        Pattern length for symbolic analysis.
    tau : int
        Time delay between pattern elements.
    anti_aliasing : bool | str
        Anti-aliasing mode: True (always), False (never), or "auto" (smart detection).
    is_epochs : bool
        Whether the original data was an MNE Epochs object.
    info : mne.Info | None
        MNE Info object (only available if is_epochs=True).

    Returns
    -------
    filtered_data : ndarray
        Data array of shape (n_channels, n_times, n_epochs) ready for symbolic
        transformation, with anti-aliasing applied if needed.
    """
    n_epochs = data.shape[0]
    anti_alias_freq = np.double(sfreq) / kernel / tau
    nyquist_freq = sfreq / 2.0

    # Determine if filtering is needed based on anti_aliasing mode
    should_filter = False
    skip_reason = None

    if anti_aliasing is False:
        # Never filter - warn about potential issues
        effective_sfreq = sfreq / tau
        warnings.warn(
            f"Anti-aliasing disabled. Effective sampling rate for symbolic "
            f"transformation is {effective_sfreq:.1f} Hz (sfreq/tau={sfreq}/{tau}). "
            f"Ensure your data is appropriately filtered to prevent aliasing.",
            UserWarning,
        )
        should_filter = False
    else:  # True or "auto"
        # Check if anti-aliasing frequency is too close to Nyquist
        if anti_alias_freq >= nyquist_freq * 0.99:
            skip_reason = (
                f"Anti-aliasing frequency ({anti_alias_freq:.2f} Hz) too close to "
                f"Nyquist frequency ({nyquist_freq:.2f} Hz)"
            )
            should_filter = False
        else:
            if anti_aliasing is True:
                should_filter = True
            else:  # Auto mode: smart detection based on data type and preprocessing
                if not is_epochs:
                    # Array input: always filter since we don't know preprocessing
                    logger.info(
                        "Auto anti-aliasing: Array input detected, applying filter "
                        "(preprocessing history unknown)."
                    )
                    should_filter = True
                else:
                    # MNE Epochs: check if already appropriately filtered
                    existing_lowpass = info.get("lowpass", None)

                    if (
                        existing_lowpass is not None
                        and existing_lowpass <= anti_alias_freq
                    ):
                        # Data already filtered at or below required frequency
                        logger.info(
                            f"Auto anti-aliasing: Data already low-pass filtered at "
                            f"{existing_lowpass:.2f} Hz (<= {anti_alias_freq:.2f} Hz). "
                            f"Skipping additional filtering."
                        )
                        should_filter = False
                    else:
                        # Need to apply filtering
                        if existing_lowpass is not None:
                            logger.info(
                                f"Auto anti-aliasing: Existing lowpass "
                                f"({existing_lowpass:.2f} Hz) > required "
                                f"({anti_alias_freq:.2f} Hz). Applying filter."
                            )
                        else:
                            logger.info(
                                f"Auto anti-aliasing: No lowpass filter info found. "
                                f"Applying filter at {anti_alias_freq:.2f} Hz."
                            )
                        should_filter = True

    # Apply filtering if needed
    if should_filter:
        logger.info(f"Applying anti-aliasing filter at {anti_alias_freq:.2f} Hz")

        # Make a copy to avoid modifying original data
        data = data.copy()

        # Design and apply low-pass filter
        normalized_freq = 2.0 * anti_alias_freq / np.double(sfreq)
        b, a = butter(6, normalized_freq, "lowpass")

        # Concatenate epochs horizontally for filtering
        data_concatenated = np.hstack(data)

        # Filter the concatenated data
        fdata_concatenated = filtfilt(b, a, data_concatenated)

        # Split back into epochs and transpose to match expected format
        # Output shape: (n_channels, n_times, n_epochs)
        filtered_data = np.transpose(
            np.array(np.split(fdata_concatenated, n_epochs, axis=1)), [1, 2, 0]
        )
    else:
        if skip_reason:
            logger.info(f"{skip_reason}. Skipping anti-aliasing filter.")
        # Transpose to match expected format: (n_channels, n_times, n_epochs)
        filtered_data = data.transpose(1, 2, 0)

    return filtered_data


def _validate_kernel(kernel, tau):
    """Validate kernel and tau parameters for wSMI computation.

    Parameters
    ----------
    kernel : int
        Pattern length (symbol dimension) for symbolic analysis.
    tau : int
        Time delay (lag) between consecutive pattern elements.

    Raises
    ------
    ValueError
        If kernel or tau parameters are invalid.
    """
    _validate_type(kernel, "int", "kernel")
    _validate_type(tau, "int", "tau")

    if kernel <= 1:
        raise ValueError(f"kernel (pattern length) must be > 1, got {kernel}")
    if tau <= 0:
        raise ValueError(f"tau (delay) must be > 0, got {tau}")

    # Warn about potentially large memory requirements for large kernels
    if kernel > 7:  # Factorial grows extremely fast beyond this
        n_symbols = math.factorial(kernel)
        memory_gb = (n_symbols**2 * 8) / (1024**3)  # 8 bytes per double
        warnings.warn(
            f"kernel={kernel} will require ~{memory_gb:.1f} GB of memory "
            f"(factorial({kernel}) = {n_symbols} symbols). "
            f"Consider using kernel <= 7 if you encounter memory errors.",
            UserWarning,
            stacklevel=3,
        )


[docs] @fill_doc @verbose def wsmi( data, kernel, tau, indices=None, sfreq=None, names=None, tmin=None, tmax=None, anti_aliasing="auto", weighted=True, average=False, verbose=None, ): """Compute weighted symbolic mutual information (wSMI). Parameters ---------- data : array_like, shape (n_epochs, n_signals, n_times) | ~mne.Epochs The data from which to compute connectivity. Can be an :class:`mne.Epochs` object or array-like data. kernel : int Pattern length (symbol dimension) for symbolic analysis. Must be > 1. Values > 7 may require significant memory. tau : int Time delay (lag; in samples) between consecutive pattern elements. Must be > 0. indices : tuple of array_like | None Two array-likes with indices of connections for which to compute connectivity. If ``None``, all connections are computed (lower triangular matrix). For example, to compute connectivity between channels 0 and 2, and between channels 1 and 3, use ``indices = (np.array([0, 1]), np.array([2, 3]))``. sfreq : float | None The sampling frequency. Required if ``data`` is an array-like. names : array_like | None A list of names associated with the signals in ``data``. If ``None`` and ``data`` is an :class:`mne.Epochs` object, the names in ``data`` will be used. If ``data`` is an array-like, the names will be a list of indices of the number of nodes. tmin : float | None Time to start connectivity estimation. Note: when ``data`` is an array-like, the first sample is assumed to be at time 0. For :class:`mne.Epochs`, the time information contained in the object is used to compute the time indices. If ``None``, uses the first sample. tmax : float | None Time to end connectivity estimation. Note: when ``data`` is an array-like, the first sample is assumed to be at time 0. For :class:`mne.Epochs`, the time information contained in the object is used to compute the time indices. If ``None``, uses the last sample. anti_aliasing : ``'auto'`` | bool Controls anti-aliasing low-pass filtering before symbolic transformation. - ``'auto'`` (default): Smart detection based on ``data`` type and preprocessing. For array inputs, always applies filtering. For :class:`mne.Epochs`, checks ``info['lowpass']`` to determine if data is already appropriately filtered. Only applies filtering if existing lowpass > required frequency. - ``True``: Always apply an anti-aliasing filter at ``sfreq / kernel / tau`` Hz. - ``False``: Never apply filtering. Use only if you have already applied appropriate low-pass filtering to your data. .. warning:: Setting to ``False`` may produce unreliable results if the effective sampling rate (``sfreq / tau``) violates the Nyquist criterion for the spectral content of your data. weighted : bool Whether to compute weighted SMI (wSMI) or standard SMI. If ``True`` (default), computes wSMI with distance-based weights. If ``False``, computes standard SMI without weights. average : bool Whether to average connectivity across epochs. If ``True``, returns connectivity averaged over epochs. If ``False`` (default), returns connectivity for each epoch separately. %(verbose)s Returns ------- conn : instance of Connectivity or EpochConnectivity Computed connectivity measure. If ``average=True``, returns a :class:`Connectivity` instance with connectivity averaged across epochs. If ``average=False``, returns an :class:`EpochConnectivity` instance with connectivity for each epoch. Notes ----- The weighted Symbolic Mutual Information (wSMI) is a connectivity measure that quantifies non-linear statistical dependencies between time series based on symbolic dynamics :footcite:`KingEtAl2013`. The method involves: 1. Symbolic transformation of time series using ordinal patterns 2. Computation of mutual information between symbolic sequences 3. Weighting based on pattern distance for enhanced sensitivity References ---------- .. footbibliography:: """ # Input validation and data handling for both Epochs and arrays _validate_type(weighted, bool, "weighted") _validate_type(average, bool, "average") _check_option("anti_aliasing", anti_aliasing, (True, False, "auto")) # Handle both MNE Epochs and array inputs picks = None is_epochs = isinstance(data, BaseEpochs) info = None if is_epochs: info = data.info sfreq = info["sfreq"] events = data.events event_id = data.event_id metadata = data.metadata ch_names = data.ch_names times = data.times # Get data data_for_comp = data.get_data() n_epochs, n_nodes, n_times_epoch = data_for_comp.shape # Only exclude bad channels when indices is None if indices is None: picks = _picks_to_idx(info, picks="all", exclude="bads") # Apply picks to data for computation data_for_comp = data_for_comp[:, picks, :] n_epochs, n_channels, n_times_epoch = data_for_comp.shape else: # User provided explicit indices, use all channels n_channels = n_nodes else: # Array-like input if sfreq is None: raise ValueError("Sampling frequency (sfreq) is required with array input.") data_for_comp = np.asarray(data) if data_for_comp.ndim != 3: raise ValueError( f"Array input must be 3D (n_epochs, n_channels, n_times), " f"got shape {data_for_comp.shape}" ) n_epochs, n_channels, n_times_epoch = data_for_comp.shape n_nodes = n_channels picks = np.arange(n_channels) times = np.arange(n_times_epoch) / sfreq # Set default values for array input events = None event_id = None metadata = None # Handle names parameter - just validate if provided if names is not None and len(names) != n_channels: raise ValueError( f"Number of names ({len(names)}) must match number of " f"channels ({n_channels})" ) ch_names = names # Validate all parameters early _validate_kernel(kernel, tau) # Check for insufficient channels for connectivity computation if n_channels < 2: raise ValueError( f"At least 2 channels are required for connectivity computation, " f"but only {n_channels} channels are available." ) logger.info( f"Processing {n_epochs} epochs, {n_channels} channels " f"({ch_names}), {n_times_epoch} time points per epoch." ) # Handle indices parameter if indices is None: logger.info("using all connections for lower-triangular matrix") # Compute lower-triangular connections indices_use = np.tril_indices(n_channels, k=-1) else: # User provided explicit indices indices_use = check_indices(indices) # Check that we have at least one valid connection if len(indices_use[0]) == 0: raise ValueError("No valid connections specified in indices parameter.") # Validate that indices are within the range of channels max_idx = max(np.max(indices_use[0]), np.max(indices_use[1])) if max_idx >= n_channels: raise ValueError( f"Index {max_idx} is out of range for {n_channels} channels" ) # Check that indices don't refer to the same channel (no self-connectivity) same_channel_mask = indices_use[0] == indices_use[1] if np.any(same_channel_mask): invalid_pairs = [ (indices_use[0][i], indices_use[1][i]) for i in range(len(indices_use[0])) if same_channel_mask[i] ] raise ValueError( f"Self-connectivity not supported. Found invalid pairs: {invalid_pairs}" ) logger.info(f"computing connectivity for {len(indices_use[0])} connections") # unique signals for which we actually need to compute values for sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) # map indices to unique indices idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] # select only needed signals data_for_comp = data_for_comp[:, sig_idx] # --- 2. Anti-aliasing filtering --- fdata = _apply_anti_aliasing( data_for_comp, sfreq, kernel, tau, anti_aliasing, is_epochs, info ) # --- Time masking --- time_mask = _time_mask(times, tmin, tmax) fdata_masked = fdata[:, time_mask, :] # Check if time masking resulted in too few samples for symbolization min_samples_needed_for_one_symbol = tau * (kernel - 1) + 1 if fdata_masked.shape[1] < min_samples_needed_for_one_symbol: raise ValueError( f"""After time masking ({tmin}-{tmax}s), data has {fdata_masked.shape[1]} samples per epoch, but at least {min_samples_needed_for_one_symbol} are needed for kernel={kernel}, tau={tau}. Adjust tmin/tmax or check epoch length.""" ) # Data is all ready for symbolic transformation: # (n_channels, n_times, n_epochs) fdata_for_symb = fdata_masked # --- 3. Symbolic Transformation --- logger.info("Performing symbolic transformation...") try: sym, count = _symb(fdata_for_symb, kernel, tau) except MemoryError as error: n_symbols = math.factorial(kernel) memory_gb = (n_symbols**2 * 8) / (1024**3) raise MemoryError( f"Insufficient memory for kernel={kernel} (requires ~{memory_gb:.1f} GB). " f"Try reducing kernel size (e.g., kernel <= 7) or use fewer " f"channels/epochs." ) from error except Exception as e: raise RuntimeError( "Error during symbolic transformation. Please contact the " "MNE-Connectivity developers." ) from e n_unique_symbols = count.shape[1] wts = _get_weights_matrix(n_unique_symbols) # --- 4. wSMI/SMI Computation --- method_name = "wSMI" if weighted else "SMI" logger.info(f"""Computing {method_name} for {n_unique_symbols} unique symbols...""") result = _wsmi_jitted(sym, count, wts, weighted) # Result is (n_channels, n_channels, n_epochs) result = result.transpose(2, 0, 1) # make epochs first dimension # --- Packaging results --- if indices is None: # Make it a lower-triangular matrix result = np.tril(result, k=-1) # Return all-to-all connectivity matrices raveled into a 1D array if len(picks) < n_nodes: # Bad channels were excluded, need to create full n_nodes x n_nodes matrix # and fill only the good channel entries con = np.zeros((n_epochs, n_nodes, n_nodes)) con[np.ix_(range(n_epochs), picks, picks)] = result else: con = result con = con.reshape(n_epochs, -1) else: # Extract only requested connections con = result[:, idx_map[0], idx_map[1]] # Create connectivity object with prepared data con_kwargs = dict( names=ch_names, method=method_name, indices=indices, n_epochs_used=n_epochs, n_nodes=n_nodes, events=events, event_id=event_id, metadata=metadata, ) if average: result_connectivity = Connectivity(data=np.mean(con, axis=0), **con_kwargs) else: result_connectivity = EpochConnectivity(data=con, **con_kwargs) logger.info(f"{method_name} computation finished.") return result_connectivity