Source code for mne_realtime.client

# Authors: Christoph Dinh <chdinh@nmr.mgh.harvard.edu>
#          Martin Luessi <mluessi@nmr.mgh.harvard.edu>
#          Matti Hamalainen <msh@nmr.mgh.harvard.edu>
#
# License: BSD (3-clause)

from io import StringIO
import socket
import time
import threading

import numpy as np

from mne.utils import logger, verbose, fill_doc
from mne.io.constants import FIFF
try:
    from mne._fiff.meas_info import read_meas_info
    from mne._fiff.tag import Tag, read_tag
    from mne._fiff.tree import make_dir_tree
except ImportError:  # MNE < 1.6
    from mne.io.meas_info import read_meas_info
    from mne.io.tag import Tag, read_tag
    from mne.io.tree import make_dir_tree

# Constants for fiff realtime fiff messages
MNE_RT_GET_CLIENT_ID = 1
MNE_RT_SET_CLIENT_ALIAS = 2


def _recv_tag_raw(sock):
    """Read a tag and the associated data from a socket.

    Parameters
    ----------
    sock : socket.socket
        The socket from which to read the tag.

    Returns
    -------
    tag : instance of Tag
        The tag.
    buff : str
        The raw data of the tag (including header).
    """
    s = sock.recv(4 * 4)
    if len(s) != 16:
        raise RuntimeError('Not enough bytes received, something is wrong. '
                           'Make sure the mne_rt_server is running.')
    tag = Tag(*np.frombuffer(s, '>i4'))
    n_received = 0
    rec_buff = [s]
    while n_received < tag.size:
        n_buffer = min(4096, tag.size - n_received)
        this_buffer = sock.recv(n_buffer)
        rec_buff.append(this_buffer)
        n_received += len(this_buffer)

    if n_received != tag.size:
        raise RuntimeError('Not enough bytes received, something is wrong. '
                           'Make sure the mne_rt_server is running.')

    buff = ''.join(rec_buff)

    return tag, buff


def _buffer_recv_worker(rt_client, nchan):
    """Worker thread that constantly receives buffers."""
    try:
        for raw_buffer in rt_client.raw_buffers(nchan):
            rt_client._push_raw_buffer(raw_buffer)
    except RuntimeError as err:
        # something is wrong, the server stopped (or something)
        rt_client._recv_thread = None
        print('Buffer receive thread stopped: %s' % err)


[docs] @fill_doc class RtClient(object): """Realtime Client. Client to communicate with mne_rt_server Parameters ---------- host : str Hostname (or IP address) of the host where mne_rt_server is running. cmd_port : int Port to use for the command connection. data_port : int Port to use for the data connection. timeout : float Communication timeout in seconds. %(verbose)s """ @verbose def __init__(self, host, cmd_port=4217, data_port=4218, timeout=1.0, verbose=None): # noqa: D102 self._host = host self._data_port = data_port self._cmd_port = cmd_port self._timeout = timeout try: self._cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._cmd_sock.settimeout(timeout) self._cmd_sock.connect((host, cmd_port)) self._cmd_sock.setblocking(0) except Exception: raise RuntimeError('Setting up command connection (host: %s ' 'port: %d) failed. Make sure mne_rt_server ' 'is running. ' % (host, cmd_port)) try: self._data_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._data_sock.settimeout(timeout) self._data_sock.connect((host, data_port)) self._data_sock.setblocking(1) except Exception: raise RuntimeError('Setting up data connection (host: %s ' 'port: %d) failed. Make sure mne_rt_server ' 'is running.' % (host, data_port)) self.verbose = verbose # get my client ID self._client_id = self.get_client_id() self._recv_thread = None self._recv_callbacks = list() def _send_command(self, command): """Send a command to the server. Parameters ---------- command : str The command to send. Returns ------- resp : str The response from the server. """ logger.debug('Sending command: %s' % command) command += '\n' self._cmd_sock.sendall(command.encode('utf-8')) buf, chunk, begin = [], '', time.time() while True: # if we got some data, then break after wait sec if buf and time.time() - begin > self._timeout: break # if we got no data at all, wait a little longer elif time.time() - begin > self._timeout * 2: break try: chunk = self._cmd_sock.recv(8192) if chunk: buf.append(chunk) begin = time.time() else: time.sleep(0.1) except Exception: pass return ''.join(buf) def _send_fiff_command(self, command, data=None): """Send a command through the data connection as a fiff tag. Parameters ---------- command : int The command code. data : str Additional data to send. """ kind = FIFF.FIFF_MNE_RT_COMMAND type = FIFF.FIFFT_VOID size = 4 if data is not None: size += len(data) # first 4 bytes are the command code next = 0 msg = np.array(kind, dtype='>i4').tostring() msg += np.array(type, dtype='>i4').tostring() msg += np.array(size, dtype='>i4').tostring() msg += np.array(next, dtype='>i4').tostring() msg += np.array(command, dtype='>i4').tostring() if data is not None: msg += np.array(data, dtype='>c').tostring() self._data_sock.sendall(msg)
[docs] def get_measurement_info(self): """Get the measurement information. Returns ------- info : dict The measurement information. """ cmd = 'measinfo %d' % self._client_id self._send_command(cmd) buff = [] directory = [] pos = 0 while True: tag, this_buff = _recv_tag_raw(self._data_sock) tag.pos = pos pos += 16 + tag.size directory.append(tag) buff.append(this_buff) if tag.kind == FIFF.FIFF_BLOCK_END and tag.type == FIFF.FIFFT_INT: val = np.frombuffer(this_buff[-4:], dtype=">i4") if val == FIFF.FIFFB_MEAS_INFO: break buff = ''.join(buff) fid = StringIO(buff) tree, _ = make_dir_tree(fid, directory) info, meas = read_meas_info(fid, tree) return info
[docs] def set_client_alias(self, alias): """Set client alias. Parameters ---------- alias : str The client alias. """ self._send_fiff_command(MNE_RT_SET_CLIENT_ALIAS, alias)
[docs] def get_client_id(self): """Get the client ID. Returns ------- id : int The client ID. """ self._send_fiff_command(MNE_RT_GET_CLIENT_ID) # ID is send as answer tag, buff = _recv_tag_raw(self._data_sock) if (tag.kind == FIFF.FIFF_MNE_RT_CLIENT_ID and tag.type == FIFF.FIFFT_INT): client_id = int(np.frombuffer(buff[-4:], dtype=">i4")) else: raise RuntimeError('wrong tag received') return client_id
[docs] def start_measurement(self): """Start the measurement.""" cmd = 'start %d' % self._client_id self._send_command(cmd)
[docs] def stop_measurement(self): """Stop the measurement.""" self._send_command('stop-all')
[docs] def start_receive_thread(self, nchan): """Start the receive thread. If the measurement has not been started, it will also be started. Parameters ---------- nchan : int The number of channels in the data. """ if self._recv_thread is None: self.start_measurement() self._recv_thread = threading.Thread(target=_buffer_recv_worker, args=(self, nchan)) self._recv_thread.start()
[docs] def stop_receive_thread(self, stop_measurement=False): """Stop the receive thread. Parameters ---------- stop_measurement : bool Also stop the measurement. """ self._recv_thread = None if stop_measurement: self.stop_measurement()
[docs] def register_receive_callback(self, callback): """Register a raw buffer receive callback. Parameters ---------- callback : callable The callback. The raw buffer is passed as the first parameter to callback. """ if callback not in self._recv_callbacks: self._recv_callbacks.append(callback)
[docs] def unregister_receive_callback(self, callback): """Unregister a raw buffer receive callback. Parameters ---------- callback : callable The callback to unregister. """ if callback in self._recv_callbacks: self._recv_callbacks.remove(callback)
def _push_raw_buffer(self, raw_buffer): """Push raw buffer to clients using callbacks.""" for callback in self._recv_callbacks: callback(raw_buffer)
[docs] def read_raw_buffer(self, nchan): """Read a single buffer with raw data. Parameters ---------- nchan : int The number of channels (info['nchan']). Returns ------- raw_buffer : float array, shape=(nchan, n_times) The raw data. """ tag, this_buff = _recv_tag_raw(self._data_sock) # skip tags until we get a data buffer while tag.kind != FIFF.FIFF_DATA_BUFFER: tag, this_buff = _recv_tag_raw(self._data_sock) buff = StringIO(this_buff) tag = read_tag(buff) raw_buffer = tag.data.reshape(-1, nchan).T return raw_buffer
[docs] def raw_buffers(self, nchan): """Return an iterator over raw buffers. Parameters ---------- nchan : int The number of channels (info['nchan']). Returns ------- raw_buffer : generator Generator for iteration over raw buffers. """ # stops the loop if stop_receive_thread has been called while self._recv_thread == threading.current_thread(): raw_buffer = self.read_raw_buffer(nchan) if raw_buffer is not None: yield raw_buffer else: break