Source code for mne_hfo.detect

from typing import Optional, Tuple, Union

import mne
import numpy as np

from .base import Detector
from .config import ACCEPTED_BAND_METHODS


[docs] class HilbertDetector(Detector): # noqa """2D HFO hilbert detection used in Kucewicz et al. 2014. A multi-taper method with: 4 Hz bandwidth, 1 sec sliding window, stepsize 100 ms, for the 1-500 Hz range, no padding, 2 tapers. For full details, see :footcite:`kucewicz2014high`. Parameters ---------- threshold : float Threshold for detection (default=3). filter_band : tuple(float, float) Low cut-off frequency at index 0 and high cut-off frequency at index 1. The default is ``(30, 100)``. band_method : str Spacing of hilbert frequency bands - options: 'linear' or 'log' (default='linear'). Linear provides better frequency resolution but is slower. n_bands : int Number of bands if band_spacing = log (default=300). cycle_threshold : float Minimum number of cycles to detect (default=1). gap_threshold : float Number of cycles for gaps (default=1). n_jobs : int Number of cores to use (default=1). offset : int Offset which is added to the final detection. This is used when the function is run in separate windows. Default = 0. scoring_func : str The scoring function to apply when trying to match HFOs with a different dataset, such as manual annotations. hfo_name : str What to name the events detected (i.e. fast ripple if freq_band is (250, 500)). verbose : bool Verbosity of the detector. References ---------- .. footbibliography:: """ def __init__( self, threshold: Union[int, float] = 3, filter_band: Tuple[int, int] = (30, 100), band_method: str = "linear", n_bands: int = 300, cycle_threshold: float = 1, gap_threshold: float = 1, n_jobs: int = -1, offset: int = 0, scoring_func: str = "f1", hfo_name: str = "hfo", verbose: bool = False, ): if band_method not in ACCEPTED_BAND_METHODS: raise ValueError( f"Band method {band_method} is not " f"an acceptable parameter. Please use " f"one of {ACCEPTED_BAND_METHODS}" ) super(HilbertDetector, self).__init__( threshold, win_size=1, overlap=1, scoring_func=scoring_func, n_jobs=n_jobs, verbose=verbose, ) self.band_method = band_method self.n_bands = n_bands self.filter_band = filter_band self.hfo_name = hfo_name self.cycle_threshold = cycle_threshold self.gap_threshold = gap_threshold self.n_jobs = n_jobs self.offset = offset @property def l_freq(self): """Lower frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[0] @property def h_freq(self): """Higher frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[1] def _create_empty_event_arr(self): """Override ``Detector._create_empty_event_arr`` function. Also sets the frequency span of the Hilbert detector. """ # Determine the splits for freq bands if self.band_method == "log": low_fc = float(self.filter_band[0]) high_fc = float(self.filter_band[1]) freq_cutoffs = np.logspace(0, np.log10(high_fc), self.n_bands) self.freq_cutoffs = freq_cutoffs[ (freq_cutoffs > low_fc) & (freq_cutoffs < high_fc) ] self.freq_span = len(self.freq_cutoffs) - 1 elif self.band_method == "linear": self.freq_cutoffs = np.arange(self.filter_band[0], self.filter_band[1]) self.freq_span = (self.filter_band[1] - self.filter_band[0]) - 1 n_windows = self.n_times n_bands = len(self.freq_cutoffs) - 1 hfo_event_arr = np.empty((self.n_chs, n_bands, n_windows)) return hfo_event_arr def _compute_hfo_statistic(self, X): """Override ``Detector._compute_hfo_statistic`` function.""" # Override the attribute set by fit so we actually slide on freq # bands not time windows self.n_windows = self.n_bands self.win_size = 1 self.n_times = len(X) hfo_event_arr = self._compute_frq_band_detection(X, method="hilbert") return hfo_event_arr def _threshold_statistic(self, X): """Override ``Detector._threshold_statistic`` function.""" hfo_threshold_arr = np.transpose( np.array( self._apply_threshold(X, threshold_method="hilbert"), dtype="object" ) ) return hfo_threshold_arr def _post_process_ch_hfos(self, detections): """Override ``Detector._post_process_ch_hfos`` function.""" hfo_events = self._merge_contiguous_ch_detections( detections, method="freq-bands" ) return hfo_events
[docs] class LineLengthDetector(Detector): """Line-length detection algorithm. Original paper defines HFOS as: "(HFOs), which we collectively term as all activity >40 Hz (including gamma, high-gamma, ripple, and fast ripple oscillations), may have a fundamental role in the generation and spread of focal seizures." See :footcite:`gardner2007human`. In the paper, data were sampled at 200 Hz and bandpass-filtered (0.1 – 100 Hz) during acquisition. Data were further digitally bandpass-filtered (4th-order Butterworth, forward-backward filtering, ``0.1 – 85 Hz``) to minimize potential artifacts due to aliasing. (IIR for forward-backward pass). Compared to RMS detector, they utilize line-length metric :footcite:`esteller2001line`. Parameters ---------- threshold : float Number of standard deviations to use as a threshold. win_size : int Sliding window size in samples. overlap : float Fraction of the window overlap (0 to 1). sfreq : int | None The sampling rate of the data. filter_band : tuple(float, float) Low cut-off frequency at index 0 and high cut-off frequency at index 1. The default is ``(30, 100)``. scoring_func : str The scoring function to apply when trying to match HFOs with a different dataset, such as manual annotations. n_jobs : int The number of jobs for joblib parallelization. hfo_name : str What to name the events detected (i.e. fast ripple if freq_band is (250, 500)). verbose : bool Verbosity of the detector. Notes ----- For processing, a sliding window is used. For post-processing, any events that overlap are considered to be the same. References ---------- .. footbibliography:: """ def __init__( self, threshold: Union[int, float] = 3, win_size: int = 100, overlap: float = 0.25, sfreq: Optional[int] = None, filter_band: Tuple[int, int] = (30, 100), scoring_func: str = "f1", n_jobs: int = -1, hfo_name: str = "hfo", verbose: bool = False, ): super(LineLengthDetector, self).__init__( threshold, win_size=win_size, overlap=overlap, scoring_func=scoring_func, n_jobs=n_jobs, verbose=verbose, ) self.filter_band = filter_band self.sfreq = sfreq self.hfo_name = hfo_name @property def l_freq(self): """Lower frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[0] @property def h_freq(self): """Higher frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[1] def _compute_hfo_statistic(self, X): """Override ``Detector._compute_hfo_statistic`` function.""" # store all hfo occurrences as an array of length windows # bandpass the signal using FIR filter if self.filter_band is not None: X = mne.filter.filter_data( X, sfreq=self.sfreq, l_freq=self.l_freq, h_freq=self.h_freq, method="iir", verbose=self.verbose, ) hfo_event_arr = self._compute_sliding_window_detection(X, method="line_length") # reshape array to be n_wins x n_bands (i.e. 1) n_windows = self._compute_n_wins(self.win_size, self.step_size, self.n_times) n_bands = len(self.freq_cutoffs) - 1 shape = (n_windows, n_bands) hfo_event_arr = np.array(hfo_event_arr).reshape(shape) return hfo_event_arr def _threshold_statistic(self, X): """Override ``Detector._threshold_statistic`` function.""" hfo_threshold_arr = self._apply_threshold(X, threshold_method="std") return hfo_threshold_arr def _post_process_ch_hfos(self, detections): """Override ``Detector._post_process_ch_hfos`` function.""" return self._merge_contiguous_ch_detections(detections, method="time-windows")
[docs] class RMSDetector(Detector): """Root mean square (RMS) detection algorithm (Staba Detector). The original algorithm described in the reference, takes a sliding window of 3 ms, computes the RMS values between 100 and 500 Hz. Then events separated by less than 10 ms were combined into one event. Then events not having a minimum of 6 peaks (i.e. band-pass signal rectified above 0 V) with greater then 3 std above mean baseline were removed. A finite impulse response (FIR) filter with a Hamming window was used. See :footcite:`staba2002quantitative`. Parameters ---------- threshold : float Number of standard deviations to use as a threshold. win_size : int Sliding window size in samples. overlap : float Fraction of the window overlap (0 to 1). sfreq : int | None The sampling rate of the data. filter_band : tuple(float, float) Low cut-off frequency at index 0 and high cut-off frequency at index 1. The default is ``(100, 500)``. scoring_func : str The scoring function to apply when trying to match HFOs with a different dataset, such as manual annotations. n_jobs : int The number of jobs for joblib parallelization. hfo_name : str What to name the events detected (i.e. fast ripple if freq_band is (250, 500)). verbose : bool Verbosity of the detector. References ---------- .. footbibliography:: """ def __init__( self, threshold: Union[int, float] = 3, win_size: int = 100, overlap: float = 0.25, sfreq=None, filter_band: Tuple[int, int] = (100, 500), scoring_func="f1", n_jobs: int = -1, hfo_name: str = "hfo", verbose: bool = False, ): super(RMSDetector, self).__init__( threshold, win_size, overlap, scoring_func, n_jobs=n_jobs, verbose=verbose ) # hyperparameters self.filter_band = filter_band self.sfreq = sfreq self.hfo_name = hfo_name @property def l_freq(self): """Lower frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[0] @property def h_freq(self): """Higher frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[1] def _compute_hfo_statistic(self, X): """Override ``Detector._compute_hfo`` function.""" # store all hfo occurrences as an array of length windows if self.l_freq is not None or self.h_freq is not None: # bandpass the signal using FIR filter X = mne.filter.filter_data( X, sfreq=self.sfreq, l_freq=self.l_freq, h_freq=self.h_freq, method="fir", verbose=self.verbose, ) hfo_event_arr = self._compute_sliding_window_detection(X, method="rms") # reshape array to be n_wins x n_bands (i.e. 1) n_windows = self._compute_n_wins(self.win_size, self.step_size, self.n_times) n_bands = len(self.freq_cutoffs) - 1 shape = (n_windows, n_bands) hfo_event_arr = np.array(hfo_event_arr).reshape(shape) return hfo_event_arr def _threshold_statistic(self, X): """Override ``Detector._threshold_statistic`` function.""" hfo_threshold_arr = self._apply_threshold(X, threshold_method="std") return hfo_threshold_arr def _post_process_ch_hfos(self, detections): """Override ``Detector._post_process_ch_hfos`` function.""" return self._merge_contiguous_ch_detections(detections, method="time-windows")