Source code for mne_realtime.stim_server_client
# Author: Mainak Jas <mainak@neuro.hut.fi>
# License: BSD (3-clause)
import queue
import time
import socket
import socketserver
import threading
import numpy as np
from mne.utils import logger, verbose, fill_doc
class _ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
"""Create a threaded TCP server.
Parameters
----------
server_address : str
Address on which server is listening
request_handler_class : subclass of BaseRequestHandler
_TriggerHandler which defines the handle method
stim_server : instance of StimServer
object of StimServer class
"""
def __init__(self, server_address, request_handler_class,
stim_server): # noqa: D102
# Basically, this server is the same as a normal TCPServer class
# except that it has an additional attribute stim_server
# Create the server and bind it to the desired server address
socketserver.TCPServer.__init__(self, server_address,
request_handler_class,
False)
self.stim_server = stim_server
class _TriggerHandler(socketserver.BaseRequestHandler):
"""Request handler on the server side."""
def handle(self):
"""Handle requests on the server side."""
self.request.settimeout(None)
while self.server.stim_server._running:
data = self.request.recv(1024) # clip input at 1Kb
data = data.decode() # need to turn it into a string (Py3k)
if data == 'add client':
# Add stim_server._client
client_id = self.server.stim_server \
._add_client(self.client_address[0],
self)
# Instantiate queue for communication between threads
# Note: new queue for each handler
if not hasattr(self, '_tx_queue'):
self._tx_queue = queue.Queue()
self.request.sendall("Client added".encode('utf-8'))
# Mark the client as running
for client in self.server.stim_server._clients:
if client['id'] == client_id:
client['running'] = True
elif data == 'get trigger':
# Pop triggers and send them
if (self._tx_queue.qsize() > 0 and
self.server.stim_server, '_clients'):
trigger = self._tx_queue.get()
self.request.sendall(str(trigger).encode('utf-8'))
else:
self.request.sendall("Empty".encode('utf-8'))
[docs]
class StimServer(object):
"""Stimulation Server.
Server to communicate with StimClient(s).
Parameters
----------
port : int
The port to which the stimulation server must bind to.
n_clients : int
The number of clients which will connect to the server.
See Also
--------
StimClient
"""
def __init__(self, port=4218, n_clients=1): # noqa: D102
# Start a threaded TCP server, binding to localhost on specified port
self._data = _ThreadedTCPServer(('', port),
_TriggerHandler, self)
self.n_clients = n_clients
def __enter__(self): # noqa: D105
# This is done to avoid "[Errno 98] Address already in use"
self._data.allow_reuse_address = True
self._data.server_bind()
self._data.server_activate()
# Start a thread for the server
self._thread = threading.Thread(target=self._data.serve_forever)
# Ctrl-C will cleanly kill all spawned threads
# Once the main thread exits, other threads will exit
self._thread.daemon = True
self._thread.start()
self._running = False
self._clients = list()
return self
def __exit__(self, type, value, traceback): # noqa: D105
self.shutdown()
[docs]
@verbose
def start(self, timeout=np.inf, verbose=None):
"""Start the server.
Parameters
----------
timeout : float
Maximum time to wait for clients to be added.
%(verbose)s
"""
# Start server
if not self._running:
logger.info('RtServer: Start')
self._running = True
start_time = time.time() # init delay counter.
# wait till n_clients are added
while (len(self._clients) < self.n_clients):
current_time = time.time()
if (current_time > start_time + timeout):
raise StopIteration
time.sleep(0.1)
@verbose
def _add_client(self, ip, sock, verbose=None):
"""Add client.
Parameters
----------
ip : str
IP address of the client.
sock : instance of socket.socket
The client socket.
%(verbose)s
"""
logger.info("Adding client with ip = %s" % ip)
client = dict(ip=ip, id=len(self._clients), running=False, socket=sock)
self._clients.append(client)
return client['id']
[docs]
@verbose
def shutdown(self, verbose=None):
"""Shutdown the client and server.
Parameters
----------
%(verbose)s
"""
logger.info("Shutting down ...")
# stop running all the clients
if hasattr(self, '_clients'):
for client in self._clients:
client['running'] = False
self._running = False
self._data.shutdown()
self._data.server_close()
self._data.socket.close()
[docs]
@verbose
def add_trigger(self, trigger, verbose=None):
"""Add a trigger.
Parameters
----------
trigger : int
The trigger to be added to the queue for sending to StimClient.
%(verbose)s
See Also
--------
StimClient.get_trigger
"""
for client in self._clients:
client_id = client['id']
logger.info("Sending trigger %d to client %d"
% (trigger, client_id))
client['socket']._tx_queue.put(trigger)
[docs]
@fill_doc
class StimClient(object):
"""Stimulation Client.
Client to communicate with StimServer
Parameters
----------
host : str
Hostname (or IP address) of the host where StimServer is running.
port : int
Port to use for the connection.
timeout : float
Communication timeout in seconds.
%(verbose)s
See Also
--------
StimServer
"""
@verbose
def __init__(self, host, port=4218, timeout=5.0,
verbose=None): # noqa: D102
try:
logger.info("Setting up client socket")
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.settimeout(timeout)
self._sock.connect((host, port))
logger.info("Establishing connection with server")
data = "add client".encode('utf-8')
n_sent = self._sock.send(data)
if n_sent != len(data):
raise RuntimeError('Could not communicate with server')
resp = self._sock.recv(1024).decode() # turn bytes into str (Py3k)
if resp == 'Client added':
logger.info("Connection established")
else:
raise RuntimeError('Client not added')
except Exception:
raise RuntimeError('Setting up acquisition <-> stimulation '
'computer connection (host: %s '
'port: %d) failed. Make sure StimServer '
'is running.' % (host, port))
[docs]
def close(self):
"""Close the socket object."""
self._sock.close()
[docs]
@verbose
def get_trigger(self, timeout=5.0, verbose=None):
"""Get triggers from StimServer.
Parameters
----------
timeout : float
maximum time to wait for a valid trigger from the server
%(verbose)s
See Also
--------
StimServer.add_trigger
"""
start_time = time.time() # init delay counter. Will stop iterations
while True:
try:
current_time = time.time()
# Raise timeout error
if current_time > (start_time + timeout):
logger.info("received nothing")
return None
self._sock.send("get trigger".encode('utf-8'))
trigger = self._sock.recv(1024)
if trigger != 'Empty':
logger.info("received trigger %s" % str(trigger))
return int(trigger)
except RuntimeError as err:
logger.info('Cannot receive triggers: %s' % (err))