From ac833df828d307f3ed436703a5bd8ee18c6c8676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCrkan=20G=C3=BCr?= Date: Fri, 27 Mar 2026 15:51:51 +0100 Subject: [PATCH 1/5] Add flake for easier development --- .gitignore | 2 ++ flake.nix | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 flake.nix diff --git a/.gitignore b/.gitignore index 1255668..c27db58 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ polysh.egg-info build/ dist/ .devcontainer/ +flake.lock +result diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..d3809c3 --- /dev/null +++ b/flake.nix @@ -0,0 +1,60 @@ +# Flake for polysh, only used for easier development and testing on nix-based systems. +# Might be used to build and test the package on multiple platforms, but is not intended/supported for production use. +{ + description = "Remote shell multiplexer for executing commands on multiple hosts"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + }; + + outputs = { self, nixpkgs }: + let + supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; + forAllSystems = nixpkgs.lib.genAttrs supportedSystems; + in + { + packages = forAllSystems (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + polysh = pkgs.python3Packages.buildPythonApplication { + pname = "polysh"; + version = "0.15"; + + pyproject = true; + + build-system = with pkgs.python3Packages; [ + hatchling + ]; + + src = ./.; + + meta = with pkgs.lib; { + description = "Remote shell multiplexer for executing commands on multiple hosts"; + homepage = "https://github.com/innogames/polysh"; + license = licenses.gpl2Plus; + maintainers = with maintainers; [ seqizz ]; + platforms = platforms.unix; + }; + }; + + default = self.packages.${system}.polysh; + } + ); + + devShells = forAllSystems (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = pkgs.mkShell { + packages = with pkgs; [ + python3 + python3Packages.hatchling + ]; + }; + } + ); + }; +} From 1b7ebfa4abe6dbb21ebbe7976b039ec6a8c5e4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCrkan=20G=C3=BCr?= Date: Fri, 27 Mar 2026 15:52:29 +0100 Subject: [PATCH 2/5] Migrate to asyncio --- src/polysh/buffered_dispatcher.py | 52 ++++++--- src/polysh/control_commands.py | 78 +++++++------- src/polysh/dispatcher_registry.py | 112 +++++++++++++++++++ src/polysh/dispatchers.py | 30 +++--- src/polysh/event_loop.py | 79 ++++++++++++++ src/polysh/exceptions.py | 22 ++++ src/polysh/main.py | 174 ++++++++++++++++-------------- src/polysh/remote_dispatcher.py | 158 +++++++++++++++------------ src/polysh/stdin.py | 127 +++++++++++++++------- 9 files changed, 572 insertions(+), 260 deletions(-) create mode 100644 src/polysh/dispatcher_registry.py create mode 100644 src/polysh/event_loop.py create mode 100644 src/polysh/exceptions.py diff --git a/src/polysh/buffered_dispatcher.py b/src/polysh/buffered_dispatcher.py index 66b4238..b404666 100644 --- a/src/polysh/buffered_dispatcher.py +++ b/src/polysh/buffered_dispatcher.py @@ -16,13 +16,16 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore import errno +import fcntl +import os +from polysh import dispatcher_registry from polysh.console import console_output +from polysh.exceptions import ExitNow -class BufferedDispatcher(asyncore.file_dispatcher): +class BufferedDispatcher: """A dispatcher with a write buffer to allow asynchronous writers, and a read buffer to permit line oriented manipulations""" @@ -30,17 +33,39 @@ class BufferedDispatcher(asyncore.file_dispatcher): MAX_BUFFER_SIZE = 1 * 1024 * 1024 def __init__(self, fd: int) -> None: - asyncore.file_dispatcher.__init__(self, fd) self.fd = fd - self.read_buffer = b"" - self.write_buffer = b"" + self.read_buffer = b'' + self.write_buffer = b'' + + # Set non-blocking mode + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + # Register with the dispatcher registry + dispatcher_registry.register(fd, self) + + def recv(self, buffer_size: int) -> bytes: + """Read from the file descriptor.""" + return os.read(self.fd, buffer_size) + + def send(self, data: bytes) -> int: + """Write to the file descriptor.""" + return os.write(self.fd, data) + + def close(self) -> None: + """Unregister and close the file descriptor.""" + dispatcher_registry.unregister(self.fd) + try: + os.close(self.fd) + except OSError: + pass def handle_read(self) -> None: self._handle_read_chunk() def _handle_read_chunk(self) -> bytes: """Some data can be read""" - new_data = b"" + new_data = b'' buffer_length = len(self.read_buffer) try: while buffer_length < self.MAX_BUFFER_SIZE: @@ -50,12 +75,11 @@ def _handle_read_chunk(self) -> bytes: if e.errno == errno.EAGAIN: # End of the available data break - elif e.errno == errno.EIO and new_data: + if e.errno == errno.EIO and new_data: # Hopefully we could read an error message before the # actual termination break - else: - raise + raise if not piece: # A closed connection is indicated by signaling a read @@ -66,7 +90,7 @@ def _handle_read_chunk(self) -> bytes: buffer_length += len(piece) finally: - new_data = new_data.replace(b"\r", b"\n") + new_data = new_data.replace(b'\r', b'\n') self.read_buffer += new_data return new_data @@ -76,16 +100,14 @@ def readable(self) -> bool: def writable(self) -> bool: """Do we have something to write?""" - return self.write_buffer != b"" + return self.write_buffer != b'' def dispatch_write(self, buf: bytes) -> bool: """Augment the buffer with stuff to write when possible""" self.write_buffer += buf if len(self.write_buffer) > self.MAX_BUFFER_SIZE: console_output( - "Buffer too big ({:d}) for {}\n".format( - len(self.write_buffer), str(self) - ).encode() + f'Buffer too big ({len(self.write_buffer):d}) for {str(self)}\n'.encode() ) - raise asyncore.ExitNow(1) + raise ExitNow(1) return True diff --git a/src/polysh/control_commands.py b/src/polysh/control_commands.py index 02eb8fc..7242b72 100644 --- a/src/polysh/control_commands.py +++ b/src/polysh/control_commands.py @@ -18,22 +18,20 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore import os import shlex from typing import List +from polysh import dispatchers, remote_dispatcher, stdin +from polysh.completion import add_to_history, complete_local_path +from polysh.console import console_output from polysh.control_commands_helpers import ( complete_shells, expand_local_path, selected_shells, toggle_shells, ) -from polysh.completion import complete_local_path, add_to_history -from polysh.console import console_output -from polysh import dispatchers -from polysh import remote_dispatcher -from polysh import stdin +from polysh.exceptions import ExitNow def complete_list(line: str, text: str) -> List[str]: @@ -43,11 +41,11 @@ def complete_list(line: str, text: str) -> List[str]: def do_list(command: str) -> None: instances = [i.get_info() for i in selected_shells(command)] flat_instances = dispatchers.format_info(instances) - console_output(b"".join(flat_instances)) + console_output(b''.join(flat_instances)) def do_quit(command: str) -> None: - raise asyncore.ExitNow(0) + raise ExitNow(0) def complete_chdir(line: str, text: str) -> List[str]: @@ -58,29 +56,29 @@ def do_chdir(command: str) -> None: try: os.chdir(expand_local_path(command.strip())) except OSError as e: - console_output("{}\n".format(str(e)).encode()) + console_output(f'{str(e)}\n'.encode()) def complete_send_ctrl(line: str, text: str) -> List[str]: if len(line[:-1].split()) >= 2: # Control letter already given in command line return complete_shells(line, text, lambda i: i.enabled) - if text in ("c", "d", "z"): - return [text + " "] - return ["c ", "d ", "z "] + if text in ('c', 'd', 'z'): + return [text + ' '] + return ['c ', 'd ', 'z '] def do_send_ctrl(command: str) -> None: split = command.split() if not split: - console_output(b"Expected at least a letter\n") + console_output(b'Expected at least a letter\n') return letter = split[0] if len(letter) != 1: - console_output("Expected a single letter, got: {}\n".format(letter).encode()) + console_output(f'Expected a single letter, got: {letter}\n'.encode()) return - control_letter = chr(ord(letter.lower()) - ord("a") + 1) - for i in selected_shells(" ".join(split[1:])): + control_letter = chr(ord(letter.lower()) - ord('a') + 1) + for i in selected_shells(' '.join(split[1:])): if i.enabled: i.dispatch_write(control_letter.encode()) @@ -122,7 +120,9 @@ def complete_reconnect(line: str, text: str) -> List[str]: def do_reconnect(command: str) -> None: selec = selected_shells(command) - to_reconnect = [i for i in selec if i.state == remote_dispatcher.STATE_DEAD] + to_reconnect = [ + i for i in selec if i.state == remote_dispatcher.STATE_DEAD + ] for i in to_reconnect: i.disconnect() i.close() @@ -162,13 +162,13 @@ def do_hide_password(command: str) -> None: i.debug = False if not warned: console_output( - b"Debugging disabled to avoid displaying " b"passwords\n" + b'Debugging disabled to avoid displaying passwords\n' ) warned = True stdin.set_echo(False) if remote_dispatcher.options.log_file: - console_output(b"Logging disabled to avoid writing passwords\n") + console_output(b'Logging disabled to avoid writing passwords\n') remote_dispatcher.options.log_file = None @@ -176,22 +176,22 @@ def complete_set_debug(line: str, text: str) -> List[str]: if len(line[:-1].split()) >= 2: # Debug value already given in command line return complete_shells(line, text) - if text.lower() in ("y", "n"): - return [text + " "] - return ["y ", "n "] + if text.lower() in ('y', 'n'): + return [text + ' '] + return ['y ', 'n '] def do_set_debug(command: str) -> None: split = command.split() if not split: - console_output(b"Expected at least a letter\n") + console_output(b'Expected at least a letter\n') return letter = split[0].lower() - if letter not in ("y", "n"): - console_output("Expected 'y' or 'n', got: {}\n".format(split[0]).encode()) + if letter not in ('y', 'n'): + console_output(f"Expected 'y' or 'n', got: {split[0]}\n".encode()) return - debug = letter == "y" - for i in selected_shells(" ".join(split[1:])): + debug = letter == 'y' + for i in selected_shells(' '.join(split[1:])): i.debug = debug @@ -200,25 +200,25 @@ def do_export_vars(command: str) -> None: for shell in dispatchers.all_instances(): if shell.enabled: environment_variables = { - "POLYSH_RANK": str(rank), - "POLYSH_NAME": shell.hostname, - "POLYSH_DISPLAY_NAME": shell.display_name, + 'POLYSH_RANK': str(rank), + 'POLYSH_NAME': shell.hostname, + 'POLYSH_DISPLAY_NAME': shell.display_name, } for name, value in environment_variables.items(): shell.dispatch_command( - "export {}={}\n".format(name, shlex.quote(value)).encode() + f'export {name}={shlex.quote(value)}\n'.encode() ) rank += 1 for shell in dispatchers.all_instances(): if shell.enabled: shell.dispatch_command( - "export POLYSH_NR_SHELLS={:d}\n".format(rank).encode() + f'export POLYSH_NR_SHELLS={rank:d}\n'.encode() ) -add_to_history("$POLYSH_RANK $POLYSH_NAME $POLYSH_DISPLAY_NAME") -add_to_history("$POLYSH_NR_SHELLS") +add_to_history('$POLYSH_RANK $POLYSH_NAME $POLYSH_DISPLAY_NAME') +add_to_history('$POLYSH_NR_SHELLS') def complete_set_log(line: str, text: str) -> List[str]: @@ -229,13 +229,13 @@ def do_set_log(command: str) -> None: command = command.strip() if command: try: - remote_dispatcher.options.log_file = open(command, "a") - except IOError as e: - console_output("{}\n".format(str(e)).encode()) + remote_dispatcher.options.log_file = open(command, 'a') + except OSError as e: + console_output(f'{str(e)}\n'.encode()) command = None if not command: remote_dispatcher.options.log_file = None - console_output(b"Logging disabled\n") + console_output(b'Logging disabled\n') def complete_show_read_buffer(line: str, text: str) -> List[str]: @@ -248,4 +248,4 @@ def do_show_read_buffer(command: str) -> None: for i in selected_shells(command): if i.read_in_state_not_started: i.print_lines(i.read_in_state_not_started) - i.read_in_state_not_started = b"" + i.read_in_state_not_started = b'' diff --git a/src/polysh/dispatcher_registry.py b/src/polysh/dispatcher_registry.py new file mode 100644 index 0000000..d808ce3 --- /dev/null +++ b/src/polysh/dispatcher_registry.py @@ -0,0 +1,112 @@ +"""Polysh - Dispatcher Registry + +Manages the global selector and dispatcher tracking, replacing asyncore's socket_map. + +Copyright (c) 2006 Guillaume Chazarain +Copyright (c) 2024 InnoGames GmbH +""" +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import selectors +from typing import Dict, Any, List + +# Global selector instance +_selector = selectors.DefaultSelector() + +# Mapping from file descriptor to dispatcher instance +_dispatchers: Dict[int, Any] = {} + +# Track last-registered events per fd to avoid redundant epoll_ctl syscalls +_current_events: Dict[int, int] = {} + + +def register(fd: int, dispatcher: Any) -> None: + """Register a dispatcher with the selector. + + Initially registers with EVENT_READ - events are updated dynamically + based on readable()/writable() before each select(). + """ + _dispatchers[fd] = dispatcher + # Register with EVENT_READ initially; loop_iteration will update as needed + _selector.register(fd, selectors.EVENT_READ, dispatcher) + _current_events[fd] = selectors.EVENT_READ + + +def unregister(fd: int) -> None: + """Unregister a dispatcher from the selector.""" + if fd in _dispatchers: + del _dispatchers[fd] + _current_events.pop(fd, None) + try: + _selector.unregister(fd) + except (KeyError, ValueError): + # Already unregistered or invalid fd + pass + + +def modify_events(fd: int, events: int) -> None: + """Modify the events a dispatcher is interested in. + + Skips the syscall if events haven't changed since last call. + + If events is 0, the fd is temporarily unregistered from the selector + but kept in _dispatchers. It will be re-registered when events become non-zero. + """ + if fd not in _dispatchers: + return + + # Skip if events haven't changed — avoids unnecessary epoll_ctl syscall + if _current_events.get(fd, 0) == events: + return + + old_events = _current_events.get(fd, 0) + + if old_events == 0: + # Not currently registered in selector, need to register + if events != 0: + _selector.register(fd, events, _dispatchers[fd]) + elif events == 0: + # No events - unregister temporarily + _selector.unregister(fd) + else: + # Modify existing registration + _selector.modify(fd, events, _dispatchers[fd]) + + _current_events[fd] = events + + +def all_dispatchers() -> List[Any]: + """Return a snapshot list of all registered dispatchers. + + Use iter_dispatchers() when mutation during iteration is not a concern. + """ + return list(_dispatchers.values()) + + +def iter_dispatchers(): + """Iterate dispatcher values directly, avoiding a list copy. + + Only safe when the caller does not add/remove dispatchers during iteration. + """ + return _dispatchers.values() + + +def get_selector() -> selectors.BaseSelector: + """Return the global selector instance.""" + return _selector + + +def get_dispatcher(fd: int) -> Any: + """Return the dispatcher for a given file descriptor, or None.""" + return _dispatchers.get(fd) diff --git a/src/polysh/dispatchers.py b/src/polysh/dispatchers.py index 4af0a61..170b097 100644 --- a/src/polysh/dispatchers.py +++ b/src/polysh/dispatchers.py @@ -16,27 +16,23 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore import fcntl import struct import sys import termios -from typing import List -from typing import Tuple +from typing import List, Tuple -from polysh import remote_dispatcher -from polysh import display_names +from polysh import dispatcher_registry, display_names, remote_dispatcher from polysh.terminal_size import terminal_size def _split_port(hostname: str) -> Tuple[str, str]: """Splits a string(hostname, given by the user) into hostname and port, returns a tuple""" - s = hostname.split(":", 1) + s = hostname.split(':', 1) if len(s) > 1: return s[0], s[1] - else: - return s[0], "22" + return s[0], '22' def all_instances() -> List[remote_dispatcher.RemoteDispatcher]: @@ -44,10 +40,10 @@ def all_instances() -> List[remote_dispatcher.RemoteDispatcher]: return sorted( [ i - for i in asyncore.socket_map.values() + for i in dispatcher_registry.all_dispatchers() if isinstance(i, remote_dispatcher.RemoteDispatcher) ], - key=lambda i: i.display_name or "", + key=lambda i: i.display_name or '', ) @@ -84,8 +80,8 @@ def update_terminal_size() -> None: w = max(w - display_names.max_display_name_length - 2, min(w, 10)) # python bug http://python.org/sf/1112949 on amd64 # from ajaxterm.py - bug = struct.unpack("i", struct.pack("I", termios.TIOCSWINSZ))[0] - packed_size = struct.pack("HHHH", h, w, 0, 0) + bug = struct.unpack('i', struct.pack('I', termios.TIOCSWINSZ))[0] + packed_size = struct.pack('HHHH', h, w, 0, 0) term_size = w, h for i in all_instances(): if i.enabled and i.term_size != term_size: @@ -112,17 +108,17 @@ def format_info(info_list: List[List[bytes]]) -> List[bytes]: # as it can get much longer in some shells than in others orig_str = info[str_id] indent = max_lengths[str_id] - len(orig_str) - info[str_id] = orig_str + indent * b" " - flattened_info_list.append(b" ".join(info) + b"\n") + info[str_id] = orig_str + indent * b' ' + flattened_info_list.append(b' '.join(info) + b'\n') return flattened_info_list def create_remote_dispatchers(hosts: List[str]) -> None: - last_message = "" + last_message = '' for i, host in enumerate(hosts): if remote_dispatcher.options.interactive: - last_message = "Started %d/%d remote processes\r" % (i, len(hosts)) + last_message = 'Started %d/%d remote processes\r' % (i, len(hosts)) sys.stdout.write(last_message) sys.stdout.flush() try: @@ -133,5 +129,5 @@ def create_remote_dispatchers(hosts: List[str]) -> None: raise if last_message: - sys.stdout.write(" " * len(last_message) + "\r") + sys.stdout.write(' ' * len(last_message) + '\r') sys.stdout.flush() diff --git a/src/polysh/event_loop.py b/src/polysh/event_loop.py new file mode 100644 index 0000000..268b742 --- /dev/null +++ b/src/polysh/event_loop.py @@ -0,0 +1,79 @@ +"""Polysh - Event Loop + +Provides the main loop iteration using selectors, replacing asyncore.loop(). + +Copyright (c) 2006 Guillaume Chazarain +Copyright (c) 2024 InnoGames GmbH +""" +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import errno +import selectors +from typing import Optional + +from polysh import dispatcher_registry + + +def loop_iteration(timeout: Optional[float] = None) -> None: + """Perform a single iteration of the event loop. + + This replaces asyncore.loop(count=1, timeout=timeout, use_poll=True). + + Updates selector registrations based on each dispatcher's readable()/writable() + state, then performs select and dispatches events to handlers. + """ + selector = dispatcher_registry.get_selector() + + # Update event registrations based on current readable/writable state. + # Iterate the dict values directly — this loop doesn't mutate _dispatchers. + for dispatcher in dispatcher_registry.iter_dispatchers(): + events = 0 + if dispatcher.readable(): + events |= selectors.EVENT_READ + if dispatcher.writable(): + events |= selectors.EVENT_WRITE + dispatcher_registry.modify_events(dispatcher.fd, events) + + # Perform select + try: + ready = selector.select(timeout) + except OSError as e: + if e.errno == errno.EINTR: + # Interrupted by signal handler, just return + return + raise + + # Dispatch events + for key, events in ready: + dispatcher = key.data + + # Check if dispatcher is still valid + if dispatcher_registry.get_dispatcher(key.fd) is None: + continue + + if events & selectors.EVENT_READ: + try: + dispatcher.handle_read() + except Exception: + dispatcher.handle_close() + + # Re-check dispatcher is still valid after handle_read + if dispatcher_registry.get_dispatcher(key.fd) is None: + continue + + if events & selectors.EVENT_WRITE: + try: + dispatcher.handle_write() + except Exception: + dispatcher.handle_close() diff --git a/src/polysh/exceptions.py b/src/polysh/exceptions.py new file mode 100644 index 0000000..9bb9dc5 --- /dev/null +++ b/src/polysh/exceptions.py @@ -0,0 +1,22 @@ +"""Polysh - Exceptions + +Copyright (c) 2006 Guillaume Chazarain +Copyright (c) 2024 InnoGames GmbH +""" +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +class ExitNow(Exception): + """Exception to signal clean exit. First argument is exit code.""" + pass diff --git a/src/polysh/main.py b/src/polysh/main.py index 2db3902..2c170cf 100644 --- a/src/polysh/main.py +++ b/src/polysh/main.py @@ -16,26 +16,28 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore +import argparse import atexit import getpass import locale -import argparse import os +import readline +import resource import signal import sys import termios -import readline -import resource -from typing import Callable, List - -from polysh import remote_dispatcher -from polysh import dispatchers -from polysh import stdin +from typing import Callable + +from polysh import ( + VERSION, + control_commands, + dispatchers, + remote_dispatcher, + stdin, +) from polysh.console import console_output +from polysh.exceptions import ExitNow from polysh.host_syntax import expand_syntax -from polysh import control_commands -from polysh import VERSION def kill_all() -> None: @@ -52,100 +54,105 @@ def parse_cmdline() -> argparse.Namespace: description = 'Control commands are prefixed by ":".' parser = argparse.ArgumentParser(description=description) parser.add_argument( - "--hosts-file", + '--hosts-file', type=str, - action="append", - dest="hosts_filenames", - metavar="FILE", + action='append', + dest='hosts_filenames', + metavar='FILE', default=[], - help="read hostnames from given file, one per line", + help='read hostnames from given file, one per line', ) parser.add_argument( - "--command", + '--command', type=str, - dest="command", + dest='command', default=None, - help="command to execute on the remote shells", - metavar="CMD", + help='command to execute on the remote shells', + metavar='CMD', ) - def_ssh = "exec ssh -oLogLevel=Quiet -t %(host)s %(port)s" + def_ssh = 'exec ssh -oLogLevel=Quiet -t %(host)s %(port)s' parser.add_argument( - "--ssh", + '--ssh', type=str, - dest="ssh", + dest='ssh', default=def_ssh, - metavar="SSH", - help="ssh command to use [%(default)s]", + metavar='SSH', + help='ssh command to use [%(default)s]', ) parser.add_argument( - "--user", + '--user', type=str, - dest="user", + dest='user', default=None, - help="remote user to log in as", - metavar="USER", + help='remote user to log in as', + metavar='USER', ) parser.add_argument( - "--no-color", - action="store_true", - dest="disable_color", - help="disable colored hostnames [enabled]", + '--no-color', + action='store_true', + dest='disable_color', + help='disable colored hostnames [enabled]', ) parser.add_argument( - "--password-file", + '--password-file', type=str, - dest="password_file", + dest='password_file', default=None, - metavar="FILE", - help="read a password from the specified file. - is the tty.", + metavar='FILE', + help='read a password from the specified file. - is the tty.', ) parser.add_argument( - "--log-file", + '--log-file', type=str, - dest="log_file", - help="file to log each machine conversation [none]", + dest='log_file', + help='file to log each machine conversation [none]', + ) + parser.add_argument( + '--abort-errors', + action='store_true', + dest='abort_error', + help='abort if some shell fails to initialize [ignore]', ) parser.add_argument( - "--abort-errors", - action="store_true", - dest="abort_error", - help="abort if some shell fails to initialize [ignore]", + '--debug', + action='store_true', + dest='debug', + help='print debugging information', ) parser.add_argument( - "--debug", action="store_true", dest="debug", help="print debugging information" + '--profile', action='store_true', dest='profile', default=False ) - parser.add_argument("--profile", action="store_true", dest="profile", default=False) - parser.add_argument("host_names", nargs="*") + parser.add_argument('host_names', nargs='*') args = parser.parse_args() for filename in args.hosts_filenames: try: - hosts_file = open(filename, "r") + hosts_file = open(filename) for line in hosts_file.readlines(): - if "#" in line: - line = line[: line.index("#")] + if '#' in line: + line = line[: line.index('#')] line = line.strip() if line: args.host_names.append(line) hosts_file.close() - except IOError as e: + except OSError as e: parser.error(str(e)) if args.log_file: try: - args.log_file = open(args.log_file, "a") - except IOError as e: + args.log_file = open(args.log_file, 'a') + except OSError as e: print(e) sys.exit(1) if not args.host_names: - parser.error("no hosts given") + parser.error('no hosts given') - if args.password_file == "-": + if args.password_file == '-': args.password = getpass.getpass() elif args.password_file is not None: - password_file = open(args.password_file, "r") - args.password = password_file.readline().rstrip("\n") + password_file = open(args.password_file) + args.password = password_file.readline().rstrip('\n') else: args.password = None @@ -159,20 +166,20 @@ def find_non_interactive_command(command: str) -> str: stdin = sys.stdin.read() if stdin and command: print( - "--command and reading from stdin are incompatible", + '--command and reading from stdin are incompatible', file=sys.stderr, ) sys.exit(1) - if stdin and not stdin.endswith("\n"): - stdin += "\n" + if stdin and not stdin.endswith('\n'): + stdin += '\n' return command or stdin def init_history(histfile: str) -> None: - if hasattr(readline, "read_history_file"): + if hasattr(readline, 'read_history_file'): try: readline.read_history_file(histfile) - except IOError: + except OSError: pass @@ -182,7 +189,7 @@ def save_history(histfile: str) -> None: def loop(interactive: bool) -> None: - histfile = os.path.expanduser("~/.polysh_history") + histfile = os.path.expanduser('~/.polysh_history') init_history(histfile) next_signal = None last_status = None @@ -191,11 +198,11 @@ def loop(interactive: bool) -> None: if next_signal: current_signal = next_signal next_signal = None - sig2chr = {signal.SIGINT: "C", signal.SIGTSTP: "Z"} + sig2chr = {signal.SIGINT: 'C', signal.SIGTSTP: 'Z'} ctrl = sig2chr[current_signal] - remote_dispatcher.log("> ^{}\n".format(ctrl).encode()) + remote_dispatcher.log(f'> ^{ctrl}\n'.encode()) control_commands.do_send_ctrl(ctrl) - console_output(b"") + console_output(b'') stdin.the_stdin_thread.prepend_text = None while dispatchers.count_awaited_processes()[ 0 @@ -206,14 +213,14 @@ def loop(interactive: bool) -> None: r.print_unfinished_line() current_status = dispatchers.count_awaited_processes() if current_status != last_status: - console_output(b"") + console_output(b'') if remote_dispatcher.options.interactive: stdin.the_stdin_thread.want_raw_input() last_status = current_status if dispatchers.all_terminated(): # Clear the prompt - console_output(b"") - raise asyncore.ExitNow(remote_dispatcher.options.exit_code) + console_output(b'') + raise ExitNow(remote_dispatcher.options.exit_code) if not next_signal: # possible race here with the signal handler remote_dispatcher.main_loop_iteration() @@ -223,22 +230,22 @@ def loop(interactive: bool) -> None: else: kill_all() os.kill(0, signal.SIGINT) - except asyncore.ExitNow as e: - console_output(b"") + except ExitNow as e: + console_output(b'') save_history(histfile) sys.exit(e.args[0]) def _profile(continuation: Callable) -> None: - prof_file = "polysh.prof" + prof_file = 'polysh.prof' import cProfile import pstats - print("Profiling using cProfile") - cProfile.runctx("continuation()", globals(), locals(), prof_file) + print('Profiling using cProfile') + cProfile.runctx('continuation()', globals(), locals(), prof_file) stats = pstats.Stats(prof_file) stats.strip_dirs() - stats.sort_stats("time", "calls") + stats.sort_stats('time', 'calls') stats.print_stats(50) stats.print_callees(50) os.remove(prof_file) @@ -252,7 +259,7 @@ def restore_tty_on_exit() -> None: def run() -> None: """Launch polysh""" - locale.setlocale(locale.LC_ALL, "") + locale.setlocale(locale.LC_ALL, '') atexit.register(kill_all) signal.signal(signal.SIGPIPE, signal.SIG_DFL) @@ -260,7 +267,9 @@ def run() -> None: args.command = find_non_interactive_command(args.command) args.exit_code = 0 - args.interactive = not args.command and sys.stdin.isatty() and sys.stdout.isatty() + args.interactive = ( + not args.command and sys.stdin.isatty() and sys.stdout.isatty() + ) if args.interactive: restore_tty_on_exit() @@ -281,8 +290,8 @@ def run() -> None: resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, new_hard)) except OSError as e: print( - "Failed to change RLIMIT_NOFILE from soft={} hard={} to soft={} " - "hard={}: {}".format(old_soft, old_hard, new_soft, new_hard, e), + f'Failed to change RLIMIT_NOFILE from soft={old_soft} hard={old_hard} to soft={new_soft} ' + f'hard={new_hard}: {e}', file=sys.stderr, ) sys.exit(1) @@ -290,7 +299,8 @@ def run() -> None: dispatchers.create_remote_dispatchers(hosts) signal.signal( - signal.SIGWINCH, lambda signum, frame: dispatchers.update_terminal_size() + signal.SIGWINCH, + lambda signum, frame: dispatchers.update_terminal_size(), ) stdin.the_stdin_thread = stdin.StdinThread(args.interactive) @@ -311,14 +321,14 @@ def safe_loop() -> None: def main(): """Wrapper around run() to setup sentry""" - sentry_dsn = os.environ.get("POLYSH_SENTRY_DSN") + sentry_dsn = os.environ.get('POLYSH_SENTRY_DSN') if sentry_dsn: import sentry_sdk sentry_sdk.init( dsn=sentry_dsn, - release=".".join(map(str, VERSION)), + release='.'.join(map(str, VERSION)), ) try: diff --git a/src/polysh/remote_dispatcher.py b/src/polysh/remote_dispatcher.py index ec5d327..ae66b5a 100644 --- a/src/polysh/remote_dispatcher.py +++ b/src/polysh/remote_dispatcher.py @@ -16,29 +16,27 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore import os +import platform import pty +import select import signal import sys import termios -import select -import platform -from typing import Optional, List -from argparse import Namespace +from typing import List, Optional +from polysh import callbacks, display_names, event_loop from polysh.buffered_dispatcher import BufferedDispatcher -from polysh import callbacks from polysh.console import console_output -from polysh import display_names +from polysh.exceptions import ExitNow options = None # type: Optional[Namespace] # Either the remote shell is expecting a command or one is already running -STATE_NAMES = ["not_started", "idle", "running", "terminated", "dead"] +STATE_NAMES = ['not_started', 'idle', 'running', 'terminated', 'dead'] -STATE_NOT_STARTED, STATE_IDLE, STATE_RUNNING, STATE_TERMINATED, STATE_DEAD = list( - range(len(STATE_NAMES)) +STATE_NOT_STARTED, STATE_IDLE, STATE_RUNNING, STATE_TERMINATED, STATE_DEAD = ( + list(range(len(STATE_NAMES))) ) # Terminal color codes @@ -52,7 +50,7 @@ def main_loop_iteration(timeout: Optional[float] = None) -> int: """Return the number of RemoteDispatcher.handle_read() calls made by this iteration""" prev_nr_read = nr_handle_read - asyncore.loop(count=1, timeout=timeout, use_poll=True) + event_loop.loop_iteration(timeout=timeout) return nr_handle_read - prev_nr_read @@ -63,9 +61,9 @@ def log(msg: bytes) -> None: try: written = os.write(fd, msg) except OSError as e: - print("Exception while writing log:", options.log_file.name) + print('Exception while writing log:', options.log_file.name) print(e) - raise asyncore.ExitNow(1) + raise ExitNow(1) msg = msg[written:] @@ -73,10 +71,10 @@ class RemoteDispatcher(BufferedDispatcher): """A RemoteDispatcher is a ssh process we communicate with""" def __init__(self, hostname: str, port: str) -> None: - if port != "22": - port = "-p " + port + if port != '22': + port = '-p ' + port else: - port = "" + port = '' self.pid, fd = pty.fork() if self.pid == 0: @@ -97,9 +95,9 @@ def __init__(self, hostname: str, port: str) -> None: self.change_name(self.hostname.encode()) self.init_string = self.configure_tty() + self.set_prompt() self.init_string_sent = False - self.read_in_state_not_started = b"" + self.read_in_state_not_started = b'' self.command = options.command - self.last_printed_line = b"" + self.last_printed_line = b'' self.color_code = None if sys.stdout.isatty() and not options.disable_color: COLORS.insert(0, COLORS.pop()) # Rotate the colors @@ -108,11 +106,11 @@ def __init__(self, hostname: str, port: str) -> None: def launch_ssh(self, name: str, port: str) -> None: """Launch the ssh command in the child process""" if options.user: - name = "%s@%s" % (options.user, name) - evaluated = options.ssh % {"host": name, "port": port} + name = '%s@%s' % (options.user, name) + evaluated = options.ssh % {'host': name, 'port': port} if evaluated == options.ssh: - evaluated = "%s %s" % (evaluated, name) - os.execlp("/bin/sh", "sh", "-c", evaluated) + evaluated = '%s %s' % (evaluated, name) + os.execlp('/bin/sh', 'sh', '-c', evaluated) def set_enabled(self, enabled: bool) -> None: if enabled != self.enabled and options.interactive: @@ -126,9 +124,9 @@ def change_state(self, state: int) -> None: """Change the state of the remote process, logging the change""" if state is not self.state: if self.debug: - self.print_debug(b"state => " + STATE_NAMES[state].encode()) + self.print_debug(b'state => ' + STATE_NAMES[state].encode()) if self.state is STATE_NOT_STARTED: - self.read_in_state_not_started = b"" + self.read_in_state_not_started = b'' self.state = state def disconnect(self) -> None: @@ -138,14 +136,14 @@ def disconnect(self) -> None: except OSError: # The process was already dead, no problem pass - self.read_buffer = b"" - self.write_buffer = b"" + self.read_buffer = b'' + self.write_buffer = b'' self.set_enabled(False) if self.read_in_state_not_started: self.print_lines(self.read_in_state_not_started) - self.read_in_state_not_started = b"" + self.read_in_state_not_started = b'' if options.abort_error and self.state is STATE_NOT_STARTED: - raise asyncore.ExitNow(1) + raise ExitNow(1) self.change_state(STATE_DEAD) def configure_tty(self) -> bytes: @@ -164,22 +162,22 @@ def seen_prompt_cb(self, unused: str) -> None: if options.interactive: self.change_state(STATE_IDLE) elif self.command: - p1, p2 = callbacks.add(b"real prompt ends", lambda d: None, True) + p1, p2 = callbacks.add(b'real prompt ends', lambda d: None, True) self.dispatch_command(b'PS1="' + p1 + b'""' + p2 + b'\n"\n') - self.dispatch_command(self.command.encode() + b"\n") - self.dispatch_command(b"exit 2>/dev/null\n") + self.dispatch_command(self.command.encode() + b'\n') + self.dispatch_command(b'exit 2>/dev/null\n') self.command = None def set_prompt(self) -> bytes: """The prompt is important because we detect the readyness of a process by waiting for its prompt.""" # No right prompt - command_line = b"PS2=;RPS1=;RPROMPT=;" - command_line += b"PROMPT_COMMAND=;" - command_line += b"TERM=ansi;" - command_line += b"unset precmd_functions;" - command_line += b"unset HISTFILE;" - prompt1, prompt2 = callbacks.add(b"prompt", self.seen_prompt_cb, True) + command_line = b'PS2=;RPS1=;RPROMPT=;' + command_line += b'PROMPT_COMMAND=;' + command_line += b'TERM=ansi;' + command_line += b'unset precmd_functions;' + command_line += b'unset HISTFILE;' + prompt1, prompt2 = callbacks.add(b'prompt', self.seen_prompt_cb, True) command_line += b'PS1="' + prompt1 + b'""' + prompt2 + b'\n"\n' return command_line @@ -190,24 +188,23 @@ def readable(self) -> bool: def handle_expt(self) -> None: # Dirty hack to ignore POLLPRI flag that is raised on Mac OS, but not - # on linux. asyncore calls this method in case POLLPRI flag is set, but - # self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) == 0 - if platform.system() == "Darwin" and select.POLLPRI: + # on linux. This method is kept for compatibility but the new selectors + # implementation doesn't call it - we handle this in handle_close() + if platform.system() == 'Darwin' and select.POLLPRI: return self.handle_close() def handle_close(self) -> None: if self.state is STATE_DEAD: - # This connection has already been killed. Asyncore has probably - # called handle_close() or handle_expt() on this connection twice. + # This connection has already been killed. return pid, status = os.waitpid(self.pid, 0) exit_code = os.WEXITSTATUS(status) options.exit_code = max(options.exit_code, exit_code) if exit_code and options.interactive: - console_output("Error talking to {}\n".format(self.display_name).encode()) + console_output(f'Error talking to {self.display_name}\n'.encode()) self.disconnect() if self.temporary: self.close() @@ -215,32 +212,36 @@ def handle_close(self) -> None: def print_lines(self, lines: bytes) -> None: from polysh.display_names import max_display_name_length - lines = lines.strip(b"\n") + lines = lines.strip(b'\n') while True: - no_empty_lines = lines.replace(b"\n\n", b"\n") + no_empty_lines = lines.replace(b'\n\n', b'\n') if len(no_empty_lines) == len(lines): break lines = no_empty_lines if not lines: return indent = max_display_name_length - len(self.display_name) - log_prefix = self.display_name.encode() + indent * b" " + b" : " + log_prefix = self.display_name.encode() + indent * b' ' + b' : ' if self.color_code is None: console_prefix = log_prefix else: console_prefix = ( - b"\033[1;" + b'\033[1;' + str(self.color_code).encode() - + b"m" + + b'm' + log_prefix - + b"\033[1;m" + + b'\033[1;m' ) console_data = ( - console_prefix + lines.replace(b"\n", b"\n" + console_prefix) + b"\n" + console_prefix + + lines.replace(b'\n', b'\n' + console_prefix) + + b'\n' + ) + log_data = ( + log_prefix + lines.replace(b'\n', b'\n' + log_prefix) + b'\n' ) - log_data = log_prefix + lines.replace(b"\n", b"\n" + log_prefix) + b"\n" console_output(console_data, logging_msg=log_data) - self.last_printed_line = lines[lines.rfind(b"\n") + 1 :] + self.last_printed_line = lines[lines.rfind(b'\n') + 1 :] def handle_read_fast_case(self, data: bytes) -> bool: """If we are in a fast case we'll avoid the long processing of each @@ -249,7 +250,7 @@ def handle_read_fast_case(self, data: bytes) -> bool: # Slow case :-( return False - last_nl = data.rfind(b"\n") + last_nl = data.rfind(b'\n') if last_nl == -1: # No '\n' in data => slow case return False @@ -266,10 +267,10 @@ def handle_read(self) -> None: nr_handle_read += 1 new_data = self._handle_read_chunk() if self.debug: - self.print_debug(b"==> " + new_data) + self.print_debug(b'==> ' + new_data) if self.handle_read_fast_case(self.read_buffer): return - lf_pos = new_data.find(b"\n") + lf_pos = new_data.find(b'\n') if lf_pos >= 0: # Optimization: we knew there were no '\n' in the previous read # buffer, so we searched only in the new_data and we offset the @@ -278,10 +279,10 @@ def handle_read(self) -> None: elif ( self.state is STATE_NOT_STARTED and options.password is not None - and b"password:" in self.read_buffer.lower() + and b'password:' in self.read_buffer.lower() ): - self.dispatch_write("{}\n".format(options.password).encode()) - self.read_buffer = b"" + self.dispatch_write(f'{options.password}\n'.encode()) + self.read_buffer = b'' return while lf_pos >= 0: # For each line in the buffer @@ -292,24 +293,25 @@ def handle_read(self) -> None: self.print_lines(line) elif self.state is STATE_NOT_STARTED: self.read_in_state_not_started += line - if b"The authenticity of host" in line: - msg = line.strip(b"\n") + b" Closing connection." + if b'The authenticity of host' in line: + msg = line.strip(b'\n') + b' Closing connection.' self.disconnect() - elif b"REMOTE HOST IDENTIFICATION HAS CHANGED" in line: - msg = b"Remote host identification has changed." + elif b'REMOTE HOST IDENTIFICATION HAS CHANGED' in line: + msg = b'Remote host identification has changed.' else: msg = None if msg: self.print_lines( - msg + b" Consider manually connecting or " b"using ssh-keyscan." + msg + b' Consider manually connecting or ' + b'using ssh-keyscan.' ) # Go to the next line in the buffer self.read_buffer = self.read_buffer[lf_pos + 1 :] if self.handle_read_fast_case(self.read_buffer): return - lf_pos = self.read_buffer.find(b"\n") + lf_pos = self.read_buffer.find(b'\n') if self.state is STATE_NOT_STARTED and not self.init_string_sent: self.dispatch_write(self.init_string) self.init_string_sent = True @@ -319,7 +321,7 @@ def print_unfinished_line(self) -> None: if self.state is STATE_RUNNING: if not callbacks.process(self.read_buffer): self.print_lines(self.read_buffer) - self.read_buffer = b"" + self.read_buffer = b'' def writable(self) -> bool: """Do we want to write something?""" @@ -330,22 +332,28 @@ def handle_write(self) -> None: num_sent = self.send(self.write_buffer) if self.debug: if self.state is not STATE_NOT_STARTED or options.password is None: - self.print_debug(b"<== " + self.write_buffer[:num_sent]) + self.print_debug(b'<== ' + self.write_buffer[:num_sent]) self.write_buffer = self.write_buffer[num_sent:] def print_debug(self, msg: bytes) -> None: """Log some debugging information to the console""" state = STATE_NAMES[self.state].encode() console_output( - b"[dbg] " + self.display_name.encode() + b"[" + state + b"]: " + msg + b"\n" + b'[dbg] ' + + self.display_name.encode() + + b'[' + + state + + b']: ' + + msg + + b'\n' ) def get_info(self) -> List[bytes]: """Return a list with all information available about this process""" return [ self.display_name.encode(), - self.enabled and b"enabled" or b"disabled", - STATE_NAMES[self.state].encode() + b":", + self.enabled and b'enabled' or b'disabled', + STATE_NAMES[self.state].encode() + b':', self.last_printed_line.strip(), ] @@ -373,9 +381,17 @@ def rename(self, name: bytes) -> None: """Send to the remote shell, its new name to be shell expanded""" if name: # defug callback add? - rename1, rename2 = callbacks.add(b"rename", self.change_name, False) + rename1, rename2 = callbacks.add( + b'rename', self.change_name, False + ) self.dispatch_command( - b'/bin/echo "' + rename1 + b'""' + rename2 + b'"' + name + b"\n" + b'/bin/echo "' + + rename1 + + b'""' + + rename2 + + b'"' + + name + + b'\n' ) else: self.change_name(self.hostname.encode()) diff --git a/src/polysh/stdin.py b/src/polysh/stdin.py index cafadb4..beeb766 100644 --- a/src/polysh/stdin.py +++ b/src/polysh/stdin.py @@ -16,8 +16,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import asyncore import errno +import fcntl import os import readline # Just to say we want to use it with raw_input import signal @@ -26,22 +26,27 @@ import sys import tempfile import termios -from threading import Thread, Event, Lock +from threading import Event, Lock, Thread +from typing import Optional -from polysh import dispatchers, remote_dispatcher +from polysh import ( + completion, + dispatcher_registry, + dispatchers, + remote_dispatcher, +) from polysh.console import console_output, set_last_status_length -from polysh import completion -from typing import Optional +from polysh.exceptions import ExitNow the_stdin_thread = None # type: StdinThread -class InputBuffer(object): +class InputBuffer: """The shared input buffer between the main thread and the stdin thread""" def __init__(self) -> None: self.lock = Lock() - self.buf = b"" + self.buf = b'' def add(self, data: bytes) -> None: """Add data to the buffer""" @@ -50,9 +55,9 @@ def add(self, data: bytes) -> None: def get(self) -> bytes: """Get the content of the buffer""" - data = b"" + data = b'' with self.lock: - data, self.buf = self.buf, b"" + data, self.buf = self.buf, b'' return data @@ -63,43 +68,43 @@ def process_input_buffer() -> None: from polysh.control_commands_helpers import handle_control_command data = the_stdin_thread.input_buffer.get() - remote_dispatcher.log(b"> " + data) + remote_dispatcher.log(b'> ' + data) - if data.startswith(b":"): + if data.startswith(b':'): try: handle_control_command(data[1:-1].decode()) - except UnicodeDecodeError as e: - console_output(b"Could not decode command.") + except UnicodeDecodeError: + console_output(b'Could not decode command.') return - if data.startswith(b"!"): + if data.startswith(b'!'): try: retcode = subprocess.call(data[1:], shell=True) except OSError as e: if e.errno == errno.EINTR: - console_output(b"Child was interrupted\n") + console_output(b'Child was interrupted\n') retcode = 0 else: raise if retcode > 128 and retcode <= 192: retcode = 128 - retcode if retcode > 0: - console_output("Child returned {:d}\n".format(retcode).encode()) + console_output(f'Child returned {retcode:d}\n'.encode()) elif retcode < 0: console_output( - "Child was terminated by signal {:d}\n".format(-retcode).encode() + f'Child was terminated by signal {-retcode:d}\n'.encode() ) return for r in dispatchers.all_instances(): try: r.dispatch_command(data) - except asyncore.ExitNow as e: + except ExitNow as e: raise e except Exception as msg: raise msg console_output( - "{} for {}, disconnecting\n".format(str(msg), r.display_name).encode() + f'{str(msg)} for {r.display_name}, disconnecting\n'.encode() ) r.disconnect() else: @@ -116,32 +121,82 @@ def process_input_buffer() -> None: # sends the ACK, and the stdin thread can go on. -class SocketNotificationReader(asyncore.dispatcher): +class SocketDispatcher: + """Base dispatcher class for socket-based communication.""" + + def __init__(self, sock: socket.socket) -> None: + self.socket = sock + self.fd = sock.fileno() + + # Set non-blocking mode + flags = fcntl.fcntl(self.fd, fcntl.F_GETFL) + fcntl.fcntl(self.fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + # Register with the dispatcher registry + dispatcher_registry.register(self.fd, self) + + def recv(self, buffer_size: int) -> bytes: + """Read from the socket.""" + return self.socket.recv(buffer_size) + + def send(self, data: bytes) -> int: + """Write to the socket.""" + return self.socket.send(data) + + def close(self) -> None: + """Unregister and close the socket.""" + dispatcher_registry.unregister(self.fd) + try: + self.socket.close() + except OSError: + pass + + def readable(self) -> bool: + """Override in subclass.""" + return True + + def writable(self) -> bool: + """Override in subclass.""" + return False + + def handle_read(self) -> None: + """Override in subclass.""" + pass + + def handle_write(self) -> None: + """Override in subclass.""" + pass + + def handle_close(self) -> None: + """Handle connection close.""" + self.close() + + +class SocketNotificationReader(SocketDispatcher): """The socket reader in the main thread""" - def __init__(self, the_stdin_thread: "StdinThread") -> None: - asyncore.dispatcher.__init__(self, the_stdin_thread.socket_read) + def __init__(self, the_stdin_thread: 'StdinThread') -> None: + super().__init__(the_stdin_thread.socket_read) def _do(self, c: bytes) -> None: - if c == b"d": + if c == b'd': process_input_buffer() else: - raise Exception("Unknown code: %s" % (c)) + raise Exception('Unknown code: %s' % (c)) def handle_read(self) -> None: """Handle all the available character commands in the socket""" while True: try: c = self.recv(1) - except socket.error as e: + except OSError as e: if e.errno == errno.EWOULDBLOCK: return - else: - raise + raise else: self._do(c) self.socket.setblocking(True) - self.send(b"A") + self.send(b'A') self.socket.setblocking(False) def writable(self) -> bool: @@ -155,7 +210,7 @@ def write_main_socket(c: bytes) -> None: while True: try: the_stdin_thread.socket_write.recv(1) - except socket.error as e: + except OSError as e: if e.errno != errno.EINTR: raise else: @@ -169,7 +224,7 @@ def write_main_socket(c: bytes) -> None: # a newline tempfile_fd, tempfile_name = tempfile.mkstemp() os.remove(tempfile_name) -os.write(tempfile_fd, b"\x03") +os.write(tempfile_fd, b'\x03') def get_stdin_pid(cached_result: Optional[int] = None) -> int: @@ -177,7 +232,7 @@ def get_stdin_pid(cached_result: Optional[int] = None) -> int: ID""" if cached_result is None: try: - tasks = os.listdir("/proc/self/task") + tasks = os.listdir('/proc/self/task') except OSError as e: if e.errno != errno.ENOENT: raise @@ -227,7 +282,7 @@ class StdinThread(Thread): """The stdin thread, used to call raw_input()""" def __init__(self, interactive: bool) -> None: - Thread.__init__(self, name="stdin thread") + Thread.__init__(self, name='stdin thread') completion.install_completion_handler() self.input_buffer = InputBuffer() @@ -254,9 +309,9 @@ def prepend_previous_text(self) -> None: def want_raw_input(self) -> None: nr, total = dispatchers.count_awaited_processes() if nr: - prompt = "waiting (%d/%d)> " % (nr, total) + prompt = 'waiting (%d/%d)> ' % (nr, total) else: - prompt = "ready (%d)> " % total + prompt = 'ready (%d)> ' % total self.prompt = prompt set_last_status_length(len(prompt)) self.raw_input_wanted.set() @@ -296,5 +351,5 @@ def run(self) -> None: completion.remove_last_history_item() set_echo(True) if cmd is not None: - self.input_buffer.add("{}\n".format(cmd).encode()) - write_main_socket(b"d") + self.input_buffer.add(f'{cmd}\n'.encode()) + write_main_socket(b'd') From d30b18f968e99727561a8ca71622d4f47199e91d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCrkan=20G=C3=BCr?= Date: Fri, 27 Mar 2026 15:54:33 +0100 Subject: [PATCH 3/5] Add benchmark, excluded from pytest --- tests/benchmark.py | 342 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100755 tests/benchmark.py diff --git a/tests/benchmark.py b/tests/benchmark.py new file mode 100755 index 0000000..2885150 --- /dev/null +++ b/tests/benchmark.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +"""Benchmark script for comparing polysh event loop implementations. + +Compares the old asyncore-based event loop (master branch) against the new +selectors-based event loop (g_asyncio branch) by running polysh in +non-interactive mode and measuring wall-clock time. + +Usage examples: + # Compare two binaries with explicit server list + ./benchmark.py --binary-old ./result-old/bin/polysh \ + --binary-new ./result-new/bin/polysh \ + server1 server2 server3 + + # Single binary benchmark (just timing, no comparison) + ./benchmark.py --binary ./result/bin/polysh server1 server2 + + # Use adminapi to resolve servers + ./benchmark.py --binary-old ./result-old/bin/polysh \ + --binary-new ./result-new/bin/polysh \ + --adminapi 'project=grepo game_market=xx servertype=vm' + + # Custom command and more runs + ./benchmark.py --binary-old ./result-old/bin/polysh \ + --binary-new ./result-new/bin/polysh \ + --command 'cat /proc/loadavg' \ + --runs 10 \ + server1 server2 + + # With custom ssh command + ./benchmark.py --binary ./result/bin/polysh \ + --ssh 'exec ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -oLogLevel=Quiet -t %(host)s %(port)s' \ + server1 server2 +""" + +import argparse +import json +import math +import os +import resource +import shutil +import subprocess +import sys +import time + + +def resolve_adminapi(query): + """Resolve server names using adminapi CLI tool.""" + try: + result = subprocess.run( + ['adminapi', query], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + print(f'adminapi failed: {result.stderr.strip()}', file=sys.stderr) + sys.exit(1) + servers = result.stdout.strip().split() + if not servers: + print('adminapi returned no servers', file=sys.stderr) + sys.exit(1) + return servers + except FileNotFoundError: + print('adminapi not found in PATH', file=sys.stderr) + sys.exit(1) + + +def build_polysh_cmd(binary, ssh_cmd, command, servers): + """Build the polysh command line.""" + cmd = [binary] + if ssh_cmd: + cmd.extend(['--ssh', ssh_cmd]) + cmd.extend(['--command', command]) + cmd.extend(servers) + return cmd + + +def run_single(cmd, run_number, quiet=False): + """Run a single polysh invocation and return wall-clock time in seconds.""" + if not quiet: + print(f' Run {run_number}... ', end='', flush=True) + + start = time.monotonic() + try: + result = subprocess.run( + cmd, + capture_output=True, + timeout=300, # 5 minute timeout per run + ) + except subprocess.TimeoutExpired: + if not quiet: + print('TIMEOUT') + return None + + elapsed = time.monotonic() - start + + if not quiet: + if result.returncode != 0: + print(f'{elapsed:.3f}s (exit code {result.returncode})') + else: + print(f'{elapsed:.3f}s') + + return elapsed + + +def run_benchmark(label, cmd, runs, warmup): + """Run multiple iterations of a benchmark and collect timings.""" + print(f'\n--- {label} ---') + print(f'Command: {" ".join(cmd)}') + print(f'Warmup runs: {warmup}, Measured runs: {runs}') + + # Warmup + for i in range(warmup): + print(f' Warmup {i + 1}... ', end='', flush=True) + t = run_single(cmd, i + 1, quiet=True) + if t is not None: + print(f'{t:.3f}s') + else: + print('TIMEOUT') + + # Measured runs + timings = [] + for i in range(runs): + t = run_single(cmd, i + 1) + if t is not None: + timings.append(t) + + return timings + + +def compute_stats(timings): + """Compute basic statistics from a list of timings.""" + if not timings: + return None + + n = len(timings) + mean = sum(timings) / n + sorted_t = sorted(timings) + median = sorted_t[n // 2] if n % 2 else (sorted_t[n // 2 - 1] + sorted_t[n // 2]) / 2 + + if n > 1: + variance = sum((t - mean) ** 2 for t in timings) / (n - 1) + stddev = math.sqrt(variance) + else: + stddev = 0.0 + + return { + 'n': n, + 'mean': mean, + 'median': median, + 'min': sorted_t[0], + 'max': sorted_t[-1], + 'stddev': stddev, + } + + +def print_stats(label, stats): + """Print statistics for a benchmark run.""" + if stats is None: + print(f'\n{label}: No successful runs') + return + + print(f'\n{label} ({stats["n"]} runs):') + print(f' Mean: {stats["mean"]:.3f}s') + print(f' Median: {stats["median"]:.3f}s') + print(f' Min: {stats["min"]:.3f}s') + print(f' Max: {stats["max"]:.3f}s') + print(f' Stddev: {stats["stddev"]:.3f}s') + + +def print_comparison(stats_old, stats_new): + """Print comparison between old and new implementations.""" + if stats_old is None or stats_new is None: + print('\nCannot compare: one or both benchmarks had no successful runs') + return + + diff = stats_new['mean'] - stats_old['mean'] + if stats_old['mean'] > 0: + pct = (diff / stats_old['mean']) * 100 + else: + pct = 0.0 + + print('\n=== Comparison ===') + print(f' Old mean: {stats_old["mean"]:.3f}s') + print(f' New mean: {stats_new["mean"]:.3f}s') + print(f' Diff: {diff:+.3f}s ({pct:+.1f}%)') + + if diff < 0: + print(f' New is {abs(pct):.1f}% faster') + elif diff > 0: + print(f' New is {pct:.1f}% slower') + else: + print(' No difference') + + # Also compare medians + diff_med = stats_new['median'] - stats_old['median'] + if stats_old['median'] > 0: + pct_med = (diff_med / stats_old['median']) * 100 + else: + pct_med = 0.0 + print(f' Median diff: {diff_med:+.3f}s ({pct_med:+.1f}%)') + + +def main(): + parser = argparse.ArgumentParser( + description='Benchmark polysh event loop implementations', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Binary selection - mutually exclusive groups + binary_group = parser.add_mutually_exclusive_group(required=True) + binary_group.add_argument( + '--binary', + help='Single polysh binary to benchmark (timing only, no comparison)', + ) + binary_group.add_argument( + '--binary-old', + help='Path to old (asyncore) polysh binary', + ) + + parser.add_argument( + '--binary-new', + help='Path to new (selectors) polysh binary (required with --binary-old)', + ) + + # Server selection + parser.add_argument( + 'servers', + nargs='*', + help='Server hostnames to connect to', + ) + parser.add_argument( + '--adminapi', + help='adminapi query string to resolve server names', + ) + + # Polysh options + parser.add_argument( + '--ssh', + default='exec ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -oLogLevel=Quiet -t %(host)s %(port)s', + help='SSH command template (default: %(default)s)', + ) + parser.add_argument( + '--command', + default='hostname', + help='Command to execute on remote hosts (default: %(default)s)', + ) + + # Benchmark parameters + parser.add_argument( + '--runs', + type=int, + default=5, + help='Number of measured runs (default: %(default)s)', + ) + parser.add_argument( + '--warmup', + type=int, + default=1, + help='Number of warmup runs (default: %(default)s)', + ) + parser.add_argument( + '--json', + metavar='FILE', + help='Write results as JSON to FILE', + ) + + args = parser.parse_args() + + # Validate binary arguments + if args.binary_old and not args.binary_new: + parser.error('--binary-new is required when using --binary-old') + + # Resolve servers + servers = list(args.servers) if args.servers else [] + if args.adminapi: + servers.extend(resolve_adminapi(args.adminapi)) + if not servers: + parser.error('No servers specified. Use positional args or --adminapi') + + # Validate binaries exist + binaries = {} + if args.binary: + if not os.path.isfile(args.binary) or not os.access(args.binary, os.X_OK): + print(f'Binary not found or not executable: {args.binary}', file=sys.stderr) + sys.exit(1) + binaries['single'] = args.binary + else: + for label, path in [('old', args.binary_old), ('new', args.binary_new)]: + if not os.path.isfile(path) or not os.access(path, os.X_OK): + print(f'Binary not found or not executable: {path}', file=sys.stderr) + sys.exit(1) + binaries[label] = path + + print(f'Servers ({len(servers)}): {" ".join(servers[:10])}{"..." if len(servers) > 10 else ""}') + print(f'Command: {args.command}') + print(f'SSH: {args.ssh}') + print(f'Runs: {args.runs} (+ {args.warmup} warmup)') + + results = {} + + if 'single' in binaries: + cmd = build_polysh_cmd(binaries['single'], args.ssh, args.command, servers) + timings = run_benchmark('Polysh', cmd, args.runs, args.warmup) + stats = compute_stats(timings) + print_stats('Polysh', stats) + results['single'] = {'timings': timings, 'stats': stats} + else: + # Run old first, then new + cmd_old = build_polysh_cmd(binaries['old'], args.ssh, args.command, servers) + timings_old = run_benchmark('Old (asyncore)', cmd_old, args.runs, args.warmup) + stats_old = compute_stats(timings_old) + + cmd_new = build_polysh_cmd(binaries['new'], args.ssh, args.command, servers) + timings_new = run_benchmark('New (selectors)', cmd_new, args.runs, args.warmup) + stats_new = compute_stats(timings_new) + + print_stats('Old (asyncore)', stats_old) + print_stats('New (selectors)', stats_new) + print_comparison(stats_old, stats_new) + + results['old'] = {'timings': timings_old, 'stats': stats_old} + results['new'] = {'timings': timings_new, 'stats': stats_new} + + # Write JSON output if requested + if args.json: + json_data = { + 'servers': servers, + 'command': args.command, + 'ssh': args.ssh, + 'runs': args.runs, + 'warmup': args.warmup, + 'results': results, + } + with open(args.json, 'w') as f: + json.dump(json_data, f, indent=2) + print(f'\nJSON results written to {args.json}') + + +if __name__ == '__main__': + main() From f5d89d32a4e93b0c32d63ef287d61badc7d76cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCrkan=20G=C3=BCr?= Date: Fri, 27 Mar 2026 16:31:31 +0100 Subject: [PATCH 4/5] Add unit tests related to new functions --- tests/test_dispatcher_registry.py | 208 +++++++++++++++++++++++++++ tests/test_event_loop.py | 224 ++++++++++++++++++++++++++++++ 2 files changed, 432 insertions(+) create mode 100644 tests/test_dispatcher_registry.py create mode 100644 tests/test_event_loop.py diff --git a/tests/test_dispatcher_registry.py b/tests/test_dispatcher_registry.py new file mode 100644 index 0000000..4d9e1cd --- /dev/null +++ b/tests/test_dispatcher_registry.py @@ -0,0 +1,208 @@ +"""Polysh - Tests - Dispatcher Registry + +Unit tests for the selectors-based dispatcher registry. + +Copyright (c) 2024 InnoGames GmbH +""" +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import selectors +import unittest + +from polysh import dispatcher_registry + + +class FakeDispatcher: + """Minimal dispatcher stub for testing.""" + + def __init__(self, fd): + self.fd = fd + self._readable = True + self._writable = False + + def readable(self): + return self._readable + + def writable(self): + return self._writable + + +class TestDispatcherRegistry(unittest.TestCase): + def setUp(self): + """Reset module-level global state and track fds for cleanup.""" + self._fds = [] + dispatcher_registry._dispatchers.clear() + dispatcher_registry._current_events.clear() + dispatcher_registry._selector.close() + dispatcher_registry._selector = selectors.DefaultSelector() + + def tearDown(self): + """Close all pipe fds opened during the test.""" + # Unregister any remaining dispatchers first to avoid selector complaints + for fd in list(dispatcher_registry._dispatchers): + dispatcher_registry.unregister(fd) + dispatcher_registry._selector.close() + dispatcher_registry._selector = selectors.DefaultSelector() + for fd in self._fds: + try: + os.close(fd) + except OSError: + pass + + def _make_pipe(self): + """Create a pipe and track both fds for cleanup.""" + r, w = os.pipe() + self._fds.extend([r, w]) + return r, w + + def test_register_and_get_dispatcher(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + self.assertIs(dispatcher_registry.get_dispatcher(r), d) + # Selector should have it registered with EVENT_READ + key = dispatcher_registry.get_selector().get_key(r) + self.assertEqual(key.events, selectors.EVENT_READ) + self.assertIs(key.data, d) + + def test_register_sets_current_events(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + self.assertEqual( + dispatcher_registry._current_events[r], selectors.EVENT_READ + ) + + def test_unregister(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + dispatcher_registry.unregister(r) + + self.assertIsNone(dispatcher_registry.get_dispatcher(r)) + self.assertNotIn(r, dispatcher_registry._current_events) + with self.assertRaises(KeyError): + dispatcher_registry.get_selector().get_key(r) + + def test_unregister_unknown_fd_is_noop(self): + # Should not raise + dispatcher_registry.unregister(99999) + + def test_unregister_already_unregistered(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + dispatcher_registry.unregister(r) + # Second unregister should be a safe noop + dispatcher_registry.unregister(r) + + def test_modify_events_caching(self): + """modify_events should skip syscall when events unchanged.""" + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + # Initial state is EVENT_READ. Calling modify with same events is a noop. + dispatcher_registry.modify_events(r, selectors.EVENT_READ) + key = dispatcher_registry.get_selector().get_key(r) + self.assertEqual(key.events, selectors.EVENT_READ) + + def test_modify_events_read_to_readwrite(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + both = selectors.EVENT_READ | selectors.EVENT_WRITE + dispatcher_registry.modify_events(r, both) + + key = dispatcher_registry.get_selector().get_key(r) + self.assertEqual(key.events, both) + self.assertEqual(dispatcher_registry._current_events[r], both) + + def test_modify_events_to_zero_unregisters_from_selector(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + dispatcher_registry.modify_events(r, 0) + + # Should be unregistered from selector but still in _dispatchers + with self.assertRaises(KeyError): + dispatcher_registry.get_selector().get_key(r) + self.assertIs(dispatcher_registry.get_dispatcher(r), d) + self.assertEqual(dispatcher_registry._current_events[r], 0) + + def test_modify_events_zero_to_read_reregisters(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + # Unregister from selector + dispatcher_registry.modify_events(r, 0) + # Re-register + dispatcher_registry.modify_events(r, selectors.EVENT_READ) + + key = dispatcher_registry.get_selector().get_key(r) + self.assertEqual(key.events, selectors.EVENT_READ) + + def test_modify_events_unknown_fd_is_noop(self): + # Should not raise + dispatcher_registry.modify_events(99999, selectors.EVENT_READ) + + def test_all_dispatchers_returns_list_copy(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + result = dispatcher_registry.all_dispatchers() + self.assertIsInstance(result, list) + self.assertEqual(result, [d]) + # Mutating the returned list should not affect the registry + result.clear() + self.assertEqual(dispatcher_registry.all_dispatchers(), [d]) + + def test_iter_dispatchers(self): + r, w = self._make_pipe() + d = FakeDispatcher(r) + dispatcher_registry.register(r, d) + + result = list(dispatcher_registry.iter_dispatchers()) + self.assertEqual(result, [d]) + + def test_multiple_dispatchers(self): + r1, w1 = self._make_pipe() + r2, w2 = self._make_pipe() + d1 = FakeDispatcher(r1) + d2 = FakeDispatcher(r2) + dispatcher_registry.register(r1, d1) + dispatcher_registry.register(r2, d2) + + self.assertEqual(len(dispatcher_registry.all_dispatchers()), 2) + self.assertIs(dispatcher_registry.get_dispatcher(r1), d1) + self.assertIs(dispatcher_registry.get_dispatcher(r2), d2) + + dispatcher_registry.unregister(r1) + self.assertEqual(len(dispatcher_registry.all_dispatchers()), 1) + self.assertIsNone(dispatcher_registry.get_dispatcher(r1)) + self.assertIs(dispatcher_registry.get_dispatcher(r2), d2) + + def test_get_dispatcher_returns_none_for_unknown(self): + self.assertIsNone(dispatcher_registry.get_dispatcher(99999)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_event_loop.py b/tests/test_event_loop.py new file mode 100644 index 0000000..4ec610e --- /dev/null +++ b/tests/test_event_loop.py @@ -0,0 +1,224 @@ +"""Polysh - Tests - Event Loop + +Unit tests for the selectors-based event loop. + +Copyright (c) 2024 InnoGames GmbH +""" +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import selectors +import unittest + +from polysh import dispatcher_registry +from polysh.event_loop import loop_iteration + + +class FakeDispatcher: + """Dispatcher stub that records which handlers were called.""" + + def __init__(self, fd, readable=True, writable=False): + self.fd = fd + self._readable = readable + self._writable = writable + self.read_called = 0 + self.write_called = 0 + self.close_called = 0 + + def readable(self): + return self._readable + + def writable(self): + return self._writable + + def handle_read(self): + self.read_called += 1 + + def handle_write(self): + self.write_called += 1 + + def handle_close(self): + self.close_called += 1 + + +class FailingReadDispatcher(FakeDispatcher): + """Dispatcher whose handle_read raises an exception.""" + + def handle_read(self): + self.read_called += 1 + raise RuntimeError('read failed') + + +class FailingWriteDispatcher(FakeDispatcher): + """Dispatcher whose handle_write raises an exception.""" + + def handle_write(self): + self.write_called += 1 + raise RuntimeError('write failed') + + +class SelfUnregisteringDispatcher(FakeDispatcher): + """Dispatcher that unregisters itself during handle_read.""" + + def handle_read(self): + self.read_called += 1 + dispatcher_registry.unregister(self.fd) + + +class TestEventLoop(unittest.TestCase): + def setUp(self): + self._fds = [] + dispatcher_registry._dispatchers.clear() + dispatcher_registry._current_events.clear() + dispatcher_registry._selector.close() + dispatcher_registry._selector = selectors.DefaultSelector() + + def tearDown(self): + for fd in list(dispatcher_registry._dispatchers): + dispatcher_registry.unregister(fd) + dispatcher_registry._selector.close() + dispatcher_registry._selector = selectors.DefaultSelector() + for fd in self._fds: + try: + os.close(fd) + except OSError: + pass + + def _make_pipe(self): + r, w = os.pipe() + self._fds.extend([r, w]) + return r, w + + def test_readable_dispatcher_gets_handle_read(self): + r, w = self._make_pipe() + # Write something so the read end becomes ready + os.write(w, b'hello') + + d = FakeDispatcher(r, readable=True, writable=False) + dispatcher_registry.register(r, d) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d.read_called, 1) + self.assertEqual(d.write_called, 0) + + def test_writable_dispatcher_gets_handle_write(self): + r, w = self._make_pipe() + # Write end of a pipe is immediately writable + d = FakeDispatcher(w, readable=False, writable=True) + dispatcher_registry.register(w, d) + + loop_iteration(timeout=0.1) + self.assertEqual(d.read_called, 0) + self.assertGreaterEqual(d.write_called, 1) + + def test_both_readable_and_writable(self): + r, w = self._make_pipe() + os.write(w, b'data') + + # Use the read end for reading, write end for writing + dr = FakeDispatcher(r, readable=True, writable=False) + dw = FakeDispatcher(w, readable=False, writable=True) + dispatcher_registry.register(r, dr) + dispatcher_registry.register(w, dw) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(dr.read_called, 1) + self.assertGreaterEqual(dw.write_called, 1) + + def test_not_readable_not_writable_skipped(self): + r, w = self._make_pipe() + os.write(w, b'data') + + d = FakeDispatcher(r, readable=False, writable=False) + dispatcher_registry.register(r, d) + + loop_iteration(timeout=0.1) + self.assertEqual(d.read_called, 0) + self.assertEqual(d.write_called, 0) + + def test_no_dispatchers_does_not_raise(self): + # Empty registry, should just return after timeout + loop_iteration(timeout=0.01) + + def test_handle_read_exception_triggers_handle_close(self): + r, w = self._make_pipe() + os.write(w, b'data') + + d = FailingReadDispatcher(r, readable=True, writable=False) + dispatcher_registry.register(r, d) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d.read_called, 1) + self.assertGreaterEqual(d.close_called, 1) + + def test_handle_write_exception_triggers_handle_close(self): + r, w = self._make_pipe() + + d = FailingWriteDispatcher(w, readable=False, writable=True) + dispatcher_registry.register(w, d) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d.write_called, 1) + self.assertGreaterEqual(d.close_called, 1) + + def test_dispatcher_removed_during_handle_read_skips_handle_write(self): + """If handle_read unregisters the dispatcher, handle_write must not run.""" + r, w = self._make_pipe() + os.write(w, b'data') + + # Dispatcher that claims both readable and writable, but unregisters + # itself on read — handle_write must be skipped + d = SelfUnregisteringDispatcher(r, readable=True, writable=True) + dispatcher_registry.register(r, d) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d.read_called, 1) + self.assertEqual(d.write_called, 0) + + def test_events_updated_between_iterations(self): + """Dispatcher state changes should be reflected in the next iteration.""" + r, w = self._make_pipe() + os.write(w, b'data') + + d = FakeDispatcher(r, readable=True, writable=False) + dispatcher_registry.register(r, d) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d.read_called, 1) + + # Now make it not readable + read_count = d.read_called + d._readable = False + loop_iteration(timeout=0.05) + self.assertEqual(d.read_called, read_count) + + def test_multiple_dispatchers_independent(self): + """Multiple dispatchers should be handled independently.""" + r1, w1 = self._make_pipe() + r2, w2 = self._make_pipe() + os.write(w1, b'data1') + # r2 has no data, so not ready for read + + d1 = FakeDispatcher(r1, readable=True, writable=False) + d2 = FakeDispatcher(r2, readable=True, writable=False) + dispatcher_registry.register(r1, d1) + dispatcher_registry.register(r2, d2) + + loop_iteration(timeout=0.1) + self.assertGreaterEqual(d1.read_called, 1) + self.assertEqual(d2.read_called, 0) + + +if __name__ == '__main__': + unittest.main() From dd914b7440a502f0a4c5fba08e3df6b31d318825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCrkan=20G=C3=BCr?= Date: Fri, 27 Mar 2026 16:32:00 +0100 Subject: [PATCH 5/5] Bump version --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5ff0fe..984dba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "polysh" authors = [{ email = "it@innogames.com" }] -version = "0.15.0" +version = "0.16.0" description = "Control thousands of SSH sessions from a single prompt" readme = "README.rst" requires-python = ">=3.5,<=3.12" diff --git a/uv.lock b/uv.lock index 011963b..197c441 100644 --- a/uv.lock +++ b/uv.lock @@ -109,7 +109,7 @@ wheels = [ [[package]] name = "polysh" -version = "0.14.0" +version = "0.16.0" source = { editable = "." } [package.optional-dependencies]