Tutorial on Computing HFOs (Part 2)¶
In this tutorial, we will walk through how to compute HFOs on a sample dataset that is defined in [1].
We will demonstrate usage of the following detectors:
Line Length detector
RMS detector
Morphology detector (used in the paper)
Dataset Preprocessing¶
Note that the data has been converted to BIDS to facilitate easy loading using mne-bids
package. Another thing to note is that the authors in this dataset reported HFOs detected using bipolar montage. In addition, they only analyzed HFOs for a subset of the recording channels.
In order to compare results to a monopolar reference, we define an HFO to be “found” if there was an HFO in either of the corresponding bipolar contacts.
References¶
[1] Fedele T, Burnos S, Boran E, Krayenbühl N, Hilfiker P, Grunwald T, Sarnthein J. Resection of high frequency oscillations predicts seizure outcome in the individual patient. Scientific Reports. 2017;7(1):13836. https://www.nature.com/articles/s41598-017-13064-1 doi:10.1038/s41598-017-13064-1
[1]:
# first let's load in all our packages
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import re
import pandas as pd
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV
from mne_bids import (read_raw_bids, BIDSPath,
get_entity_vals, get_datatypes,
make_report)
from mne_bids.stats import count_events
import mne
from mne import make_ad_hoc_cov
basepath = os.path.join(os.getcwd(), "../..")
sys.path.append(basepath)
from mne_hfo import LineLengthDetector, RMSDetector
from mne_hfo.score import _compute_score_data, accuracy
from mne_hfo.sklearn import make_Xy_sklearn, DisabledCV
1 Working with Real Data¶
We are now going to work with the dataset from Fedele et al. linked above
1.1 Load in Real Data¶
1.1.1 Define dataset paths and load the data¶
The data is assumed to be in BIDS format. We have converted the dataset into BIDS, which you can load using mne-bids.
[2]:
# this may change depending on where you store the data
root = '/Users/patrick/Dropbox/fedele_hfo_data'
[3]:
# print a boiler plate summary report using mne-bids
report = make_report(root, verbose=False)
print(report)
Summarizing participants.tsv /Users/patrick/Dropbox/fedele_hfo_data/participants.tsv...
The iEEG Interictal Asleep HFO Dataset dataset was created with BIDS version
1.4.0 by Fedele T, Burnos S, Boran E, Krayenbühl N, Hilfiker P, Grunwald T, and
Sarnthein J.. This report was generated with MNE-BIDS
(https://doi.org/10.21105/joss.01896). The dataset consists of 20 participants
(comprised of 13 men and 6 women; handedness were all unknown; ages ranged from
17.0 to 52.0 (mean = 32.47, std = 11.43; 1 with unknown age))and 1 recording
sessions: interictalsleep. Data was recorded using a iEEG system (Neuralynx
manufacturer) sampled at 2000.0 Hz with line noise at 50.0 Hz using Sampling
with parameters 2000 Downsampled (Hz). There were 385 scans in total. Recording
durations ranged from 204.0 to 720.0 seconds (mean = 302.44, std = 37.85), for a
total of 116438.81 seconds of data recorded over all scans. For each dataset,
there were on average 51.88 (std = 15.55) recording channels per scan, out of
which 51.88 (std = 15.55) were used in analysis (0.0 +/- 0.0 were removed from
analysis).
[4]:
# there are iEEG datatypes throughout the dataset
datatypes = get_datatypes(root)
# get all the subjects, sessions
subjects = get_entity_vals(root, 'subject')
sessions = get_entity_vals(root, 'session')
[5]:
# construct BIDSPath for dataset we will demo
subjectID = subjects[0]
sessionID = sessions[0]
bids_path = BIDSPath(subject=subjectID, session=sessionID,
datatype='ieeg',
suffix='ieeg',
extension='.vhdr', root=root)
# get all matching datasets
fpaths = bids_path.match()
print(f'Found {len(fpaths)} matching filepaths for this subject and session.')
Found 28 matching filepaths for this subject and session.
[6]:
# analyze the first run
dataset_path = fpaths[0]
print(dataset_path)
# count the different HFO types (fr, ripple, frandr) that were detected using
# the published Morphology detector
counts_df = count_events(dataset_path)
display(counts_df)
/Users/patrick/Dropbox/fedele_hfo_data/sub-01/ses-interictalsleep/ieeg/sub-01_ses-interictalsleep_run-01_ieeg.vhdr
NaN | |||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
trial_type | fr_AHR1-2 | fr_AHR2-3 | fr_AHR3-4 | fr_AL1-2 | fr_AL2-3 | fr_AL3-4 | fr_AR1-2 | fr_AR2-3 | fr_AR3-4 | fr_HL1-2 | ... | ripple_IAR2-3 | ripple_IAR3-4 | ripple_IAR4-5 | ripple_IAR5-6 | ripple_IPR1-2 | ripple_IPR2-3 | ripple_IPR3-4 | ripple_PHR1-2 | ripple_PHR2-3 | ripple_PHR3-4 | ||
subject | session | run | |||||||||||||||||||||
01 | interictalsleep | 01 | 15 | 23 | 12 | 15 | 19 | 20 | 25 | 28 | 11 | 9 | ... | 351 | 223 | 179 | 238 | 12 | 32 | 128 | 121 | 111 | 64 |
1 rows × 60 columns
[7]:
%%capture
# load dataset into mne Raw object
extra_params = dict(preload=True)
raw = read_raw_bids(dataset_path, extra_params)
[ ]:
# Optional - Change backend to interactive for below plot
# Allows you to scroll dataset, highlight (or grey out) certain channels, zoom in/out, etc.
%matplotlib notebook
[8]:
# Plot the raw data with vertical colorbars to denote where HFOs were detected
raw_plot = raw.plot()
raw_plot.show()
print('plotting channels with HFO events detected in '
'the original publication in color.')
plotting channels with HFO events detected in the original publication in color.
<ipython-input-8-85590d0d121e>:3: UserWarning: Matplotlib is currently using module://ipykernel.pylab.backend_inline, which is a non-GUI backend, so cannot show the figure.
raw_plot.show()
[ ]:
# Optional - Change back to regular plots
%matplotlib inline
1.1.2 Convert to bipolar referencing scheme¶
The Fedele paper seems to use bipolar referenced channels, so we do our best to compare here
[9]:
def convert_to_bipolar(raw, drop_originals=True):
original_ch_names = raw.ch_names
ch_names_sorted = sorted(original_ch_names)
ch_pairs = []
for first, second in zip(ch_names_sorted, ch_names_sorted[1:]):
firstName = re.sub(r'[0-9]+', '', first)
secondName = re.sub(r'[0-9]+', '', second)
if firstName == secondName:
ch_pairs.append((first,second))
for ch_pair in ch_pairs:
raw = mne.set_bipolar_reference(raw, ch_pair[0], ch_pair[1], drop_refs=False)
if drop_originals:
raw = raw.drop_channels(original_ch_names)
return raw
[10]:
%%capture
raw = convert_to_bipolar(raw)
1.1.3 Load Annotated HFOs¶
[11]:
# All annotated HFO events for this file
annotations = raw.annotations
[12]:
# The fedele bipolar names use the scheme CH#-#, but mne-bipolar uses the scheme CH#-CH#. Reconstructing
# the fedele names to match mne names
def reconstruct_channel_name_to_mne(ch_name):
ch_split = ch_name.split("-")
cont_name = re.sub(r'[0-9]+', '', ch_split[0])
ch_name_mne = f"{ch_split[0]}-{cont_name}{ch_split[1]}"
return ch_name_mne
# You can also go the other way around. If we convert the mne names to the fedele names, you can use
# the mne.io.Raw.rename_channels function
def reconstruct_mne_channel_name_to_fedele(ch_name):
ch_split = ch_name.split("-")
cont_name = re.sub(r'[0-9]+', '', ch_split[0])
ch_name_fedele = f"{ch_split[0]-{ch_split[1].replace(cont_name, '')}}"
return ch_name_fedele
[13]:
# Convert to convenient data structure (pandas DF)
column_names = ["onset", "duration", "sample", "label", "channels"]
sfreq = raw.info["sfreq"]
rows = []
for annot in annotations:
onset = float(annot.get("onset"))
duration = float(annot.get("duration"))
sample = onset * sfreq
trial_type = annot.get("description").split("_")[0]
ch_name = annot.get("description").split("_")[1]
ch_name = reconstruct_channel_name_to_mne(ch_name)
annot_row = [onset, duration, sample, trial_type, ch_name]
rows.append(annot_row)
gs_df = pd.concat([pd.DataFrame([row], columns=column_names) for row in rows],
ignore_index=True)
[14]:
# for now, lets just look at ripple events:
gs_df_ripple = gs_df[gs_df['label'].str.contains("ripple")]
gs_df_ripple = gs_df_ripple.dropna()
gs_df_ripple.reset_index(drop=True, inplace=True)
1.2 Detect HFOs¶
1.2.1 Line Length Detector¶
[15]:
# Set Key Word Arguments for the Line Length Detector and generate the class object
kwargs = {
'filter_band': (80, 250), # (l_freq, h_freq)
'threshold': 3, # Number of st. deviations
'win_size': 100, # Sliding window size in samples
'overlap': 0.25, # Fraction of window overlap [0, 1]
'hfo_name': "ripple"
}
ll_detector = LineLengthDetector(**kwargs)
[16]:
%%capture
# Detect HFOs in the raw data using the LineLengthDetector method.
# Return the class object with HFOs added
ll_detector = ll_detector.fit(raw)
# Dictionary where keys are channel index and values are a list of tuples in the form of (start_samp, end_samp)
ll_chs_hfo_dict = ll_detector.chs_hfos_
# nCh x nWin ndarray where each value is the line-length of the data window per channel
ll_hfo_event_array = ll_detector.hfo_event_arr_
# Pandas dataframe containing onset, duration, sample trial, and trial type per HFO
ll_hfo_df = ll_detector.df_
1.2.2 RMS Detector¶
[17]:
# Set Key Word Arguments for the RMS Detector and generate the class object
kwargs = {
'filter_band': (80, 250),
'threshold': 3,
'win_size': 100,
'overlap': 0.25,
'hfo_name': 'ripple',
}
rms_detector = RMSDetector(**kwargs)
[18]:
%%capture
# Detect HFOs in the raw data using the RMSDetector method.
rms_detector = rms_detector.fit(raw)
rms_chs_hfo_dict = rms_detector.chs_hfos_
rms_hfo_event_array = rms_detector.hfo_event_arr_
rms_hfo_df = rms_detector.df_
1.3 Compare Results¶
1.3.1 Find matches¶
Now that our dataframes are in the same format, we can compare them. We will simply look at the matches for ripples, since that is what the the detectors looked for. We will compute, for each detection, the accuracy, precision, true positive rate, false negative rate, and false discovery rate.
[19]:
def scores_to_df(score_dict):
df = pd.DataFrame(columns=['detector', 'accuracy', 'true positive rate', 'precision', 'false negative rate', 'false discovery rate'])
for detector_name, scores in score_dict.items():
to_append = [detector_name]
[to_append.append(str(score)) for score in scores]
append_series = pd.Series(to_append, index = df.columns)
df = df.append(append_series, ignore_index=True)
return df
[20]:
# Note: Since we are computing every score at once, we take a shortcut by calling the internal
# function _compute_score_data, which gives the number of true positives, false positives,
# and false negatives. There are no true negatives in this dataset
# Gold standard vs LineLengthDetector
scores_dict = {}
tp, fp, fn = _compute_score_data(gs_df, ll_hfo_df, method="match-total")
acc_ll = tp / (tp + fp + fn)
tpr_ll = tp / (tp + fn)
prec_ll = tp / (tp + fp)
fnr_ll = fn / (fn + tp)
fdr_ll = fp / (fp + tp)
scores_dict["LineLengthDetector"] = [acc_ll, tpr_ll, prec_ll, fnr_ll, fdr_ll]
# Gold standard vs RMSDetector
tp, fp, fn = _compute_score_data(gs_df, rms_hfo_df, method="match-total")
acc_rms = tp / (tp + fp + fn)
tpr_rms = tp / (tp + fn)
prec_rms = tp / (tp + fp)
fnr_rms = fn / (fn + tp)
fdr_rms = fp / (fp + tp)
scores_dict["RMSDetector"] = [acc_rms, tpr_rms, prec_rms, fnr_rms, fdr_rms]
scores_df = scores_to_df(scores_dict)
scores_df
[20]:
detector | accuracy | true positive rate | precision | false negative rate | false discovery rate | |
---|---|---|---|---|---|---|
0 | LineLengthDetector | 0.2599502487562189 | 0.4336099585062241 | 0.3935969868173258 | 0.5663900414937759 | 0.6064030131826742 |
1 | RMSDetector | 0.26991030171242186 | 0.45749827228749135 | 0.39696182290625626 | 0.5425017277125086 | 0.6030381770937437 |
2 Optimizing the Detectors¶
The above detectors did decently well, but the hyperparameters were randomly set. Let’s walk through the procedure for optimizing the hyperparameters based using GridSearch Cross Validation on the LineLengthDetector
2.1 Set up the data¶
SKlearn requires some changes to the input data and true labels in order for the procedure to function. We provide some helper functions to assist in the data conversion
[21]:
raw_df, y = make_Xy_sklearn(raw, gs_df_ripple)
2.2 Set up the GridSearchCV function¶
We will be testing three possible threshold values and three possible win_size values, for a total of 9 tests. Accuracy will be the only test used for speed purposes, but multiple scoring functions can be passed in at once.
[22]:
scorer = accuracy
parameters = {'threshold': [1, 2, 3], 'win_size': [50, 100, 250]}
kwargs = {
'filter_band': (80, 250),
'overlap': 0.25,
'hfo_name': 'ripple',
}
detector = LineLengthDetector(**kwargs)
scorer = make_scorer(scorer)
cv = DisabledCV()
gs = GridSearchCV(detector, param_grid=parameters, scoring=scorer, cv=cv,
verbose=True)
2.3 Perform the Search and Print Output¶
[23]:
%%time
%%capture
gs.fit(raw_df, y, groups=None)
CPU times: user 2min 45s, sys: 10.9 s, total: 2min 55s
Wall time: 6min 7s
[24]:
# Nicely display the output
pd.concat([pd.DataFrame(gs.cv_results_["params"]),pd.DataFrame(gs.cv_results_["mean_test_score"], columns=["Accuracy"])],axis=1)
[24]:
threshold | win_size | Accuracy | |
---|---|---|---|
0 | 1 | 50 | 0.000000 |
1 | 1 | 100 | 0.000000 |
2 | 1 | 250 | 0.136080 |
3 | 2 | 50 | 0.000000 |
4 | 2 | 100 | 0.152935 |
5 | 2 | 250 | 0.225537 |
6 | 3 | 50 | 0.205174 |
7 | 3 | 100 | 0.238337 |
8 | 3 | 250 | 0.232798 |