"""Utility and helper functions for MNE-HFO."""
# License: BSD (3-clause)
import json
import os
from os import path as op
import mne
import numpy as np
import pandas as pd
from scipy.signal import hilbert
from tqdm import tqdm
from mne_hfo.config import ANNOT_COLUMNS, EVENT_COLUMNS
def _check_df(df: pd.DataFrame, df_type: str,
              copy: bool = True) -> pd.DataFrame:
    """Check dataframe for correctness."""
    if df_type == 'annotations':
        if any([col not in df.columns
                for col in ANNOT_COLUMNS + ['sample']]):
            raise RuntimeError(f'Annotations dataframe columns must contain '
                               f'{ANNOT_COLUMNS + ["sample"]}.')
    elif df_type == 'events':
        if any([col not in df.columns
                for col in EVENT_COLUMNS + ['sample']]):
            raise RuntimeError(f'Events dataframe columns must contain '
                               f'{EVENT_COLUMNS}.')
    # Only want to do this check if there are multiple rows. Handles edge case
    # of 1 HFO starting at 0. TODO: handle this more elegantly
    if df.shape[0] > 1:
        # first compute sampling rate from sample / onset columns
        sfreq = df['sample'].divide(df['onset']).round(2)
        # onset=0 will cause sfreq to be inf, drop these rows to
        # prevent additional sfreqs
        sfreq = sfreq.replace([np.inf, -np.inf], np.nan).dropna()
        if sfreq.nunique() != 1:
            raise ValueError(f'All rows in the annotations dataframe '
                             f'should have the same sampling rate. '
                             f'Found {sfreq.nunique()} different '
                             f'sampling rates.')
    if copy:
        return df.copy()
    return df
def _ensure_tuple(x):
    """Return a tuple."""
    if x is None:
        return tuple()
    elif isinstance(x, str):
        return (x,)
    else:
        return tuple(x)
def _check_types(variables):
    """Make sure all vars are str or None."""
    for var in variables:
        if not isinstance(var, (str, type(None))):
            raise ValueError(f"You supplied a value ({var}) of type "
                             f"{type(var)}, where a string or None was "
                             f"expected.")
def _write_json(fname, dictionary, overwrite=False, verbose=False):
    """Write JSON to a file."""
    if op.exists(fname) and not overwrite:
        raise FileExistsError(f'"{fname}" already exists. '
                              'Please set overwrite to True.')
    json_output = json.dumps(dictionary, indent=4)
    with open(fname, 'w', encoding='utf-8') as fid:
        fid.write(json_output)
        fid.write('\n')
    if verbose is True:
        print(os.linesep + f"Writing '{fname}'..." + os.linesep)
        print(json_output)
def _band_zscore_detect(signal, sfreq, band_idx, l_freq, h_freq, n_times,
                        cycles_threshold, gap_threshold, zscore_threshold):
    """
    Find detections that meet the Hilbert envelope criteria.
    Parameters
    ----------
    signal : np.ndarray
        A single channel's Hilbert transform within a frequency band
    sfreq : float
        Sampling frequency of the data
    band_idx :  int
        The index of the frequency band
    l_freq : float
        The low frequency of the band
    h_freq : float
        The high frequency of the band
    n_times : int
        The number of timepoints used to calculate the signal
    cycles_threshold : float
        The number of cycles to be considered a valid envelope
    gap_threshold : float
        The number of cycles needed to be considered a gap
    zscore_threshold : float
        Value to threshold the signal on
    Returns
    -------
    tdetects: List[Tuple[int, int, int, int]]
        All HFO events that passed the bandpass, zscore. Each tuple contains:
        [0] - The band index
        [1] - The timepoint of the start of a detection
        [2] - The timepoint of the end of the detection
        [3] - Maximum value of the Hilbert envelope in this event window
    """
    # Detections where the envelope has a zscore greater than threshold
    tdetects = []
    # Create boolean mask of signal greater than zscore_threshold
    thresh_sig = np.zeros(n_times, dtype='bool')
    thresh_sig[signal > zscore_threshold] = 1
    idx = 0
    # Find indices where threshold is met
    thresh_idxs = np.where(thresh_sig == 1)[0]
    # Calculate the required samples to be considered a valid gap
    gap_samp = round(gap_threshold * sfreq / l_freq)
    # Iterate over valid indices (significant zscore timepoints)
    while idx < len(thresh_idxs) - 1:
        # Find the start of the envelope, which occurs when back to back
        # time-points meet the threshold
        if (thresh_idxs[idx + 1] - thresh_idxs[idx]) == 1:
            start_idx = thresh_idxs[idx]
            # Find where the envelope ends by iterating over indices
            while idx < len(thresh_idxs) - 1:
                # Check if last index over threshold. If so,
                # consider this index to be the end of the envelope
                if (thresh_idxs[idx + 1] - thresh_idxs[idx]) == 1:
                    idx += 1
                    if idx == len(thresh_idxs) - 1:
                        stop_idx = thresh_idxs[idx]
                        # Check that envelope meets number of cycles criteria
                        dur = (stop_idx - start_idx) / sfreq
                        cycs = l_freq * dur
                        if cycs > cycles_threshold:
                            # Valid, so append to detections
                            tdetects.append([band_idx, start_idx, stop_idx,
                                             max(signal[start_idx:stop_idx]),
                                             [l_freq, h_freq]])
                else:
                    # If there is no gap between this and the next index,
                    # it is still part of this envelope. Increment the
                    # index.
                    if (thresh_idxs[idx + 1] - (thresh_idxs[idx])) < gap_samp:
                        idx += 1
                    # If the next index has a gap, the current index is the
                    # end of the envelope
                    else:
                        stop_idx = thresh_idxs[idx]
                        # Check that envelope meets number of cycles criteria
                        dur = (stop_idx - start_idx) / sfreq
                        cycs = l_freq * dur
                        if cycs > cycles_threshold:
                            # Valid, so append to detections
                            tdetects.append([band_idx, start_idx, stop_idx,
                                             max(signal[start_idx:stop_idx]),
                                             [l_freq, h_freq]])
                        idx += 1
                        break
        else:
            idx += 1
    return tdetects
[docs]def compute_rms(signal, win_size=6):
    """
    Calculate the Root Mean Square (RMS) energy.
    Parameters
    ----------
    signal: numpy array
        1D signal to be transformed
    win_size: int
        Number of the points of the window (default=6)
    Returns
    -------
    rms: numpy array
        Root mean square transformed signal
    """
    aux = np.power(signal, 2)
    window = np.ones(win_size) / float(win_size)
    return np.sqrt(np.convolve(aux, window, 'same')) 
[docs]def compute_line_length(signal, win_size=6):
    """Calculate line length.
    Parameters
    ----------
    signal: numpy array
        1D signal to be transformed
    win_size: int
        Number of the points of the window (default=6)
    Returns
    -------
    line_length: numpy array
        Line length transformed signal
    Notes
    -----
    ::
        return np.mean(np.abs(np.diff(data, axis=-1)), axis=-1)
    References
    ----------
    .. [1] Esteller, R. et al. (2001). Line length: an efficient feature for
           seizure onset detection. In Engineering in Medicine and Biology
           Society, 2001. Proceedings of the 23rd Annual International
           Conference of the IEEE (Vol. 2, pp. 1707-1710). IEEE.
    .. [2] Dümpelmann et al, 2012.  Clinical Neurophysiology: 123 (9): 1721-31.
    """
    aux = np.abs(np.subtract(signal[1:], signal[:-1]))
    window = np.ones(win_size) / float(win_size)
    data = np.convolve(aux, window)
    start = int(np.floor(win_size / 2))
    stop = int(np.ceil(win_size / 2))
    return data[start:-stop] 
[docs]def compute_hilbert(signal, freq_cutoffs, freq_span, sfreq):
    """Compute the Hilbert envelope for a single channel.
    Parameters
    ----------
    signal : np.array
        EEG signal for a single channel
    extra_params : dict
        Must have values for 'freq_cutoffs', 'freq_span', and
        'sfreq'
    Returns
    -------
    hfx_bands: np.ndarray
        Hilbert transforms per freq band
    """
    hfx_bands = []
    # Iterate over freq bands
    for ind in range(freq_span):
        l_freq = freq_cutoffs[ind]
        h_freq = freq_cutoffs[ind + 1]
        # Filter the data for this frequency band
        signal = mne.filter.filter_data(signal, sfreq=sfreq,
                                        l_freq=l_freq, h_freq=h_freq,
                                        method='iir', verbose=False)
        # compute z-score of data
        signal = (signal - np.mean(signal)) / np.std(signal)
        # Chunk the signal into 30 second windows and compute the Hilbert
        # to save memory
        hfx = np.empty(signal.shape)
        n_times = len(hfx)
        win_size = int(sfreq * 30)
        n_wins = int(np.ceil(n_times / win_size))
        for win in range(n_wins):
            start_samp = win * win_size
            end_samp = (win + 1) * win_size
            if win == n_wins:
                end_samp = n_times
            sig = signal[start_samp:end_samp]
            hfx[start_samp:end_samp] = np.abs(hilbert(sig))
        # return the absolute value of the Hilbert transform.
        # (i.e. the envelope)
        hfx_bands.append(hfx)
        hfx = None
    return hfx_bands 
[docs]def apply_hilbert(metric, threshold_dict, kwargs):
    """Apply the Hilbert z-score thresholding scheme.
    Parameters
    ----------
    metric : np.ndarray
        The values to apply the threshold rules to.
    threshold_dict : dict
        Dictionary of  threshold parameters to apply to metric.
        Must have zscore, gap, and cycles keys
    kwargs : dict
        Additional model parameters needed to apply hilbert threshold.
        Must have n_times, sfreq, filter_band, freq_cutoffs,
        freq_span, and n_jobs.
    Returns
    -------
    tdetects: List(tuples)
        Detected hfo events with the structure [band_idx, start,
        stop, max_amplitude, freq_band]
    """
    # get threshold vals
    zscore_threshold = threshold_dict["zscore"]
    gap_threshold = threshold_dict["gap"]
    cycles_threshold = threshold_dict["cycles"]
    if any(elem is None for elem in [zscore_threshold, gap_threshold,
                                     cycles_threshold]):
        raise RuntimeError(f"threshold_dict must have values for zscore,"
                           f" gap, and cycles. You passed {threshold_dict}")
    n_times = kwargs["n_times"]
    sfreq = kwargs["sfreq"]
    filter_band = kwargs["filter_band"]
    freq_cutoffs = kwargs["freq_cutoffs"]
    freq_span = kwargs["freq_span"]
    n_jobs = kwargs["n_jobs"]
    if any(elem is None for elem in [n_times, sfreq, filter_band,
                                     freq_cutoffs, freq_span, n_jobs]):
        raise RuntimeError(f"kwargs must have values for n_times, sfreq,"
                           f" filter_band, freq_cutoffs, freq_span, n_jobs."
                           f" You passed {kwargs}")
    tdetects = []
    for i in tqdm(range(freq_span), unit="HFO-first-phase"):
        # Find bottom and top of the frequency band
        bot = freq_cutoffs[i]
        top = freq_cutoffs[i + 1]
        # Make sure you only look at Hilbert envelope values
        # for the specific freq band
        tdetects.append(_band_zscore_detect(metric[i], sfreq, i, bot,
                                            top, n_times, cycles_threshold,
                                            gap_threshold,
                                            zscore_threshold))
    return tdetects 
[docs]def apply_std(metric, threshold_dict, kwargs):
    """Calculate and apply the threshold based on number of standard deviations.
    Parameters
    ----------
    metric : np.ndarray
        Values to apply the threshold to
    threshold_dict : dict
        Dictionary of threshold values. Should just have thresh,
        which is the number of standard deviations to check against
    kwargs : dict
        Additional key-word args from the detector needed to
        apply the threshold.
        Step_size, win_size, and n_times are required keys.
    Returns
    -------
    output: List(tuples)
        List of detected events that pass the threshold
    """
    # determine threshold value
    threshold = threshold_dict["thresh"]
    if threshold is None:
        raise RuntimeError(f"threshold_dict must have a value for 'thresh'."
                           f" You passed {threshold_dict}")
    det_th = _get_threshold_std(metric, threshold)
    n_windows = len(metric)
    step_size = kwargs["step_size"]
    win_size = kwargs["win_size"]
    n_times = kwargs["n_times"]
    if any(elem is None for elem in [step_size, win_size, n_times]):
        raise RuntimeError(f"kwargs must have step_size, win_size, "
                           f"and n_times. You passed {kwargs}")
    # store thresholded hfo events as a list
    output = []
    # Detect and now group events if they are within a
    # step size of each other
    win_idx = 0
    while win_idx < n_windows:
        # log events if they pass our threshold criterion
        if metric[win_idx] >= det_th:
            event_start = win_idx * step_size
            # group events together if they occur in
            # contiguous windows
            # TODO: We could factor this out into an independent step,
            #  but that will just add comp time
            while win_idx < n_windows and \
                    
metric[win_idx] >= det_th:
                win_idx += 1
            event_stop = (win_idx * step_size) + win_size
            if event_stop > n_times:
                event_stop = n_times
            # TODO: Optional feature calculations
            # Write into output
            output.append((event_start, event_stop))
            win_idx += 1
        else:
            win_idx += 1
    return output 
def _get_threshold_std(signal, threshold):
    """
    Calculate threshold by Standard Deviations above the mean.
    Parameters
    ----------
    signal: numpy array
        1D signal for threshold determination
    threshold: int
        Number of standard deviations to consider.
    Returns
    -------
    ths_value: float
        Value of the threshold
    """
    ths_value = np.mean(signal) + threshold * np.std(signal)
    return ths_value
def merge_contiguous_freq_bands(detections):
    """Merge detected events in contiguous freq bands and time windows.
    Parameters
    ----------
    detections : List(tuple)
        List of detections, which have the form [band_idx, start,
         stop, max_amplitude, freq_band]
    Returns
    -------
    hfo_events: List(tuple)
        List of distinct hfo events, which have the form [start, stop]
    max_hilbert: List(int)
        List of max values in each event
    freq_bands: List(tuple)
        List of the freq_band for each event
    """
    from mne_hfo.posthoc import _check_detection_overlap
    outlines = []
    for detection in detections[0]:
        band_idx = detection[0]
        # If first freq band, always unique so append
        if band_idx == 0:
            outlines.append(detection)
        else:
            for ind, outline in enumerate(outlines):
                # only try to merge contiguous freq bands
                if outline[0] == band_idx + 1:
                    # Check if the events overlap in time
                    if _check_detection_overlap([detection[1], detection[2]],
                                                [outline[1], outline[2]]):
                        # merge the overlapping events
                        outlines[ind] = _merge_outline(outlines, detection)
                    else:
                        # Events dont overlap so append it
                        outlines.append(detection)
                else:
                    # Events are contiguous so append it
                    outlines.append(detection)
    # extract start and stop times
    hfo_events = [[o[1], o[2]] for o in outlines]
    max_hilbert = [o[3] for o in outlines]
    freq_bands = [[o[4][0], o[4][1]] for o in outlines]
    return hfo_events, max_hilbert, freq_bands
def _merge_outline(outline, detection):
    band_idx = detection[0]
    start = min(outline[1], detection[1])
    stop = max(outline[2], detection[2])
    max_frq = max(outline[3], detection[3])
    freq_band = [outline[4][0], detection[4][1]]
    return [band_idx, start, stop, max_frq, freq_band]
[docs]def threshold_tukey(signal, threshold):
    """
    Calculate threshold by Tukey method.
    Parameters
    ----------
    signal: numpy array
        1D signal for threshold determination
    threshold: float
        Number of interquartile interval above the 75th percentile
    Returns
    -------
    ths_value: float
        Value of the threshold
    References
    ----------
    [1] TUKEY JW. Comparing individual means in the analysis of
        variance. Biometrics. 1949 Jun;5(2):99-114. PMID: 18151955.
    """
    ths_value = np.percentile(signal, 75) + threshold * (np.percentile(signal, 75) - np.percentile(signal, 25))  # noqa
    return ths_value 
def threshold_quian(signal, threshold):
    """
    Calculate threshold by Quian.
    Parameters
    ----------
    signal: numpy array
        1D signal for threshold determination
    threshold: float
        Number of estimated noise SD above the mean
    Returns
    -------
    ths_value: float
        Value of the threshold
    References
    ----------
    1. Quian Quiroga, R. 2004. Neural Computation 16: 1661–87.
    """
    ths_value = threshold * np.median(np.abs(signal)) / 0.6745
    return ths_value