vignettes/plot_machine_learning_evoked.Rmd
plot_machine_learning_evoked.Rmd
In this example we are going to use the machine learning functionality of MNE-Python and scikit-learn to analyze evoked responses.
## Importing MNE version=0.18.dev0, path='/Users/dengeman/github/mne-python/mne'
Let’s read in the raw data.
data_path <- mne$datasets$sample$data_path()
subject <- "sample"
raw_fname <- paste(data_path,
'MEG',
subject,
'sample_audvis_filt-0-40_raw.fif',
sep = '/')
Let’s read in and preprocess the data.
raw <- mne$io$read_raw_fif(raw_fname, preload = T)
# Band pass filtering signalss
raw$filter(1, 30, fir_design = "firwin")
## <Raw | sample_audvis_filt-0-40_raw.fif, n_channels x n_times : 376 x 41700 (277.7 sec), ~123.3 MB, data loaded>
events <- mne$find_events(raw)
storage.mode(events) <- "integer" # R gets the events as floats.
tmin <- -0.2
tmax <- 0.5
baseline <- reticulate::tuple(NULL, 0)
reject <- list(mag = 5e-12)
event_id <- list("aud/l" = 1L, "vis/l" = 3L)
picks <- mne$pick_types(raw$info, meg = T, exclude = 'bads') %>%
as.integer() # make sure it's int
decim = 2L # use every 2nd time sample
epochs <- mne$Epochs(raw = raw, events = events,
event_id = event_id, tmin = tmin,
tmax = tmax, picks = picks, proj = T,
baseline = baseline, reject = NULL,
decim = decim, preload = T)
We can set up the model. For this we will need to do a few imports from scikit-learn.
preproc <- reticulate::import("sklearn.preprocessing")
pipeline <- reticulate::import("sklearn.pipeline")
linear_model <- reticulate::import("sklearn.linear_model")
clf <- pipeline$make_pipeline(
preproc$StandardScaler(),
linear_model$LogisticRegression(solver = "lbfgs"))
time_decod <- mne$decoding$SlidingEstimator(
clf, scoring = "roc_auc", n_jobs = 1L)
# MEG signals: n_epochs, n_meg_channels, n_times
X <- epochs$get_data()
# R vs Python indexing!
y <- epochs$events[, 2 + 1]
# recode and make sure it's an int.
y <- ifelse(y == 1, 0, 3) %>% as.integer()
scores <- mne$decoding$cross_val_multiscore(time_decod, X, y,
cv = 5L, n_jobs = 1L)
Time for action with R to plot the outputs!
# first we add the time info after rounding
colnames(scores) <- round(epochs$times * 1e3)
# we create a long table
scores_df <- scores %>%
as.data.frame() %>%
gather(key = "time", value = "score")
# and add the info about the folds.
scores_df$fold <- sprintf("fold %s", 1:5) # R expands it magically ...
ggplot(
data = scores_df,
mapping = aes(x = as.numeric(time), y = score)) +
geom_line(mapping = aes(group = fold), alpha = 0.4) +
stat_summary(fun.y= "mean", geom = "smooth", size = 1.3,
color = "red") +
geom_hline(yintercept = 0.5, color = "black", linetype="dashed") +
theme_minimal() +
theme(text = element_text(size = 18, family = "Helvetica")) +
labs(y = "AUC score", x = "time [ms]")
We can see that the cross-validation uncertainty is lowest around 150 milliseconds.