Source code for mne_realtime.fieldtrip_client

# Author: Mainak Jas
#
# License: BSD (3-clause)

import copy
import re
import threading
import time

import numpy as np

from mne import create_info, pick_info
from mne.io.pick import _picks_to_idx
from mne.io.constants import FIFF
from mne.epochs import EpochsArray
from mne.utils import logger, warn, fill_doc
from .externals.FieldTrip import Client as FtClient


def _buffer_recv_worker(ft_client):
    """Worker thread that constantly receives buffers."""
    try:
        for raw_buffer in ft_client.iter_raw_buffers():
            ft_client._push_raw_buffer(raw_buffer)
    except RuntimeError as err:
        # something is wrong, the server stopped (or something)
        ft_client._recv_thread = None
        logger.error('Buffer receive thread stopped: %s' % err)


[docs] @fill_doc class FieldTripClient(object): """Realtime FieldTrip client. Parameters ---------- info : dict | None The measurement info read in from a file. If None, it is guessed from the Fieldtrip Header object. host : str Hostname (or IP address) of the host where Fieldtrip buffer is running. port : int Port to use for the connection. wait_max : float Maximum time (in seconds) to wait for Fieldtrip buffer to start tmin : float | None Time instant to start receiving buffers. If None, start from the latest samples available. tmax : float Time instant to stop receiving buffers. buffer_size : int Size of each buffer in terms of number of samples. %(verbose)s Notes ----- This software uses the FieldTrip buffer open source library. See http:/www.fieldtriptoolbox.org for details. The FieldTrip buffer is used under the BSD 3-Clause License. """ def __init__(self, info=None, host='localhost', port=1972, wait_max=30, tmin=None, tmax=np.inf, buffer_size=1000, verbose=None): # noqa: D102 self.verbose = verbose self.info = info self.wait_max = wait_max self.tmin = tmin self.tmax = tmax self.buffer_size = buffer_size self.host = host self.port = port self._recv_thread = None self._recv_callbacks = list() def __enter__(self): # noqa: D105 # instantiate Fieldtrip client and connect self.ft_client = FtClient() # connect to FieldTrip buffer logger.info("FieldTripClient: Waiting for server to start") start_time, current_time = time.time(), time.time() success = False while current_time < (start_time + self.wait_max): try: self.ft_client.connect(self.host, self.port) logger.info("FieldTripClient: Connected") success = True break except Exception: current_time = time.time() time.sleep(0.1) if not success: raise RuntimeError('Could not connect to FieldTrip Buffer') # retrieve header logger.info("FieldTripClient: Retrieving header") start_time, current_time = time.time(), time.time() while current_time < (start_time + self.wait_max): self.ft_header = self.ft_client.getHeader() if self.ft_header is None: current_time = time.time() time.sleep(0.1) else: break if self.ft_header is None: raise RuntimeError('Failed to retrieve Fieldtrip header!') else: logger.info("FieldTripClient: Header retrieved") self.info = self._create_info() self.ch_names = self.ft_header.labels # find start and end samples sfreq = self.info['sfreq'] if self.tmin is None: self.tmin_samp = max(0, self.ft_header.nSamples - 1) else: self.tmin_samp = int(round(sfreq * self.tmin)) if self.tmax != np.inf: self.tmax_samp = int(round(sfreq * self.tmax)) else: self.tmax_samp = np.iinfo(np.uint32).max return self def __exit__(self, type, value, traceback): # noqa: D105 self.ft_client.disconnect() def _create_info(self): """Create a minimal Info dictionary for epoching, averaging, etc.""" if self.info is None: warn('Info dictionary not provided. Trying to guess it from ' 'FieldTrip Header object') info = create_info(1, self.ft_header.fSample, 'mag') # create info info._unlocked = True # modify info attributes according to the FieldTrip Header object info['comps'] = list() info['projs'] = list() info['bads'] = list() # channel dictionary list info['chs'] = [] # unrecognized channels chs_unknown = [] for idx, ch in enumerate(self.ft_header.labels): this_info = dict() this_info['scanno'] = idx # extract numerical part of channel name this_info['logno'] = \ int(re.findall(r'[^\W\d_]+|\d+', ch)[-1]) if ch.startswith('EEG'): this_info['kind'] = FIFF.FIFFV_EEG_CH elif ch.startswith('MEG'): this_info['kind'] = FIFF.FIFFV_MEG_CH elif ch.startswith('MCG'): this_info['kind'] = FIFF.FIFFV_MCG_CH elif ch.startswith('EOG'): this_info['kind'] = FIFF.FIFFV_EOG_CH elif ch.startswith('EMG'): this_info['kind'] = FIFF.FIFFV_EMG_CH elif ch.startswith('STI'): this_info['kind'] = FIFF.FIFFV_STIM_CH elif ch.startswith('ECG'): this_info['kind'] = FIFF.FIFFV_ECG_CH elif ch.startswith('MISC'): this_info['kind'] = FIFF.FIFFV_MISC_CH elif ch.startswith('SYS'): this_info['kind'] = FIFF.FIFFV_SYST_CH else: # cannot guess channel type, mark as MISC and warn later this_info['kind'] = FIFF.FIFFV_MISC_CH chs_unknown.append(ch) # Set coil_type (does FT supply this information somehow?) this_info['coil_type'] = FIFF.FIFFV_COIL_NONE # Fieldtrip already does calibration this_info['range'] = 1.0 this_info['cal'] = 1.0 this_info['ch_name'] = ch this_info['loc'] = np.zeros(12) if ch.startswith('EEG'): this_info['coord_frame'] = FIFF.FIFFV_COORD_HEAD elif ch.startswith('MEG'): this_info['coord_frame'] = FIFF.FIFFV_COORD_DEVICE else: this_info['coord_frame'] = FIFF.FIFFV_COORD_UNKNOWN if ch.startswith('MEG') and ch.endswith('1'): this_info['unit'] = FIFF.FIFF_UNIT_T elif ch.startswith('MEG') and (ch.endswith('2') or ch.endswith('3')): this_info['unit'] = FIFF.FIFF_UNIT_T_M else: this_info['unit'] = FIFF.FIFF_UNIT_V this_info['unit_mul'] = 0 info['chs'].append(this_info) info._update_redundant() info._check_consistency() if chs_unknown: msg = ('Following channels in the FieldTrip header were ' 'unrecognized and marked as MISC: ') warn(msg + ', '.join(chs_unknown)) else: # XXX: the data in real-time mode and offline mode # does not match unless this is done info = self.info.copy() info._unlocked = True info['projs'] = list() # FieldTrip buffer already does the calibration for this_info in info['chs']: this_info['range'] = 1.0 this_info['cal'] = 1.0 this_info['unit_mul'] = 0 info._unlocked = False return info
[docs] def get_measurement_info(self): """Return the measurement info. Returns ------- self.info : dict The measurement info. """ return self.info
[docs] @fill_doc def get_data_as_epoch(self, n_samples=1024, picks=None): """Return last n_samples from current time. Parameters ---------- n_samples : int Number of samples to fetch. %(picks_all)s Returns ------- epoch : instance of Epochs The samples fetched as an Epochs object. See Also -------- mne.Epochs.iter_evoked """ ft_header = self.ft_client.getHeader() last_samp = ft_header.nSamples - 1 start = last_samp - n_samples + 1 stop = last_samp events = np.expand_dims(np.array([start, 1, 1]), axis=0) # get the data data = self.ft_client.getData([start, stop]).transpose() # create epoch from data picks = _picks_to_idx(self.info, picks, 'all', exclude=()) info = pick_info(self.info, picks) return EpochsArray(data[picks][np.newaxis], info, events)
[docs] def register_receive_callback(self, callback): """Register a raw buffer receive callback. Parameters ---------- callback : callable The callback. The raw buffer is passed as the first parameter to callback. """ if callback not in self._recv_callbacks: self._recv_callbacks.append(callback)
[docs] def unregister_receive_callback(self, callback): """Unregister a raw buffer receive callback. Parameters ---------- callback : callable The callback to unregister. """ if callback in self._recv_callbacks: self._recv_callbacks.remove(callback)
def _push_raw_buffer(self, raw_buffer): """Push raw buffer to clients using callbacks.""" for callback in self._recv_callbacks: callback(raw_buffer)
[docs] def start_receive_thread(self, nchan): """Start the receive thread. If the measurement has not been started, it will also be started. Parameters ---------- nchan : int The number of channels in the data. """ if self._recv_thread is None: self._recv_thread = threading.Thread(target=_buffer_recv_worker, args=(self, )) self._recv_thread.daemon = True self._recv_thread.start()
[docs] def stop_receive_thread(self, stop_measurement=False): """Stop the receive thread. Parameters ---------- stop_measurement : bool unused, for compatibility. """ self._recv_thread = None
[docs] def iter_raw_buffers(self): """Return an iterator over raw buffers. Returns ------- raw_buffer : generator Generator for iteration over raw buffers. """ # self.tmax_samp should be included iter_times = list(zip( list(range(self.tmin_samp, self.tmax_samp, self.buffer_size)), list(range(self.tmin_samp + self.buffer_size, self.tmax_samp + 1, self.buffer_size)))) last_iter_sample = iter_times[-1][1] if iter_times else self.tmin_samp if last_iter_sample < self.tmax_samp + 1: iter_times.append((last_iter_sample, self.tmax_samp + 1)) for ii, (start, stop) in enumerate(iter_times): # wait for correct number of samples to be available self.ft_client.wait(stop, np.iinfo(np.uint32).max, np.iinfo(np.uint32).max) # get the samples (stop index is inclusive) raw_buffer = self.ft_client.getData([start, stop - 1]).transpose() yield raw_buffer if self._recv_thread != threading.current_thread(): # stop_receive_thread has been called break