Source code for mne_hfo.posthoc

import collections
from datetime import datetime, timedelta, timezone
from typing import List, Optional, Union

import numpy as np
import pandas as pd

from .config import TIME_SCALE_TO_SECS
from .utils import _check_df


def _to_freq(x, rate: str = "s"):
    """Convert a groupby DataFrame to rate.

    Parameters
    ----------
    x : pd.Series
        The series of the group to compute frequency of occurrence.
    rate : str
        One of ``s`` (second), ``m`` (minute), ``h`` (hour),
        ``d`` (day) to compute rate of the dataframe.

    Returns
    -------
    rate : float
        The rate of the events per unit of time, selected
        by ``rate`` input.
    """
    f = x.count() / x.mean()
    return f / TIME_SCALE_TO_SECS[rate]


[docs] def compute_chs_hfo_rates( annot_df: pd.DataFrame, rate: str, ch_names: Optional[List[str]] = None, end_sec: Optional[float] = None, verbose: bool = True, ): """Compute channel HFO rates from annotations DataFrame. This function will assume that each row is another HFO event. If you want to pre-process the HFOs that in some way overlap, do so beforehand. Parameters ---------- annot_df : pd.DataFrame The DataFrame corresponding to the ``annotations.tsv`` file. rate : str The frequency at which to compute the HFO rate. One of ``s`` (second), ``m`` (minute), ``h`` (hour), ``d`` (day) to compute rate of the dataframe. ch_names : list of str | None A list of channel names to constrain the rate computation to. Default = None will compute rate for all channels present in the ``annot_df``. end_sec : float | None The end time (in seconds) of the dataset that HFOs were computed on. If None (default), then will take the last detected HFO as the end time point. verbose : bool Verbosity. Returns ------- ch_hfo_rates : dict The HFO rates per channel with any HFOs. See Also -------- mne_hfo.io.read_annotations : Reading in annotations.tsv file as DataFrame. References ---------- .. [1] https://stackoverflow.com/questions/66143839/computing-rate-of-occurrences-per-unit-of-time-in-a-pandas-dataframe # noqa """ annot_df = _check_df(annot_df, df_type="annotations") # store channel rates over sliding window ch_hfo_rates = collections.defaultdict(list) # start timestamp with current time ref_timestamp = datetime.now(tz=timezone.utc) onset_tdelta = pd.to_timedelta(annot_df["onset"], unit="s") # type: ignore annot_df["timestamp"] = ref_timestamp + onset_tdelta # get the end point in seconds if end_sec is None: end_timestamp = annot_df["timestamp"].max() else: end_timestamp = ref_timestamp + timedelta(seconds=end_sec) # get end time in seconds annot_df["end_time"] = ( end_timestamp - ref_timestamp ).total_seconds() # type: ignore if verbose: print(f"Beginning timestamp: {ref_timestamp}") print(f"Got end timestamp of: {end_timestamp}") # set timestamp as the datetime index to allow resampling annot_df.set_index("timestamp", inplace=True) # type: ignore # get all unique channels if ch_names is None: ch_names = annot_df["channels"].unique() # type: ignore else: # search for channel names not inside pandas dataframe if not all([name in annot_df["channels"] for name in ch_names]): raise ValueError("Not all channels are inside the " "annotation DataFrame.") for idx, group in annot_df.groupby(["channels"]): # get channel name ch_name = group["channels"].values[0] if ch_name not in ch_names: # type: ignore continue # resample datetime indices over a certain frequency # so we can now count the number of HFO occurrences in a # set time frame # dt_idx = pd.date_range(ref_timestamp, end_timestamp, freq=rate) # group = group.reindex(dt_idx, fill_value=np.nan) # see Reference [1] where we compute rate of occurrence result = group.end_time.agg(lambda x: _to_freq(x, rate=rate)) if verbose: print(f"Found HFO rate per {rate} for {ch_name} as {result}") # now compute the rate in this group ch_hfo_rates[ch_name] = result # if not over_time: # ch_hfo_rates[ch_name] = ch_hfo_rates[ch_name].count() return ch_hfo_rates
def _join_times(df: pd.DataFrame) -> pd.DataFrame: """Join together start and end times sorted in order. Creates a second column ``what`` that marks +1/-1 for start/end times to keep track of how many intervals are overlapping. Then a ``newwin`` column is added to identify the beginning of a new non-overlapping time interval and a ``group`` column is added to mark the rows that belong to the same overlapping time interval. This ``group`` column is added to the original dataframe. Parameters ---------- df : pd.DataFrame Returns ------- res : pd.DataFrame References ---------- .. [1] https://stackoverflow.com/questions/57804145/combining-rows-with-overlapping-time-periods-in-a-pandas-dataframe # noqa """ startdf = pd.DataFrame( { # type: ignore "time": df["start_timestamp"], # type: ignore "what": 1, } ) # type: ignore enddf = pd.DataFrame( { # type: ignore "time": df["end_timestamp"], # type: ignore "what": -1, } ) # type: ignore # create merged dataframe of start and end times that are # sorted by timestamp mergdf = pd.concat([startdf, enddf]).sort_values("time") # get a running cumulative sum mergdf["running"] = mergdf["what"].cumsum() # type: ignore # assign groups to overlapping intervals mergdf["newwin"] = mergdf["running"].eq(1) & mergdf["what"].eq( # type: ignore 1 ) # type: ignore mergdf["group"] = mergdf["newwin"].cumsum() # type: ignore # add the group assignments to the original dataframe df["group"] = mergdf["group"].loc[mergdf["what"].eq(1)] # type: ignore # now group all overlapping intervals in the original dataframe # agg_func_dict = {col: lambda x: set(x) for col in df.columns} res = df.groupby("group").agg( { "start_timestamp": "first", "end_timestamp": "last", "label": "unique", "ref_timestamp": "first", } ) return res
[docs] def merge_overlapping_events(df: pd.DataFrame): """Merge overlapping events detected. Parameters ---------- df : pd.DataFrame Events dataframe generated from HFO events detected. Returns ------- merged_df : pd.DataFrame New events dataframe with merged HFOs depending on overlap criterion. See Also -------- mne_hfo.io.create_annotations_df : Create annots DataFrame from HFO detections. """ orig_cols = df.columns # check dataframe df = _check_df(df, df_type="annotations") # compute sfreq. XXX: assumes only 1 sampling rate sfreq = np.unique(df["sfreq"])[0] # start/end timestamp with current time for every row ref_timestamp = datetime.now(tz=timezone.utc) onset_tdelta = pd.to_timedelta(df["onset"], unit="s") # type: ignore df["start_timestamp"] = ref_timestamp + onset_tdelta duration_secs = pd.to_timedelta(df["duration"], unit="s") # type: ignore df["end_timestamp"] = df["start_timestamp"] + duration_secs df["ref_timestamp"] = ref_timestamp # first group by channels # now join rows that are overlapping merged_df = ( df.groupby(["channels"]) .apply(_join_times) # type: ignore .reset_index() .drop("group", axis=1) ) # get the old columns back and drop the intermediate computation columns merged_df["duration"] = ( merged_df["end_timestamp"] - merged_df["start_timestamp"] ).dt.total_seconds() merged_df["onset"] = ( merged_df["start_timestamp"] - merged_df["ref_timestamp"] ).dt.total_seconds() merged_df["sample"] = merged_df["onset"] * sfreq # XXX: need to enable different sfreqs maybe print(sfreq) print(merged_df) merged_df["sfreq"] = sfreq merged_df.drop( ["start_timestamp", "end_timestamp", "ref_timestamp"], axis=1, inplace=True ) merged_df = merged_df[orig_cols] return merged_df
[docs] def find_coincident_events(hfo_dict1, hfo_dict2): """ Get a dictionary of hfo events that overlap between two sets. Note: Both input dictionaries should come from the same original dataset and therefore contain the same keys. Parameters ---------- hfo_dict1 : dict Keys are channel names and values are list of tuples of start and end times. hfo_dict2 : dict Keys are channel names and values are list of tuples of start and end times. Returns ------- coincident_hfo_dict : dict Subset of hfo_dict1 containing just the entries that overlap with hfo_dict2. """ if set(hfo_dict1.keys()) != set(hfo_dict2.keys()): raise RuntimeError("The two dictionaries must have the same keys.") coincident_hfo_dict = {} for ch_name, hfo_list1 in hfo_dict1.items(): hfo_list2 = hfo_dict2.get(ch_name) coincident_hfo_list = _find_overlapping_events(hfo_list1, hfo_list2) coincident_hfo_dict.update({ch_name: coincident_hfo_list}) return coincident_hfo_dict
def _check_detection_overlap(y_true: List[float], y_predict: List[float]): """ Evaluate if two detections overlap. Parameters ---------- y_true: list Gold standard detection [start,stop] y_predict: list Detector detection [start,stop] Returns ------- overlap: bool Whether two events overlap. """ overlap = False # dd stop in gs + (dd inside gs) if (y_predict[1] >= y_true[0]) and (y_predict[1] <= y_true[1]): overlap = True # dd start in gs + (dd inside gs) if (y_predict[0] >= y_true[0]) and (y_predict[0] <= y_true[1]): overlap = True # gs inside dd if (y_predict[0] <= y_true[0]) and (y_predict[1] >= y_true[1]): overlap = True return overlap def _find_overlapping_events(list1, list2): """ Get subset of list1 that overlaps with list2. Parameters ---------- list1 : list list of tuples (start_time, end_time) list2 : list list of tuples (start_time, end_time) Returns ------- overlapping_events : list list of tuples (start_time, end_time) that overlap between list1 and list2. """ # Sort events by start times to speed up calculation list1 = sorted(list1, key=lambda x: x[0]) list2 = sorted(list2, key=lambda x: x[0]) overlapping_events = [] for event_time1 in list1: for event_time2 in list2: if event_time2[0] > event_time1[1]: break if _check_detection_overlap(event_time1, event_time2): overlapping_events.append(event_time1) return overlapping_events
[docs] def match_detected_annotations( ytrue_annot_df: pd.DataFrame, ypred_annot_df: pd.DataFrame, ch_names: Optional[Union[List[str], str]] = None, label: Optional[str] = None, sec_margin: float = 1.0, method="match-true", ): """Given two annotations.tsv DataFrames, match HFO detection overlaps. Parameters ---------- ytrue_annot_df : pd.DataFrame The reference annotations DataFrame containing the HFO events that are considered "ground-truth" in this comparison. ypred_annot_df : pd.DataFrame The estimated annotations DataFrame containing the HFO events that are estimated using a ``Detector``. ch_names : list | str | None Which channels to match. If None (default), then will match all available channels in both dataframes. If str, then must be a single channel name available in the ``ytrue_annot_df``. If list of strings, then must be a list of channel names available in the ``ytrue_annot_df``. label : str | None The HFO label to use. If None (default) will consider all rows in both input DataFrames as an HFO event. If a string, then it must match to an element of ``label`` column in the dataframes. sec_margin : float Number of seconds to consider a valid checking window. Default = 1. method : str Type of strategy for matching HFO events. Must be one of ``match-true``, ``match-pred``, or ``match-total``. If "match-true", will return a dataframe of all true indices and matching predicted indices if they exist. If "match-pred", will return a dataframe of all predicted indices and matching true indices if they exist. If "match-total", will return the concatenation of the two. See Notes for more information. Returns ------- matched_df : pd.DataFrame A DataFrame with the columns ``pred_index`` and ``true_index``, which corresponds to indices. """ # check adherence of the annotations dataframe structure ytrue_annot_df = _check_df(ytrue_annot_df, df_type="annotations") ypred_annot_df = _check_df(ypred_annot_df, df_type="annotations") # select only certain labels if label is not None: if label not in ytrue_annot_df["label"] or label not in ypred_annot_df["label"]: raise ValueError(f"Label {label} is not inside the input " f"DataFrames.") ytrue_annot_df = ytrue_annot_df.loc[ytrue_annot_df["label"] == label] ypred_annot_df = ypred_annot_df.loc[ypred_annot_df["label"] == label] # select only certain channels if ch_names is not None: if isinstance(ch_names, str): ch_names = [ch_names] if any([ch not in ytrue_annot_df["channels"] for ch in ch_names]): raise ValueError( f"Channels {ch_names} are not all inside " f"ground-truth HFO DataFrame." ) if any([ch not in ypred_annot_df["channels"] for ch in ch_names]): raise ValueError( f"Channels {ch_names} are not all inside " f"predicted HFO DataFrame." ) ytrue_annot_df = ytrue_annot_df.loc[ytrue_annot_df["channels"].isin(ch_names)] ypred_annot_df = ypred_annot_df.loc[ypred_annot_df["channels"].isin(ch_names)] # if prediction yields no events and method is match-pred, # return empty structured dataframe if ypred_annot_df.empty and method == "match-pred": return pd.DataFrame(columns=("true_index", "pred_index")) # else if prediction yields no events, return structured dataframe # containing just true indices elif ypred_annot_df.empty: match_df = pd.DataFrame(columns=("true_index", "pred_index")) for ind, row in ytrue_annot_df.iterrows(): match_df.loc[ind] = [ind, None] match_df.apply(pd.to_numeric, errors="coerce", downcast="float") return match_df # make sure columns match what is needed ytrue_annot_df["offset"] = ytrue_annot_df["onset"] + ytrue_annot_df["duration"] ypred_annot_df["offset"] = ypred_annot_df["onset"] + ypred_annot_df["duration"] if method.lower() == "match-true": return _match_detections_overlap( ytrue_annot_df, ypred_annot_df, sec_margin, ("true_index", "pred_index") ) elif method.lower() == "match-pred": return _match_detections_overlap( ypred_annot_df, ytrue_annot_df, sec_margin, ("pred_index", "true_index") ) elif method.lower() == "match-total": true_match = _match_detections_overlap( ytrue_annot_df, ypred_annot_df, sec_margin, ("true_index", "pred_index") ) pred_match = _match_detections_overlap( ypred_annot_df, ytrue_annot_df, sec_margin, ("pred_index", "true_index") ) return ( pd.concat([true_match, pred_match]).drop_duplicates().reset_index(drop=True) ) else: raise NotImplementedError( "Method must be one of match-true," " match-pred, or match-total" )
# Iterate over true labels (gold standard) def _match_detections_overlap(gs_df, check_df, margin, cols): """ Find the overlapping detections in the two passed dataframes. gs_df and check_df need to be the same type (i.e. both annotation dataframes or event dataframes). If they are annotation dataframes, margin should be in seconds, and if they are event dataframes, margin should be in samples. Parameters ---------- gs_df : pd.DataFrame The reference DataFrame containing the HFO events that are considered "ground-truth" in this comparison. check_df : pd.DataFrame The estimated DataFrame containing the HFO events that are estimated using a ``Detector``. margin : int Margin to check. Should be in the same unit as the data in the desired columns cols : list[str] Name of the columns corresponding to gs indices and check indices Returns ------- match_df: pd.DataFrame A DataFrame with the columns from cols input, which corresponds to indices """ if not all([col in gs_df for col in ["onset", "offset"]]): raise ValueError( f"Gold standard reference Annotations " f'DataFrame must have both "onset" and ' f'"offset" columns (in seconds). It ' f"has columns: {gs_df.columns}" ) if not all([col in check_df for col in ["onset", "offset"]]): raise ValueError( f"Estimated Annotations " f'DataFrame must have both "onset" and ' f'"offset" columns (in seconds).It ' f"has columns: {check_df.columns}" ) # List of tuples to populate the output DataFrame match_indices = [] # Convert the DataFrames that are expensive to manipulate into a list # of tuples (index, onset, offset, ch_name) # Pandas does not care about column order, but since we are changing # the DataFrames to numpy, we need to track the column order gs_cols = gs_df.columns check_cols = check_df.columns gs_keep_inds = ( gs_cols.get_loc("onset"), gs_cols.get_loc("offset"), gs_cols.get_loc("channels"), ) check_keep_inds = ( check_cols.get_loc("onset"), check_cols.get_loc("offset"), check_cols.get_loc("channels"), ) gs_numpy = gs_df.to_numpy()[:, gs_keep_inds] gs_numpy = [ [i, onset, offset, ch_name] for i, (onset, offset, ch_name) in enumerate(gs_numpy) ] check_numpy = check_df.to_numpy()[:, check_keep_inds] check_numpy = [ [i, onset, offset, ch_name] for i, (onset, offset, ch_name) in enumerate(check_numpy) ] # TODO: If there is a way to subset by channel, we can speed # up the loop # Now we can iterate for gs_hfo in gs_numpy: gs_ind, gs_onset, gs_offset, gs_ch_name = gs_hfo check_window = (gs_onset - margin, gs_onset + margin) # Subset to the same channel and has onset within the expected window check_numpy_channel = [ x for x in check_numpy if ( x[3] == gs_ch_name and (x[1] > check_window[0] or x[1] < check_window[1]) ) ] # check if nothing meets this criteria if not check_numpy_channel: match_indices.append((gs_ind, None)) continue potential_matches = [] # else, see if there is overlap for check_hfo in check_numpy_channel: check_ind, check_onset, check_offset, check_ch_name = check_hfo gs_win = (gs_onset, gs_offset) check_win = (check_onset, check_offset) if _check_detection_overlap(gs_win, check_win): potential_matches.append(check_hfo) if not potential_matches: match_indices.append((gs_ind, None)) elif len(potential_matches) == 1: match_indices.append((gs_ind, potential_matches[0][0])) else: # more than one match, find closest match_indices.append(_find_best_overlap(gs_hfo, potential_matches)) if not match_indices: match_df = pd.DataFrame(columns=cols) else: match_df = pd.DataFrame(match_indices, columns=cols).apply( pd.to_numeric, errors="coerce", downcast="float" ) return match_df def _find_best_overlap(gs, check_list): """Find best overlap from an ideal (gs) and a possible list.""" gs_ind, gs_onset, gs_offset, _ = gs gs_point = np.array([gs_onset, gs_offset]) dist = np.inf best_inds = (gs_ind, None) for check_hfo in check_list: check_ind, check_onset, check_offset, _ = check_hfo check_point = np.array([check_onset, check_offset]) # Using distance of the points as the metric new_dist = np.linalg.norm(gs_point - check_point) if new_dist < dist: dist = new_dist best_inds = (gs_ind, check_ind) return best_inds