Source code for daisy.communication.message_stream

# Copyright (C) 2024-2025 DAI-Labor and others
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""An efficient, persistent, and stateless communications stream between two
endpoints over BSD sockets. Supports SSL (soon) and LZ4 compression.

Author: Fabian Hofmann
Modified: 16.10.24
"""

import ctypes
import logging
import pickle
import queue
import select
import socket
import sys
import threading
from time import sleep, time
from typing import Callable, Iterable, Optional, Self

# noinspection PyUnresolvedReferences
from lz4.frame import compress, decompress


[docs] class EndpointSocket: """A bundle of up to two sockets, that is used to communicate with another endpoint over a persistent TCP connection in synchronous manner. Supports authentication and encryption over SSL, and stream compression using LZ4. Thread-safe for both access to the same endpoint socket and using multiple threads using endpoint sockets set to the same address (this is organized through an array of class variables, see below for more info). :cvar _listen_socks: Active listen sockets, along with a respective lock to access each safely. :cvar _acc_r_socks: Pending registered connection cache for each listen socket. :cvar _acc_p_socks: Pending unregistered connection queue for each listen socket. :cvar _reg_r_addrs: Registered remote addresses. :cvar _addr_map: Mapping between registered remote addresses and their aliases. :cvar _act_l_counts: Active thread counter for each listen socket. Socket closes if counter reaches zero. :cvar _lock: General purpose lock to ensure safe access to class variables. :cvar _cls_logger: General purpose logger for class methods. """ _listen_socks: dict[tuple[str, int], tuple[socket.socket, threading.Lock]] = {} _acc_r_socks: dict[ tuple[str, int], tuple[dict[tuple[str, int], socket.socket], threading.Lock] ] = {} _acc_p_socks: dict[ tuple[str, int], queue.Queue[tuple[socket.socket, tuple[str, int]]] ] = {} _reg_r_addrs: set[tuple[str, int]] = set() _addr_map: dict[tuple[str, int], set[tuple[str, int]]] = {} _act_l_counts: dict[tuple[str, int], int] = {} _lock = threading.Lock() _cls_logger: logging.Logger = logging.getLogger("EndpointSocketCLS") _logger: logging.Logger _addr: tuple[str, int] _remote_addr: Optional[tuple[str, int]] _acceptor: bool _send_b_size: int _recv_b_size: int _sock: Optional[socket.socket] _sock_lock: threading.Lock _keep_alive: bool _conn_rdy: threading.Event _opened: bool def __init__( self, name: str, addr: tuple[str, int] = None, remote_addr: tuple[str, int] = None, acceptor: bool = True, send_b_size: int = 65536, recv_b_size: int = 65536, keep_alive: bool = True, ): """Creates a new endpoint socket. Implementation note: A pre-defined remote address is not a guarantee that this endpoint will successfully be allowed to initialize for this remote address --- for example if another endpoint sock with the same remote address (be it generic or pre-defined) has already been registered, then the current one will throw an error. :param name: Name of endpoint for logging purposes. :param addr: Address of endpoint. Mandatory in acceptor mode (acceptor set to True), for initiators this fixes the address the endpoint is bound to. :param remote_addr: Address of remote endpoint to be connected to. Mandatory in initiator mode (acceptor set to false), for acceptors this fixes the remote endpoint that is allowed to be connected to this endpoint. :param acceptor: Determines whether the endpoint accepts or initiates connections to/from other endpoints. :param send_b_size: Underlying send buffer size of socket. :param recv_b_size: Underlying receive buffer size of socket. :param keep_alive: Determines whether to attempt re-connects after the remote endpoint has terminated the connection. :raises ValueError: If the remote address is already taken for the acceptor, or if the address/remote address is not provided for acceptor/initiator. """ self._logger = logging.getLogger(name + "-Socket") self._logger.info(f"Initializing endpoint socket {addr, remote_addr}...") self._addr = addr if addr is not None else ("0.0.0.0", 0) self._remote_addr = remote_addr self._acceptor = acceptor if acceptor: if addr is None: raise ValueError("Accepting endpoint socket requires an address!") elif remote_addr is not None: self._reg_remote(remote_addr) self._fix_rp_acc_socks(addr, remote_addr) elif remote_addr is None: raise ValueError("Initiating endpoint socket requires a remote address!") self._send_b_size = send_b_size self._recv_b_size = recv_b_size self._sock = None self._sock_lock = threading.Lock() self._keep_alive = keep_alive self._conn_rdy = threading.Event() self._opened = False self._logger.info(f"Endpoint socket {addr, remote_addr} initialized.")
[docs] def open(self): """Opens the endpoint socket along with its underlying socket(s) and its connection to a/the remote endpoint socket. Blocking until the connection is established. """ self._logger.info("Opening endpoint socket...") self._opened = True if self._acceptor: self._open_l_socket(self._addr) self._sock_lock.acquire() self._conn_rdy.set() self._connect(initial=True) self._logger.info("Endpoint socket opened.")
[docs] def close(self, shutdown: bool = False): """Closes the endpoint socket, cleaning up any underlying datastructures if acceptor. If already closed, allows the cleanup of just the datastructures incase a shutdown is requested. """ if ( not self._opened and shutdown and self._acceptor and self._remote_addr is not None ): self._unreg_remote(self._addr, self._remote_addr) return self._logger.info("Closing endpoint socket...") self._opened = False with self._sock_lock: _close_socket(self._sock) self._sock = None if self._acceptor: if shutdown and self._remote_addr is not None: self._unreg_remote(self._addr, self._remote_addr) self._close_l_socket(self._addr) self._logger.info("Endpoint socket closed.")
[docs] def send(self, p_data: bytes): """Sends the given bytes of a single object over the connection, performing simple marshalling (size is sent first, then the bytes of the object). Fault-tolerant for breakdowns and resets in the connection. Blocking. :param p_data: Bytes to send. :raises RuntimeError: If connection has been terminated by remote endpoint and keep-alive is disabled. """ while self._opened: try: self._sock_lock.acquire() if _check_w_socket(self._sock): _send_payload(self._sock, p_data) self._sock_lock.release() return self._sock_lock.release() sleep(1) except (OSError, ValueError, RuntimeError) as e: if not self._keep_alive and type(e) is RuntimeError: self._sock_lock.release() raise e self._logger.warning( f"{e.__class__.__name__}({e}) " "while trying to send data. Retrying..." ) # release() of sock_lock is done in connect() self._connect()
[docs] def recv(self, timeout: int = None) -> Optional[bytes]: """Receives the bytes of a single object sent over the connection, performing simple marshalling (size is received first, then the bytes of the object). Fault-tolerant for breakdowns and resets in the connection. Blocking in default-mode if timeout not set. :param timeout: Timeout (seconds) to receive an object to return. :return: Received bytes or None of end point socket has been closed. :raises TimeoutError: If timeout set and triggered. :raises RuntimeError: If connection has been terminated by remote endpoint and keep-alive is disabled. """ while self._opened: try: self._sock_lock.acquire() if _check_r_socket(self._sock, timeout=timeout): p_data = _recv_payload(self._sock) self._sock_lock.release() return p_data self._sock_lock.release() sleep(1) except (OSError, ValueError, RuntimeError) as e: if timeout is not None and type(e) is TimeoutError: raise e if not self._keep_alive and type(e) is RuntimeError: self._sock_lock.release() raise e self._logger.warning( f"{e.__class__.__name__}({e}) " "while trying to receive data. Retrying..." ) # release() of sock_lock is done in connect() self._connect() return None
[docs] def poll( self, lazy: bool = False ) -> tuple[list[bool], tuple[tuple[str, int], tuple[str, int]]]: """Polls the state of various state and addresses of the endpoint socket: * 0,0: Existence of socket (true if connected). * 0,1: Whether there is something to read on the underlying socket. * 0,2: Whether one is able to write on the underlying socket. + 1,0: Address of endpoint socket, else None + 1,1: Address of remote endpoint socket, else None. Note this does not necessarily guarantee that the underlying socket is actually connected and available for reading/writing; e.g. the connection could have broken down since then and is currently being re-established. :param lazy: Whether to lazily skip the actual state of the underlying socket and just check for connectivity. :return: Tuple of boolean states (connectivity, readability, writability) and address-pair of endpoint socket. """ states = [self._sock is not None] + [False] * 2 if not lazy: with self._sock_lock: states[0] = self._sock is not None if self._sock is not None: states[1] = len(select.select([self._sock], [], [], 0)[0]) != 0 states[2] = len(select.select([], [self._sock], [], 0)[1]) != 0 return states, (self._addr, self._remote_addr)
def _connect(self, initial=False): """(Re-)Establishes a connection to a/the remote endpoint socket, first performing any necessary cleanup of the underlying socket, before opening it again and trying to connect/accept a remote endpoint socket. Fault-tolerant for breakdowns and resets in the connection. Blocking. Note if called multiple times concurrently, only one call is executed, the other callers release any locks they are holding and wait until the first caller has established the connection and set the semaphore accordingly. Also note the longer an endpoint socket attempts to re-establish the connection, the longer it will wait inbetween attempts. This is done to prevent busy waiting of endpoint sockets occupying the same listening socket, the exception being initial attempts to establish the connection. :param initial: Whether this is the first time the connection is established. """ if not self._conn_rdy.is_set(): self._sock_lock.release() self._conn_rdy.wait() return self._conn_rdy.clear() i = 0 while self._opened: try: self._logger.info( f"Trying to (re-)establish connection " f"{self._addr, self._remote_addr}..." ) self._setup() self._logger.info( f"Connection {self._addr, self._remote_addr} (re-)established." ) break except (OSError, ValueError, AttributeError, RuntimeError) as e: e_msg = ( f"{e.__class__.__name__}({e}) while trying to (re-)establish " f"connection {self._addr, self._remote_addr}. Retrying..." ) if i == 0: self._logger.info(e_msg) else: self._logger.debug(e_msg) self._sock_lock.release() sleep(min(1 << i, 128)) # initial attempts of an endpoint socket never do binary backoff (+0) i += int(not initial) self._sock_lock.acquire() self._conn_rdy.set() self._sock_lock.release() def _setup(self): """(Re-)Establishes a connection to a/the remote endpoint socket, first performing any necessary cleanup of the underlying socket, before attempting to open it and trying to connect/accept a remote endpoint socket. In case it fails, raises the respective error. Idempotent, may be called multiple times until connection is established. :raises Error: Various errors when failing to establish a connection. """ _close_socket(self._sock) self._sock = None if self._acceptor: remote_addr = self._remote_addr while self._opened and self._sock is None: self._sock, remote_addr = self._get_a_socket( self._addr, self._remote_addr ) self._remote_addr = remote_addr else: self._sock, self._addr = self._get_c_socket(self._addr, self._remote_addr) self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._send_b_size) self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self._recv_b_size) @classmethod def _reg_remote(cls, remote_addr: tuple[str, int]): """Registers a remote address into the class datastructures, notifying other endpoints of its existence. Tries to both resolve the address and finds its fully qualified hostname to reserve all its aliases. If only a single alias is already registered, aborts the whole registration process and registers none of the aliases. :param remote_addr: Remote address to register. :raises ValueError: If remote address is already registered (possibly by another caller). """ cls._cls_logger.info(f"Registering remote address ({remote_addr})...") addr_mapping = set() for _, _, _, _, addr in socket.getaddrinfo( *remote_addr, type=socket.SOCK_STREAM ): addr = _convert_addr_to_name(addr) addr_mapping.add(addr) cls._cls_logger.debug( f"Registering aliases of remote address ({remote_addr}): {addr_mapping}..." ) with cls._lock: if remote_addr in cls._addr_map: raise ValueError( f"Remote address ({remote_addr}) is already registered!" ) for addr in addr_mapping: if addr in cls._reg_r_addrs: raise ValueError( f"Remote address ({addr}) (resolved from {remote_addr})" " is already registered!" ) cls._addr_map[remote_addr] = addr_mapping for addr in addr_mapping: cls._reg_r_addrs.add(addr) cls._cls_logger.info(f"Remote address ({remote_addr}) registered.") @classmethod def _fix_rp_acc_socks(cls, addr: tuple[str, int], remote_addr: tuple[str, int]): """After a remote address has been registered, cycles the waiting connections of that remote address to the registered socket dictionary from the pending socket queue. Necessary, as new endpoint sockets may be created while the listening socket is already opened and connections are already getting accepted. Note this method merely helps in speeding up the sorting of connections to the correct endpoint socket and operates in best-effort manner, as the whole method is not considered a critical section, i.e. during the cycling of connections, the registered connection could be accepted by another thread and be put into the pending connection cache, where it will be dequeued by another thread --- which will result in a ValueError since the remote peer is obviously registered. However, no much of an issue, as the error handling is done in the background during the handling of existing connections. :param addr: Address of endpoint. :param remote_addr: Remote address that was registered. """ cls._cls_logger.info( f"Fixing registered and pending connection caches for address pair " f"{addr, remote_addr}..." ) l_addr, _, _ = cls._get_l_socket(addr) with cls._lock: addr_mapping = cls._addr_map[remote_addr] acc_p_socks = cls._acc_p_socks[l_addr] acc_r_socks, acc_r_lock = cls._acc_r_socks[l_addr] p_socks = [] while True: try: a_sock, a_addr = acc_p_socks.get_nowait() with acc_r_lock: if a_addr in addr_mapping: _close_socket(acc_r_socks.pop(a_addr, None)) acc_r_socks[a_addr] = a_sock else: p_socks.append((a_sock, a_addr)) except queue.Empty: break for a_sock, a_addr in p_socks: try: acc_p_socks.put_nowait((a_sock, a_addr)) except queue.Full: _close_socket(a_sock) cls._cls_logger.info( "Registered and pending connection caches for " f"address pair {addr, remote_addr} fixed." ) @classmethod def _unreg_remote(cls, addr: tuple[str, int], remote_addr: tuple[str, int]): """Unregisters a remote address from the class datastructures. Uses the existing mappings from the original resolution. Also cycles any pending connections of that remote address from the registered socket dictionary to the pending socket queue so the connection may be accepted by any other (generic or pre-defined) endpoint socket listening on the same address. :param remote_addr: Remote address that was registered. """ cls._cls_logger.info( f"Unregistering remote address pair {addr, remote_addr}..." ) l_addr, _, _ = cls._get_l_socket(addr) with cls._lock: addr_mapping = cls._addr_map.pop(remote_addr) acc_p_socks = cls._acc_p_socks[l_addr] acc_r_socks, acc_r_lock = cls._acc_r_socks[l_addr] for a_addr in addr_mapping: with cls._lock: cls._reg_r_addrs.discard(a_addr) with acc_r_lock: a_sock = acc_r_socks.pop(a_addr, None) if a_sock is not None: try: acc_p_socks.put_nowait((a_sock, a_addr)) except queue.Full: _close_socket(a_sock) cls._cls_logger.info(f"Remote address pair {addr, remote_addr} unregistered.") @classmethod def _open_l_socket(cls, addr: tuple[str, int]): """Opens the socket listening to a given address, iff there are no further endpoint sockets listening on the same socket as well. :param addr: Address of listen socket to open. """ cls._get_l_socket(addr, new_endpoint=True) @classmethod def _get_l_socket( cls, addr: tuple[str, int], new_endpoint: bool = False ) -> tuple[tuple[str, int], socket.socket, threading.Lock]: """Gets the socket listening to a given address. If this socket does not exist already, creates it and with it all accompanying datastructures. Supports address resolution. :param addr: Address of listen socket. :return: A tupel consisting of the address, the socket, and a lock to be used for accessing the socket. :raises RuntimeError: If none of the addresses/aliases succeed to create a working socket. """ cls._cls_logger.debug(f"Trying to retrieve listening socket for {addr}...") for res in socket.getaddrinfo(*addr, type=socket.SOCK_STREAM): s_af, s_t, s_p, _, s_addr = res l_addr = _convert_addr_to_name(s_addr) with cls._lock: l_sock, l_sock_lock = cls._listen_socks.get(l_addr, (None, None)) if l_sock is None: try: l_sock = socket.socket(s_af, s_t, s_p) l_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) l_sock.bind(s_addr) l_sock.listen(65535) except OSError: _close_socket(l_sock) continue l_sock_lock = threading.Lock() cls._listen_socks[l_addr] = l_sock, l_sock_lock cls._acc_r_socks[l_addr] = {}, threading.Lock() cls._acc_p_socks[l_addr] = queue.Queue(maxsize=512) if new_endpoint: cls._act_l_counts[l_addr] = cls._act_l_counts.get(l_addr, 0) + 1 cls._cls_logger.debug( f"Listening socket {l_addr, l_sock, l_sock_lock} for {addr} retrieved" ) return l_addr, l_sock, l_sock_lock raise RuntimeError(f"Could not open listen socket for {addr}!") @classmethod def _close_l_socket(cls, addr: tuple[str, int]): """Closes the socket listening to a given address, iff there are no further endpoint sockets listening on the same socket as well. :param addr: Address of listen socket to close. """ cls._cls_logger.debug(f"Performing cleanup for listening socket for {addr}...") l_addr, l_sock, l_sock_lock = cls._get_l_socket(addr) with cls._lock, l_sock_lock: cls._act_l_counts[l_addr] = cls._act_l_counts.get(l_addr, 0) - 1 if cls._act_l_counts[l_addr] > 0: return cls._listen_socks.pop(l_addr) cls._acc_r_socks.pop(l_addr) cls._acc_p_socks.pop(l_addr) cls._act_l_counts.pop(l_addr) _close_socket(l_sock) cls._cls_logger.debug(f"Listening socket for {addr} closed.") @classmethod def _get_a_socket( cls, addr: tuple[str, int], remote_addr: tuple[str, int] = None ) -> tuple[Optional[socket.socket], Optional[tuple[str, int]]]: """Gets an accepted connection socket assigned to a socket of a given address. The connection socket can either be connected to an arbitrary endpoint or to a pre-defined remote peer (remote address). If remote address is not pre-defined, also registers the remote peer's address. As the underlying datastructures are shared between all endpoint sockets (of a process), a connection socket can either be retrieved from them or directly from the listen socket. :param addr: Address of listen socket. :param remote_addr: Address of remote endpoint to be connected to. :return: Tuple of the connection socket and the address of the remote peer. :raises RuntimeError: If none of the addresses/aliases of the listen socket succeed to get a working socket. """ cls._cls_logger.debug( f"Trying to retrieve accept socket for {addr, remote_addr}..." ) l_addr, _, _ = cls._get_l_socket(addr) # Check the active connection cache first, as another Endpoint # might have accepted this one's registered connection already. if remote_addr is not None: a_sock = cls._get_r_acc_sock(l_addr, remote_addr) if a_sock is not None: cls._cls_logger.debug( f"Accept socket {l_addr, remote_addr} " "from registered connection cache retrieved." ) return a_sock, remote_addr # Check the pending connection queue, if this thread does not care # about the address of the remote peer. If connection, registers it. else: a_sock, a_addr = cls._get_p_acc_sock(l_addr) if a_sock is not None: try: cls._reg_remote(a_addr) cls._cls_logger.debug( f"Accept socket for {l_addr} " "from pending connection queue retrieved." ) return a_sock, a_addr except ValueError: _close_socket(a_sock) # Check the OS connection backlog for pending connections a_sock, a_addr = cls._get_n_acc_sock(l_addr, remote_addr) if a_sock is not None: if remote_addr is None: try: cls._reg_remote(a_addr) return a_sock, a_addr except ValueError: _close_socket(a_sock) else: return a_sock, remote_addr return None, None @classmethod def _get_r_acc_sock( cls, l_addr: tuple[str, int], remote_addr: tuple[str, int] ) -> Optional[socket.socket]: """Retrieves and returns a registered (accepted) connection socket assigned to a socket of a given address, if it exists in the (shared) active registered connection cache. :param l_addr: Address of listen socket. :param remote_addr: Address of remote endpoint to be connected to. :return: Tuple of the registered connection socket and the address of the remote peer. """ cls._cls_logger.debug( f"Trying to retrieve accept socket for {l_addr, remote_addr} " "from registered connection cache..." ) with cls._lock: acc_r_socks, acc_r_lock = cls._acc_r_socks[l_addr] for _, _, _, _, addr in socket.getaddrinfo( *remote_addr, type=socket.SOCK_STREAM ): addr = _convert_addr_to_name(addr) with acc_r_lock: a_sock = acc_r_socks.pop(addr, None) if a_sock is not None: return a_sock return None @classmethod def _get_p_acc_sock( cls, l_addr: tuple[str, int] ) -> tuple[Optional[socket.socket], Optional[tuple[str, int]]]: """Retrieves and returns a pending, not registered (accepted) connection socket assigned to a socket of a given address, if there is one in the pending connection queue. :param l_addr: Address of listen socket. :return: Tuple of a connection socket and the address of the remote peer. """ cls._cls_logger.debug( f"Trying to retrieve accept socket for {l_addr} " "from pending connection queue..." ) with cls._lock: acc_p_socks = cls._acc_p_socks[l_addr] try: return acc_p_socks.get_nowait() except queue.Empty: return None, None @classmethod def _get_n_acc_sock( cls, l_addr: tuple[str, int], remote_addr: tuple[str, int] ) -> tuple[Optional[socket.socket], Optional[tuple[str, int]]]: """Retrieves, accepts, and returns a pending connection socket from the OS connection backlog if there is one. The connection socket can either be connected to an arbitrary endpoint or to a pre-defined remote peer (remote address). If the remote address is not pre-defined, returns any connection socket that does not belong to another (registered) endpoint socket, otherwise stores them in the shared data structures. The same is done the other way around with the pending connection queue. :param l_addr: Address of listen socket. :param remote_addr: Address of remote endpoint to be connected to. :return: Tuple of a connection socket and the address of the remote peer. :raises RuntimeError: If there are no new connections are in the OS connection backlog. """ cls._cls_logger.debug(f"Trying to accept socket for {l_addr, remote_addr}...") with cls._lock: l_sock, l_sock_lock = cls._listen_socks[l_addr] acc_r_socks, acc_r_lock = cls._acc_r_socks[l_addr] acc_p_socks = cls._acc_p_socks[l_addr] with l_sock_lock: if not select.select([l_sock], [], [], 0)[0]: raise RuntimeError( f"Could not open connection socket for {l_addr, remote_addr}!" ) a_sock, a_addr = l_sock.accept() a_addr = _convert_addr_to_name(a_addr) # 1. If it is the predefined remote peer, uses it for Endpoint. if remote_addr is not None: for _, _, _, _, r_addr in socket.getaddrinfo( *remote_addr, type=socket.SOCK_STREAM ): r_addr = _convert_addr_to_name(r_addr) if r_addr == a_addr: return a_sock, a_addr # 2. If it is an already registered remote peer, puts it into cache. with cls._lock, acc_r_lock: if a_addr in cls._reg_r_addrs: cls._cls_logger.debug( f"Storing accept socket {a_sock, a_addr} " "into registered connection cache..." ) _close_socket(acc_r_socks.pop(a_addr, None)) acc_r_socks[a_addr] = a_sock return None, None # 3. If the predefined remote peer is undefined, uses it for Endpoint. if remote_addr is None: return a_sock, a_addr # Any other connection is stored in the pending connection queue. try: cls._cls_logger.debug( f"Storing accept socket {a_sock, a_addr} " "into pending connection queue..." ) acc_p_socks.put_nowait((a_sock, a_addr)) except queue.Full: _close_socket(a_sock) return None, None @classmethod def _get_c_socket( cls, addr: tuple[str, int], remote_addr: tuple[str, int] ) -> tuple[Optional[socket.socket], Optional[tuple[str, int]]]: """Creates and returns a connection socket to a given remote address, that might be bound to a specific address, if given. Non-Blocking (with timeout) during connection attempts. :param addr: Local address to bind endpoint to. If none provided, OS chooses an address. :param remote_addr: Address of remote endpoint to be connected to. :return: Tuple of the connection socket and the address of the socket. :raises RuntimeError: If no connection can be established. """ cls._cls_logger.debug( f"Trying to open connection socket for {addr, remote_addr}..." ) for res in socket.getaddrinfo(*addr, type=socket.SOCK_STREAM): s_af, s_t, s_p, _, s_addr = res r_res_list = socket.getaddrinfo( *remote_addr, family=s_af.value, type=s_t.value, proto=s_p ) for r_res in r_res_list: r_af, r_t, r_p, _, r_addr = r_res sock = None try: sock = socket.socket(r_af, r_t, r_p) sock.settimeout(10) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(s_addr) sock.connect(r_addr) sock.settimeout(None) except OSError: _close_socket(sock) continue return sock, _convert_addr_to_name(sock.getsockname()) raise RuntimeError(f"Could not open connection socket for {addr, remote_addr}!")
[docs] class StreamEndpoint: """One of a pair of endpoints that is able to communicate with one another over a persistent stateless stream over BSD sockets. Allows the transmission of generic objects in both synchronous and asynchronous fashion. Supports SSL and LZ4 compression for the stream. Thread-safe for both access to the same endpoint and using multiple threads using endpoints set to the same address. """ _logger: logging.Logger _endpoint_socket: EndpointSocket _marshal_f: Callable[[object], bytes] _unmarshal_f: Callable[[bytes], object] _multithreading: bool _sender: threading.Thread _receiver: threading.Thread _send_buffer: queue.Queue[bytes] _recv_buffer: queue.Queue[bytes] _started: bool _ready: threading.Event _stopped: threading.Event _shutdown: bool def __init__( self, name: str = "StreamEndpoint", addr: tuple[str, int] = None, remote_addr: tuple[str, int] = None, acceptor: bool = True, send_b_size: int = 65536, recv_b_size: int = 65536, compression: bool = False, marshal_f: Callable[[object], bytes] = pickle.dumps, unmarshal_f: Callable[[bytes], object] = pickle.loads, multithreading: bool = False, buffer_size: int = 1024, keep_alive: bool = True, ): """Creates a new endpoint. :param name: Name of endpoint for logging purposes. :param addr: Address of endpoint. :param remote_addr: Address of remote endpoint to be connected to. Optional in acceptor mode. :param acceptor: Determines whether the endpoint accepts or initiates connections to/from other endpoints. :param send_b_size: Underlying send buffer size of socket. :param recv_b_size: Underlying receive buffer size of socket. :param compression: Enables lz4 stream compression for bandwidth optimization. :param marshal_f: Marshal function to serialize objects to send into bytes. :param unmarshal_f: Unmarshal function to deserialize received bytes into objects. :param multithreading: Enables transparent multithreading (i.e. asynchronous object processing) for speedup. :param buffer_size: Size of shared buffer in multithreading mode. :param keep_alive: Determines whether to attempt re-connects after the remote endpoint has terminated the connection or to stop the endpoint. Such connection endings will also result in RuntimeErrors in synchronous mode (during send()/receive()) and the automatic exiting of endpoint sender/receiver loops if multithreading is set. Note that the actual shut down of the endpoint may still be called separately. """ self._logger = logging.getLogger(name) self._logger.info(f"Initializing endpoint {addr, remote_addr}...") self._endpoint_socket = EndpointSocket( name=name, addr=addr, remote_addr=remote_addr, acceptor=acceptor, send_b_size=send_b_size, recv_b_size=recv_b_size, keep_alive=keep_alive, ) if compression: self._marshal_f = lambda d: compress(marshal_f(d)) self._unmarshal_f = lambda e: unmarshal_f(decompress(e)) else: self._marshal_f = marshal_f self._unmarshal_f = unmarshal_f self._multithreading = multithreading self._send_buffer = queue.Queue(maxsize=buffer_size) self._recv_buffer = queue.Queue(maxsize=buffer_size) self._started = False self._ready = threading.Event() self._stopped = threading.Event() self._shutdown = False self._logger.info(f"Endpoint {addr, remote_addr} initialized.")
[docs] def start(self, blocking=True) -> threading.Event: """Starts the endpoint, either in threaded fashion or as part of the main thread. By doing so, the two endpoints are connected and the datastream is opened. This method is blocking until a connection is established by default if multithreading is not enabled or the respective flag is not set. If either is the case, the caller can check the readiness of the connection via the returned event object. Note that in multithreading mode, objects can already be sent/received, however they will only be stored in internal buffers until the establishing of connection (sets the semaphore accordingly to allow async sender and receiver to proceed). :param blocking: Whether to wait for a connection to be established in non-multithreading (sync) mode. :return: Event object to check endpoint's readiness to send/receive. Always true if start() was called blocking. :raises RuntimeError: If endpoint has already been started or shut down. """ def start_ep_socket(): self._endpoint_socket.open() self._ready.set() self._logger.info("Starting endpoint...") if self._shutdown: raise RuntimeError("Endpoint has already been shut down!") if self._started: raise RuntimeError("Endpoint has already been started!") self._started = True self._stopped.clear() if not blocking or self._multithreading: self._logger.info("Starting endpoint socket starter thread...") threading.Thread(target=start_ep_socket, daemon=True).start() else: start_ep_socket() if self._multithreading: self._logger.info( "Multithreading detected, starting endpoint sender/receiver threads..." ) self._sender = threading.Thread(target=self._create_sender, daemon=True) self._receiver = threading.Thread(target=self._create_receiver, daemon=True) self._sender.start() self._receiver.start() self._logger.info("Endpoint started.") return self._ready
[docs] def stop(self, shutdown=False, timeout=10, blocking=True): """Stops the endpoint and closes the stream, cleaning up underlying datastructures. If multithreading is enabled, waits for both endpoint threads to stop before finishing. Note this does not guarantee the sending and receiving of all objects still pending --- they may still be in internal buffers and will be processed if the endpoint is opened again, or get discarded by the underlying socket. This method is blocking until the endpoint is fully closed / shutdown if multithreading is not enabled or the respective flag is not set. If either is the case, the caller can check the progress via the returned event object. Also note if the endpoint has not been started or has already been closed, a set shutdown flag still results in the full cleanup of the underlying datastructures. :param shutdown: If set, also cleans up underlying datastructures of the socket communication. :param timeout: Allows the sender thread to process remaining messages until timeout. :param blocking: Whether to wait for the endpoint to be closed in non-multithreading (sync) mode. :return: Event object to check whether endpoint is closed. Always true if stop() was called blocking. :raises RuntimeError: If endpoint has not been started or already shut down. """ def stop_ep_sock(): if self._multithreading: start = time() while not self._send_buffer.empty() and time() - start < timeout: sleep(1) self._started = False self._ready.set() self._endpoint_socket.close(shutdown) self._shutdown = shutdown if self._multithreading: self._logger.info( "Multithreading detected, waiting for " "endpoint sender/receiver threads to stop..." ) self._sender.join() self._receiver.join() self._ready.clear() self._stopped.set() self._logger.info("Stopping endpoint...") if not self._started: if not self._shutdown and shutdown: self._logger.warning( "Shutdown on closed endpoint detected, cleaning up endpoint..." ) self._endpoint_socket.close(shutdown) self._shutdown = True return else: raise RuntimeError( "Endpoint has not been started or already shut down!" ) if not blocking or self._multithreading: self._logger.info("Starting endpoint socket stopping thread...") threading.Thread(target=stop_ep_sock, daemon=True).start() else: stop_ep_sock() self._logger.info("Endpoint stopped.") return self._stopped
[docs] def send(self, obj: object): """Generic send function that sends any object as a pickle over the persistent datastream. If multithreading is enabled, this function is non-blocking. :param obj: Object to send. :raises RuntimeError: If endpoint has not been started or has been terminated by the remote counterpart. """ self._logger.debug("Sending object...") if not self._started: raise RuntimeError("Endpoint has not been started!") p_obj = self._marshal_f(obj) if self._multithreading: self._logger.debug( "Multithreading detected, putting object " f"into buffer (size={self._send_buffer.qsize()})..." ) self._send_buffer.put(p_obj) else: self._endpoint_socket.send(p_obj) self._logger.debug(f"Pickled object sent of size {len(p_obj)}.")
[docs] def receive(self, timeout: int = None) -> object: """Generic receive function that receives data as a pickle over the persistent datastream, unpickles it into the respective object and returns it. Blocking in default-mode if timeout not set. Also supports receiving objects past the closing of the endpoint, if multithreading is enabled and the objects have already been received nad stored in the receive buffer. :param timeout: Timeout (seconds) to receive an object to return. :return: Received object. :raises RuntimeError: If endpoint has not been started or has been terminated by the remote counterpart, and there is nothing to receive asynchronously (multithreading is not enabled). :raises TimeoutError: If timeout set and triggered. """ self._logger.debug("Receiving object...") if not self._started and self._recv_buffer.empty(): raise RuntimeError("Endpoint has not been started, nothing to receive!") p_obj = None if self._multithreading: self._logger.debug( "Multithreading detected, retrieving object " f"from buffer (size={self._recv_buffer.qsize()})..." ) while self._started or not self._recv_buffer.empty(): try: if timeout is not None: p_obj = self._recv_buffer.get(timeout=timeout) else: p_obj = self._recv_buffer.get(timeout=10) break except queue.Empty: if timeout is not None: raise TimeoutError continue else: p_obj = self._endpoint_socket.recv(timeout) if p_obj is None: raise RuntimeError("Endpoint has not been started!") self._logger.debug(f"Pickled data received of size {len(p_obj)}.") return self._unmarshal_f(p_obj)
[docs] def poll(self) -> tuple[list[bool], tuple[tuple[str, int], tuple[str, int]]]: """Polls the state of various stats of the endpoint (see below) and addresses of endpoint. * 0,0: Existence of underlying socket (true if connected). * 0,1: Whether there is something to read on the internal buffer (async) or underlying socket (sync). * 0,2: Whether one is able to write on the internal buffer (async) or underlying socket (sync). + 1,0: Address of endpoint, else None + 1,1: Address of remote endpoint, else None. Note this does not necessarily guarantee that the underlying endpoint socket is actually connected and available for reading/writing; not only could have the connection broken down since then and is being re-established, but in multithreading mode the content of the internal buffers might change over time and thus change the read/write state. :return: Tuple of boolean states (connectivity, readability, writability) and address-pair of endpoint. """ states, addrs = self._endpoint_socket.poll(lazy=self._multithreading) if self._multithreading: states[1] = not self._recv_buffer.empty() states[2] = not self._send_buffer.full() and states[0] return states, addrs
def _create_sender(self): """Starts the loop to send objects over the socket retrieved from the sending buffer. Note that setting the keep-alive flag to false, terminations of the connection stops the loop, stops the endpoints (not shut down) and exits the thread automatically. """ self._logger.info("AsyncSender: Starting...") self._ready.wait() self._logger.info( "AsyncSender: Starting to send objects in asynchronous mode..." ) while self._started: try: p_obj = self._send_buffer.get(timeout=10) self._logger.debug( f"AsyncSender: Retrieved and sending object (size: {len(p_obj)}) " f"from buffer (length: {self._send_buffer.qsize()})" ) self._endpoint_socket.send(p_obj) except queue.Empty: self._logger.debug( "AsyncSender: Timeout triggered: Buffer empty. Retrying..." ) except RuntimeError: self._logger.info("AsyncSender: Termination of connection detected!") break self._logger.info("AsyncSender: Stopping...") def _create_receiver(self): """Starts the loop to receive objects over the socket and store them in the receiving buffer. Note that setting the keep-alive flag to false, terminations of the connection stops the loop and exits the thread automatically. """ self._logger.info("AsyncReceiver: Starting...") self._ready.wait() self._logger.info( "AsyncReceiver: Starting to receive objects in asynchronous mode..." ) while self._started: try: p_obj = self._endpoint_socket.recv() if p_obj is None: continue self._logger.debug( f"AsyncReceiver: Storing received object (size: {len(p_obj)}) " f"in buffer (length: {self._recv_buffer.qsize()})..." ) self._recv_buffer.put(p_obj, timeout=10) except queue.Full: self._logger.warning( "AsyncReceiver: Timeout triggered: Buffer full. " "Discarding object..." ) except RuntimeError: self._logger.info("AsyncReceiver: Termination of connection detected!") self.stop(blocking=False) break self._logger.info("AsyncReceiver: Stopping...") def __iter__(self): while self._started: try: yield self.receive() except RuntimeError: break def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop(shutdown=True) def __del__(self): if ( not self._shutdown and threading.current_thread() != self._sender and threading.current_thread() != self._receiver ): self.stop(shutdown=True)
[docs] @classmethod def create_quick_sender_ep( cls, objects: Iterable, remote_addr: tuple[str, int], name: str = "QuickSenderEndpoint", addr: tuple[str, int] = None, send_b_size: int = 65536, compression: bool = False, marshal_f: Callable[[object], bytes] = pickle.dumps, blocking=True, ): """Creates a (simplified) one-time endpoint to send a number of objects to a remote endpoint before shutting down. May be called non-blocking to handle endpoint in background entirely. :param objects: Iterable of objects to send to remote endpoint. :param remote_addr: Address of remote endpoint to send messages to. :param name: Name of endpoint for logging purposes. :param addr: Address of endpoint. :param send_b_size: Underlying send buffer size of socket. :param compression: Enables lz4 stream compression for bandwidth optimization. :param marshal_f: Marshal function to serialize objects to send into byte :param blocking: Whether endpoint and message handling is to be done synchronously or asynchronously (using threads). """ def quick_sender_ep(): endpoint = StreamEndpoint( name=name, addr=addr, remote_addr=remote_addr, acceptor=False, send_b_size=send_b_size, compression=compression, marshal_f=marshal_f, multithreading=False, ) endpoint.start() for obj in objects: endpoint.send(obj) endpoint.stop(shutdown=True) if not blocking: threading.Thread(target=quick_sender_ep, daemon=True).start() else: quick_sender_ep()
[docs] @classmethod def receive_latest_ep_objs( cls, endpoints: Iterable[Self], obj_type: type = object ) -> dict[Self, Optional]: """Endpoint helper function to receive the latest objects of a certain type from a number of endpoints. Note this flushes any other messages held by these endpoints as well, as non-blocking receives are called on them until their buffers are exhausted. Any messages of others types are discarded, as are endpoints who are not ready. :param endpoints: Iterable of endpoints to receive objects from. :param obj_type: Type of objects to receive. If none given, receives the latest message of any type. :return: Dictionary of each endpoint and their respective latest received object, None if nothing received for endpoint. """ ep_objs = {} for endpoint in endpoints: ep_obj = None try: while True: ep_msg = endpoint.receive(timeout=0) if isinstance(ep_msg, obj_type): ep_obj = ep_msg else: pass except (RuntimeError, TimeoutError): pass ep_objs[endpoint] = ep_obj return ep_objs
[docs] @classmethod def select_eps(cls, endpoints: Iterable[Self]) -> tuple[list[Self], list[Self]]: """Endpoint select helper function to check a number of endpoints whether objects can be read from or written to them. For simplicity's sake, does not mirror the actual UNIX select function (supporting separate lists). :param endpoints: Iterable of endpoints to check for readiness. :return: Tuple of lists of endpoints that are read/write ready: """ ep_states = [(endpoint, endpoint.poll()) for endpoint in endpoints] r_ready = list(map(lambda t: t[0], filter(lambda t: t[1][0][1], ep_states))) w_ready = list(map(lambda t: t[0], filter(lambda t: t[1][0][2], ep_states))) return r_ready, w_ready
[docs] class EndpointServer: """Helper class to manage a group of (acceptor) connection endpoints listening to the same address. Supports all features of the existing endpoint class, besides also supporting thread-safe access, polling, and management of them as a group. """ _logger: logging.Logger _connection_handler: threading.Thread _connection_cleaner: threading.Thread _connections: dict[tuple[str, int], StreamEndpoint] _p_connections: queue.Queue[tuple[tuple[str, int], StreamEndpoint]] _n_connections: int _c_timeout: int _c_lock: threading.Lock _name: str _addr: tuple[str, int] _send_b_size: int _recv_b_size: int _compression: bool _marshal_f: Callable[[object], bytes] _unmarshal_f: Callable[[bytes], object] _multithreading: bool _buffer_size: int _keep_alive: bool _started: bool def __init__( self, addr: tuple[str, int], name: str = "EndpointServer", c_timeout: int = None, send_b_size: int = 65536, recv_b_size: int = 65536, compression: bool = False, marshal_f: Callable[[object], bytes] = pickle.dumps, unmarshal_f: Callable[[bytes], object] = pickle.loads, multithreading: bool = False, buffer_size: int = 1024, keep_alive: bool = True, ): """Creates a new endpoint server. :param addr: Address of endpoint server. :param name: Name of endpoint server for logging purposes. :param c_timeout: Timeout (secs) for disconnected connection endpoints when performing periodic cleanup. Default is no cleanup. :param send_b_size: Underlying send buffer size of all connection sockets. :param recv_b_size: Underlying receive buffer size of all connection sockets. :param compression: Enables lz4 stream compression for bandwidth optimization. :param marshal_f: Marshal function to serialize objects to send into bytes. :param unmarshal_f: Unmarshal function to deserialize received bytes into objects. :param multithreading: Enables transparent multithreading (for individual endpoints) for speedup. :param buffer_size: Size of shared buffers, both for server and for connection endpoints in multithreading mode. :param keep_alive: Determines whether connection endpoints should attempt re-connects after their remote counterparts have terminated the connection. Such connection terminations will then result in RuntimeErrors in synchronous mode during send()/receive() and a more ressource efficient handling when multithreading by closing endpoints prematurely. Note that the actual shutdown and cleanup of the actual endpoint is still done by the server. """ self._logger = logging.getLogger(name) self._logger.info(f"Initializing endpoint server {addr}...") self._connections = {} self._p_connections = queue.Queue(maxsize=buffer_size) self._n_connections = 0 self._c_timeout = c_timeout self._c_lock = threading.Lock() self._name = name self._addr = addr self._send_b_size = send_b_size self._recv_b_size = recv_b_size self._compression = compression self._marshal_f = marshal_f self._unmarshal_f = unmarshal_f self._multithreading = multithreading self._buffer_size = buffer_size self._keep_alive = keep_alive self._started = False self._logger.info(f"Endpoint server {addr} initialized.")
[docs] def start(self): """Starts the endpoint server, launching the connection handlers in the background. :raises RuntimeError: If endpoint server has already been started. """ self._logger.info("Starting endpoint server...") if self._started: raise RuntimeError("Endpoint server has already been started!") self._started = True self._connection_handler = threading.Thread( target=self._create_connection_handler, daemon=True ) self._connection_handler.start() if self._c_timeout is not None: self._logger.info( "Connection timeout detected, starting periodic cleanup thread..." ) self._connection_cleaner = threading.Thread( target=self._cleanup_connections, daemon=True ) self._connection_cleaner.start() self._logger.info("Endpoint server started.")
[docs] def stop(self, timeout=10, blocking=True): """Stops the endpoint server along all its connection endpoints, cleaning up underlying datastructures. This always shuts down all connection endpoints with a given timeout (see stop() of the Endpoint class for more information on this behavior). :param timeout: Allows each connection endpoint to process remaining messages until timeout. This is done for each endpoint and not in parallel if blocking. :param blocking: Whether to wait for endpoints to be closed before exiting. :raises RuntimeError: If endpoint server has not been started. """ self._logger.info("Stopping endpoint server...") if not self._started: raise RuntimeError("Endpoint server has not been started!") self._started = False self._logger.info("Waiting for connection handler threader to stop...") self._connection_handler.join() if self._c_timeout is not None: self._logger.info("Waiting for periodic cleanup thread to stop...") self._connection_cleaner.join() self._logger.info("Closing connections...") with self._c_lock: connections = self._connections self._connections = {} if not blocking: threading.Thread( target=lambda: self._close_conns(connections, timeout), daemon=True ).start() else: self._close_conns(connections, timeout) self._logger.info("Endpoint server stopped.")
[docs] def poll_connections( self, ) -> tuple[ dict[tuple[str, int], StreamEndpoint], dict[tuple[str, int], StreamEndpoint] ]: """Polls the state of all current available connection endpoints, filtering them for readability and writability. Note that while this method is thread-safe in itself, it is not guaranteed that any returned endpoint will be still connected (and available) at the point of using it, since the underlying cleanup thread (if enabled) might have closed any potential dead endpoint if general timeout set (see __init__()). :return: Tuple of dictionary of addresses and endpoints from which can be read from / written to. """ with self._c_lock: self._logger.debug( f"Polling {len(self._connections)} connections " "for readability and writability..." ) c_states = {addr: (ep, ep.poll()) for addr, ep in self._connections.items()} r_ready = {addr: t[0] for addr, t in c_states.items() if t[1][0][1]} w_ready = {addr: t[0] for addr, t in c_states.items() if t[1][0][2]} self._logger.debug( f"{len(r_ready)} connections for readability and " f"{len(w_ready)} connections for writability found." ) return r_ready, w_ready
[docs] def get_connections( self, addrs: list[tuple[str, int]] ) -> dict[tuple[str, int], Optional[StreamEndpoint]]: """Checks a list of given client addresses whether there is an available connection endpoint for each of them and retrieves them. Note that while this method is thread-safe in itself, it is not guaranteed that any returned endpoint will be still connected (and available) at the point of using it, since the underlying cleanup thread (if enabled) might have closed any potential dead endpoint if general timeout set (see __init__()). :param addrs: Client addresses to check and retrieve endpoints for. :return: Dictionary of addresses and endpoints (None if not existing). """ self._logger.debug(f"Trying to retrieve {len(addrs)} connections...") with self._c_lock: return {addr: self._connections.get(addr) for addr in addrs}
[docs] def get_new_connections( self, n: int = 1, timeout: int = 10 ) -> dict[tuple[str, int], StreamEndpoint]: """Checks and retrieves the first n new connections in the underlying queue filled by the connection handler. Note that while this method is thread-safe in itself, it is not guaranteed that any returned endpoint will be still connected (and available) at the point of using it, since the underlying cleanup thread (if enabled) might have closed any potential dead endpoint if general timeout set (see __init__()). :param n: Maximum numbers of new connections to retrieve. :param timeout: Maximum time when polling for new connections. :return: Dictionary of new addresses and endpoints of new connections. """ self._logger.debug(f"Trying to retrieve {n} new connections...") new_connections = {} start = time() while self._started and len(new_connections) < n: try: remaining_timeout = max(0.0, timeout - time() - start) addr, ep = self._p_connections.get(timeout=remaining_timeout) new_connections[addr] = ep except queue.Empty: break self._logger.debug( f"{len(new_connections)} out of {n} new connections retrieved." ) return new_connections
[docs] def close_connections( self, addrs: list[tuple[str, int]], timeout: int = 10, blocking=True ): """Checks a list of given client addresses whether there is an available connection endpoint for each of them and closes them, shutting them also down. :param addrs: Client addresses to check and close endpoints for. :param timeout: Timeout for shutting down connection endpoints. :param blocking: Whether to wait for endpoints to be closed before returning. """ self._logger.info("Closing connections...") with self._c_lock: connections = {addr: self._connections.pop(addr, None) for addr in addrs} if not blocking: threading.Thread( target=lambda: self._close_conns(connections, timeout), daemon=True ).start() else: self._close_conns(connections, timeout)
def _create_connection_handler(self): """Starts the loop to create new endpoints, connect them to their remote counterparts, and store them into the underlying datastructures of the server. """ self._logger.info("AsyncHandler: Starting connection handler...") while self._started: self._n_connections += 1 logging_prefix = f"AsyncHandler: [{self._n_connections}] " self._logger.debug(logging_prefix + "Preparing endpoint for connection...") new_connection = StreamEndpoint( name=f"{self._name}-{self._n_connections}", addr=self._addr, remote_addr=None, acceptor=True, send_b_size=self._send_b_size, recv_b_size=self._recv_b_size, compression=self._compression, marshal_f=self._marshal_f, unmarshal_f=self._unmarshal_f, multithreading=self._multithreading, buffer_size=self._buffer_size, keep_alive=self._keep_alive, ) n_connection_rdy = new_connection.start(blocking=False) while self._started and not n_connection_rdy.wait(10): self._logger.debug( logging_prefix + "Waiting for endpoint to establish a connection..." ) if not self._started: break remote_addr = new_connection.poll()[1][1] while self._started: try: self._logger.debug( logging_prefix + "Storing connection endpoint in pending queue..." ) self._p_connections.put((remote_addr, new_connection), block=False) with self._c_lock: self._connections[remote_addr] = new_connection self._logger.debug( logging_prefix + "New connection endpoint handled." ) break except queue.Full: self._logger.debug( logging_prefix + "Pending queue full. Discarding oldest endpoint..." ) try: self._p_connections.get(block=False) except queue.Empty: continue self._logger.info("AsyncHandler: Stopping...") def _cleanup_connections(self): """Starts the loop to periodically check all connection endpoints of the server for dead connections and clean up those that remain dead after a set timeout (see __init__()). """ self._logger.info("AsyncCleaner: Starting periodic connection cleanup...") c_pending: dict[tuple[str, int], StreamEndpoint] = {} while self._started: with self._c_lock: c_dead = { addr: self._connections.pop(addr, None) for addr, ep in c_pending.items() if not ep.poll()[0][0] } self._logger.debug( f"AsyncCleaner: Closing {len(c_dead)} out of {len(c_pending)} " f"inactive and marked connection endpoints..." ) threading.Thread( target=lambda: self._close_conns(c_dead, timeout=0), daemon=True ).start() with self._c_lock: c_pending = { addr: ep for addr, ep in self._connections.items() if not ep.poll()[0][0] } self._logger.debug( f"AsyncCleaner: {len(c_pending)} inactive connection endpoints " "found and marked." ) sleep(self._c_timeout) self._logger.info("AsyncCleaner: Stopping...") def _close_conns( self, connections: dict[tuple[str, int], StreamEndpoint], timeout=10 ): """Helper method to close and shutting down a number of endpoints. :param connections: Connection endpoints to close. :param timeout: Individual timeout for shutting down connection endpoints. """ self._logger.debug(f"Trying to close {len(connections)} connections...") n = 0 for addr, endpoint in connections.items(): if endpoint is not None: self._logger.debug(f"Shutting down connection endpoint {addr}...") n += 1 endpoint.stop(shutdown=True, timeout=timeout) self._logger.debug(f"{n} out of {len(connections)} connections closed.") def __iter__(self): while self._started: try: yield self._p_connections.get(timeout=10) except queue.Empty: continue def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() def __del__(self): if ( self._started and threading.current_thread() != self._connection_handler and threading.current_thread() != self._connection_cleaner ): self.stop()
def _convert_addr_to_name( addr: tuple[str, int] | tuple[str, int, int, int], ) -> tuple[str, int]: """Translates a socket address, which is either a 2-tuple (ipv4) or a 4-tuple (ipv6) into a 2-tuple (host, port). Tries to resolve the host to its (DNS) hostname, otherwise keeps the numeric representation. Ports/Services are always kept numeric. :param addr: Address (ipv4/6) to convert. :return: Address tuple. """ return socket.getnameinfo(addr, socket.NI_NUMERICSERV)[0], int( socket.getnameinfo(addr, socket.NI_NUMERICSERV)[1] ) # noinspection PyTypeChecker def _send_payload(sock: socket.socket, payload: bytes): """Sends a payload over a socket, performing simple marshalling (size is sent first, then the bytes of the object). Blocking (if passed socket not configured otherwise). :param sock: Sockets to send payload over :param payload: Payload to send. """ payload_size = len(payload) p_payload_size = bytes(ctypes.c_uint32(payload_size)) _send_n_data(sock, p_payload_size, 4) _send_n_data(sock, payload, payload_size) def _send_n_data(sock: socket.socket, data: bytes, size: int): """Sends a number of bytes over a socket. :param sock: Sockets to send bytes over. :param data: Bytes to send. :param size: Number of bytes to send. :raises RuntimeError: If connection has been closed by remote. """ sent_bytes = 0 while sent_bytes < size: n_sent_bytes = sock.send(data[sent_bytes:]) if n_sent_bytes == 0: raise RuntimeError("Connection terminated!") sent_bytes += n_sent_bytes def _recv_payload(sock: socket.socket) -> bytes: """Receives a payload over a socket, performing simple marshalling (size is received first, then the bytes of the object). Blocking (if passed socket not configured otherwise). :param sock: Socket to received payload over. :return: Received Payload. """ p_payload_size = _recv_n_data(sock, 4, 1) payload_size = int.from_bytes(p_payload_size, byteorder=sys.byteorder) return _recv_n_data(sock, payload_size) def _recv_n_data(sock: socket.socket, size: int, buff_size: int = 4096) -> bytes: """Receives a number of bytes over a socket. :param sock: Socket to receive bytes over. :param size: Number of bytes to receive. :param buff_size: Maximum number of bytes to receive from socket per receive iteration. :return: Received n bytes. :raises RuntimeError: If connection has been closed by remote. """ data = bytearray(size) r_size = size while r_size > 0: n_data = sock.recv(min(r_size, buff_size)) n_size = len(n_data) if n_size == 0: raise RuntimeError("Connection terminated!") data[size - r_size : size - r_size + n_size] = n_data r_size -= n_size return data def _check_r_socket(sock: socket.socket, timeout: int = None) -> bool: """Checks a given socket whether data can be read from it. :param sock: Socket to check for read readiness. :param timeout: Timeout (seconds) to wait for socket to be read ready. :return: True if socket is ready to be read from, else false. :raises TimeoutError: If timeout set and triggered. """ if sock is None: return False if timeout is not None and not select.select([sock], [], [], timeout)[0]: raise TimeoutError elif not select.select([sock], [], [], 0)[0]: return False return True def _check_w_socket(sock: socket.socket) -> bool: """Checks a given socket whether data can be written to it. :param sock: Socket to check for write readiness. :return: True if socket is ready to be written to, else false. """ if sock is None: return False elif not select.select([], [sock], [], 0)[1]: return False return True def _close_socket(sock: socket.socket): """Closes the socket of an endpoint, shutdowns any potential connection that might have been established. :param sock: Socket to close. """ if sock is None: return try: sock.shutdown(socket.SHUT_RDWR) except OSError: pass sock.close()