diff --git a/.gitignore b/.gitignore
index 1255668..c27db58 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,3 +4,5 @@ polysh.egg-info
build/
dist/
.devcontainer/
+flake.lock
+result
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 0000000..d3809c3
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,60 @@
+# Flake for polysh, only used for easier development and testing on nix-based systems.
+# Might be used to build and test the package on multiple platforms, but is not intended/supported for production use.
+{
+ description = "Remote shell multiplexer for executing commands on multiple hosts";
+
+ inputs = {
+ nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
+ };
+
+ outputs = { self, nixpkgs }:
+ let
+ supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
+ forAllSystems = nixpkgs.lib.genAttrs supportedSystems;
+ in
+ {
+ packages = forAllSystems (system:
+ let
+ pkgs = nixpkgs.legacyPackages.${system};
+ in
+ {
+ polysh = pkgs.python3Packages.buildPythonApplication {
+ pname = "polysh";
+ version = "0.15";
+
+ pyproject = true;
+
+ build-system = with pkgs.python3Packages; [
+ hatchling
+ ];
+
+ src = ./.;
+
+ meta = with pkgs.lib; {
+ description = "Remote shell multiplexer for executing commands on multiple hosts";
+ homepage = "https://github.com/innogames/polysh";
+ license = licenses.gpl2Plus;
+ maintainers = with maintainers; [ seqizz ];
+ platforms = platforms.unix;
+ };
+ };
+
+ default = self.packages.${system}.polysh;
+ }
+ );
+
+ devShells = forAllSystems (system:
+ let
+ pkgs = nixpkgs.legacyPackages.${system};
+ in
+ {
+ default = pkgs.mkShell {
+ packages = with pkgs; [
+ python3
+ python3Packages.hatchling
+ ];
+ };
+ }
+ );
+ };
+}
diff --git a/pyproject.toml b/pyproject.toml
index d5ff0fe..984dba8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "polysh"
authors = [{ email = "it@innogames.com" }]
-version = "0.15.0"
+version = "0.16.0"
description = "Control thousands of SSH sessions from a single prompt"
readme = "README.rst"
requires-python = ">=3.5,<=3.12"
diff --git a/src/polysh/buffered_dispatcher.py b/src/polysh/buffered_dispatcher.py
index 66b4238..b404666 100644
--- a/src/polysh/buffered_dispatcher.py
+++ b/src/polysh/buffered_dispatcher.py
@@ -16,13 +16,16 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
import errno
+import fcntl
+import os
+from polysh import dispatcher_registry
from polysh.console import console_output
+from polysh.exceptions import ExitNow
-class BufferedDispatcher(asyncore.file_dispatcher):
+class BufferedDispatcher:
"""A dispatcher with a write buffer to allow asynchronous writers, and a
read buffer to permit line oriented manipulations"""
@@ -30,17 +33,39 @@ class BufferedDispatcher(asyncore.file_dispatcher):
MAX_BUFFER_SIZE = 1 * 1024 * 1024
def __init__(self, fd: int) -> None:
- asyncore.file_dispatcher.__init__(self, fd)
self.fd = fd
- self.read_buffer = b""
- self.write_buffer = b""
+ self.read_buffer = b''
+ self.write_buffer = b''
+
+ # Set non-blocking mode
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ # Register with the dispatcher registry
+ dispatcher_registry.register(fd, self)
+
+ def recv(self, buffer_size: int) -> bytes:
+ """Read from the file descriptor."""
+ return os.read(self.fd, buffer_size)
+
+ def send(self, data: bytes) -> int:
+ """Write to the file descriptor."""
+ return os.write(self.fd, data)
+
+ def close(self) -> None:
+ """Unregister and close the file descriptor."""
+ dispatcher_registry.unregister(self.fd)
+ try:
+ os.close(self.fd)
+ except OSError:
+ pass
def handle_read(self) -> None:
self._handle_read_chunk()
def _handle_read_chunk(self) -> bytes:
"""Some data can be read"""
- new_data = b""
+ new_data = b''
buffer_length = len(self.read_buffer)
try:
while buffer_length < self.MAX_BUFFER_SIZE:
@@ -50,12 +75,11 @@ def _handle_read_chunk(self) -> bytes:
if e.errno == errno.EAGAIN:
# End of the available data
break
- elif e.errno == errno.EIO and new_data:
+ if e.errno == errno.EIO and new_data:
# Hopefully we could read an error message before the
# actual termination
break
- else:
- raise
+ raise
if not piece:
# A closed connection is indicated by signaling a read
@@ -66,7 +90,7 @@ def _handle_read_chunk(self) -> bytes:
buffer_length += len(piece)
finally:
- new_data = new_data.replace(b"\r", b"\n")
+ new_data = new_data.replace(b'\r', b'\n')
self.read_buffer += new_data
return new_data
@@ -76,16 +100,14 @@ def readable(self) -> bool:
def writable(self) -> bool:
"""Do we have something to write?"""
- return self.write_buffer != b""
+ return self.write_buffer != b''
def dispatch_write(self, buf: bytes) -> bool:
"""Augment the buffer with stuff to write when possible"""
self.write_buffer += buf
if len(self.write_buffer) > self.MAX_BUFFER_SIZE:
console_output(
- "Buffer too big ({:d}) for {}\n".format(
- len(self.write_buffer), str(self)
- ).encode()
+ f'Buffer too big ({len(self.write_buffer):d}) for {str(self)}\n'.encode()
)
- raise asyncore.ExitNow(1)
+ raise ExitNow(1)
return True
diff --git a/src/polysh/control_commands.py b/src/polysh/control_commands.py
index 02eb8fc..7242b72 100644
--- a/src/polysh/control_commands.py
+++ b/src/polysh/control_commands.py
@@ -18,22 +18,20 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
import os
import shlex
from typing import List
+from polysh import dispatchers, remote_dispatcher, stdin
+from polysh.completion import add_to_history, complete_local_path
+from polysh.console import console_output
from polysh.control_commands_helpers import (
complete_shells,
expand_local_path,
selected_shells,
toggle_shells,
)
-from polysh.completion import complete_local_path, add_to_history
-from polysh.console import console_output
-from polysh import dispatchers
-from polysh import remote_dispatcher
-from polysh import stdin
+from polysh.exceptions import ExitNow
def complete_list(line: str, text: str) -> List[str]:
@@ -43,11 +41,11 @@ def complete_list(line: str, text: str) -> List[str]:
def do_list(command: str) -> None:
instances = [i.get_info() for i in selected_shells(command)]
flat_instances = dispatchers.format_info(instances)
- console_output(b"".join(flat_instances))
+ console_output(b''.join(flat_instances))
def do_quit(command: str) -> None:
- raise asyncore.ExitNow(0)
+ raise ExitNow(0)
def complete_chdir(line: str, text: str) -> List[str]:
@@ -58,29 +56,29 @@ def do_chdir(command: str) -> None:
try:
os.chdir(expand_local_path(command.strip()))
except OSError as e:
- console_output("{}\n".format(str(e)).encode())
+ console_output(f'{str(e)}\n'.encode())
def complete_send_ctrl(line: str, text: str) -> List[str]:
if len(line[:-1].split()) >= 2:
# Control letter already given in command line
return complete_shells(line, text, lambda i: i.enabled)
- if text in ("c", "d", "z"):
- return [text + " "]
- return ["c ", "d ", "z "]
+ if text in ('c', 'd', 'z'):
+ return [text + ' ']
+ return ['c ', 'd ', 'z ']
def do_send_ctrl(command: str) -> None:
split = command.split()
if not split:
- console_output(b"Expected at least a letter\n")
+ console_output(b'Expected at least a letter\n')
return
letter = split[0]
if len(letter) != 1:
- console_output("Expected a single letter, got: {}\n".format(letter).encode())
+ console_output(f'Expected a single letter, got: {letter}\n'.encode())
return
- control_letter = chr(ord(letter.lower()) - ord("a") + 1)
- for i in selected_shells(" ".join(split[1:])):
+ control_letter = chr(ord(letter.lower()) - ord('a') + 1)
+ for i in selected_shells(' '.join(split[1:])):
if i.enabled:
i.dispatch_write(control_letter.encode())
@@ -122,7 +120,9 @@ def complete_reconnect(line: str, text: str) -> List[str]:
def do_reconnect(command: str) -> None:
selec = selected_shells(command)
- to_reconnect = [i for i in selec if i.state == remote_dispatcher.STATE_DEAD]
+ to_reconnect = [
+ i for i in selec if i.state == remote_dispatcher.STATE_DEAD
+ ]
for i in to_reconnect:
i.disconnect()
i.close()
@@ -162,13 +162,13 @@ def do_hide_password(command: str) -> None:
i.debug = False
if not warned:
console_output(
- b"Debugging disabled to avoid displaying " b"passwords\n"
+ b'Debugging disabled to avoid displaying passwords\n'
)
warned = True
stdin.set_echo(False)
if remote_dispatcher.options.log_file:
- console_output(b"Logging disabled to avoid writing passwords\n")
+ console_output(b'Logging disabled to avoid writing passwords\n')
remote_dispatcher.options.log_file = None
@@ -176,22 +176,22 @@ def complete_set_debug(line: str, text: str) -> List[str]:
if len(line[:-1].split()) >= 2:
# Debug value already given in command line
return complete_shells(line, text)
- if text.lower() in ("y", "n"):
- return [text + " "]
- return ["y ", "n "]
+ if text.lower() in ('y', 'n'):
+ return [text + ' ']
+ return ['y ', 'n ']
def do_set_debug(command: str) -> None:
split = command.split()
if not split:
- console_output(b"Expected at least a letter\n")
+ console_output(b'Expected at least a letter\n')
return
letter = split[0].lower()
- if letter not in ("y", "n"):
- console_output("Expected 'y' or 'n', got: {}\n".format(split[0]).encode())
+ if letter not in ('y', 'n'):
+ console_output(f"Expected 'y' or 'n', got: {split[0]}\n".encode())
return
- debug = letter == "y"
- for i in selected_shells(" ".join(split[1:])):
+ debug = letter == 'y'
+ for i in selected_shells(' '.join(split[1:])):
i.debug = debug
@@ -200,25 +200,25 @@ def do_export_vars(command: str) -> None:
for shell in dispatchers.all_instances():
if shell.enabled:
environment_variables = {
- "POLYSH_RANK": str(rank),
- "POLYSH_NAME": shell.hostname,
- "POLYSH_DISPLAY_NAME": shell.display_name,
+ 'POLYSH_RANK': str(rank),
+ 'POLYSH_NAME': shell.hostname,
+ 'POLYSH_DISPLAY_NAME': shell.display_name,
}
for name, value in environment_variables.items():
shell.dispatch_command(
- "export {}={}\n".format(name, shlex.quote(value)).encode()
+ f'export {name}={shlex.quote(value)}\n'.encode()
)
rank += 1
for shell in dispatchers.all_instances():
if shell.enabled:
shell.dispatch_command(
- "export POLYSH_NR_SHELLS={:d}\n".format(rank).encode()
+ f'export POLYSH_NR_SHELLS={rank:d}\n'.encode()
)
-add_to_history("$POLYSH_RANK $POLYSH_NAME $POLYSH_DISPLAY_NAME")
-add_to_history("$POLYSH_NR_SHELLS")
+add_to_history('$POLYSH_RANK $POLYSH_NAME $POLYSH_DISPLAY_NAME')
+add_to_history('$POLYSH_NR_SHELLS')
def complete_set_log(line: str, text: str) -> List[str]:
@@ -229,13 +229,13 @@ def do_set_log(command: str) -> None:
command = command.strip()
if command:
try:
- remote_dispatcher.options.log_file = open(command, "a")
- except IOError as e:
- console_output("{}\n".format(str(e)).encode())
+ remote_dispatcher.options.log_file = open(command, 'a')
+ except OSError as e:
+ console_output(f'{str(e)}\n'.encode())
command = None
if not command:
remote_dispatcher.options.log_file = None
- console_output(b"Logging disabled\n")
+ console_output(b'Logging disabled\n')
def complete_show_read_buffer(line: str, text: str) -> List[str]:
@@ -248,4 +248,4 @@ def do_show_read_buffer(command: str) -> None:
for i in selected_shells(command):
if i.read_in_state_not_started:
i.print_lines(i.read_in_state_not_started)
- i.read_in_state_not_started = b""
+ i.read_in_state_not_started = b''
diff --git a/src/polysh/dispatcher_registry.py b/src/polysh/dispatcher_registry.py
new file mode 100644
index 0000000..d808ce3
--- /dev/null
+++ b/src/polysh/dispatcher_registry.py
@@ -0,0 +1,112 @@
+"""Polysh - Dispatcher Registry
+
+Manages the global selector and dispatcher tracking, replacing asyncore's socket_map.
+
+Copyright (c) 2006 Guillaume Chazarain
+Copyright (c) 2024 InnoGames GmbH
+"""
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import selectors
+from typing import Dict, Any, List
+
+# Global selector instance
+_selector = selectors.DefaultSelector()
+
+# Mapping from file descriptor to dispatcher instance
+_dispatchers: Dict[int, Any] = {}
+
+# Track last-registered events per fd to avoid redundant epoll_ctl syscalls
+_current_events: Dict[int, int] = {}
+
+
+def register(fd: int, dispatcher: Any) -> None:
+ """Register a dispatcher with the selector.
+
+ Initially registers with EVENT_READ - events are updated dynamically
+ based on readable()/writable() before each select().
+ """
+ _dispatchers[fd] = dispatcher
+ # Register with EVENT_READ initially; loop_iteration will update as needed
+ _selector.register(fd, selectors.EVENT_READ, dispatcher)
+ _current_events[fd] = selectors.EVENT_READ
+
+
+def unregister(fd: int) -> None:
+ """Unregister a dispatcher from the selector."""
+ if fd in _dispatchers:
+ del _dispatchers[fd]
+ _current_events.pop(fd, None)
+ try:
+ _selector.unregister(fd)
+ except (KeyError, ValueError):
+ # Already unregistered or invalid fd
+ pass
+
+
+def modify_events(fd: int, events: int) -> None:
+ """Modify the events a dispatcher is interested in.
+
+ Skips the syscall if events haven't changed since last call.
+
+ If events is 0, the fd is temporarily unregistered from the selector
+ but kept in _dispatchers. It will be re-registered when events become non-zero.
+ """
+ if fd not in _dispatchers:
+ return
+
+ # Skip if events haven't changed — avoids unnecessary epoll_ctl syscall
+ if _current_events.get(fd, 0) == events:
+ return
+
+ old_events = _current_events.get(fd, 0)
+
+ if old_events == 0:
+ # Not currently registered in selector, need to register
+ if events != 0:
+ _selector.register(fd, events, _dispatchers[fd])
+ elif events == 0:
+ # No events - unregister temporarily
+ _selector.unregister(fd)
+ else:
+ # Modify existing registration
+ _selector.modify(fd, events, _dispatchers[fd])
+
+ _current_events[fd] = events
+
+
+def all_dispatchers() -> List[Any]:
+ """Return a snapshot list of all registered dispatchers.
+
+ Use iter_dispatchers() when mutation during iteration is not a concern.
+ """
+ return list(_dispatchers.values())
+
+
+def iter_dispatchers():
+ """Iterate dispatcher values directly, avoiding a list copy.
+
+ Only safe when the caller does not add/remove dispatchers during iteration.
+ """
+ return _dispatchers.values()
+
+
+def get_selector() -> selectors.BaseSelector:
+ """Return the global selector instance."""
+ return _selector
+
+
+def get_dispatcher(fd: int) -> Any:
+ """Return the dispatcher for a given file descriptor, or None."""
+ return _dispatchers.get(fd)
diff --git a/src/polysh/dispatchers.py b/src/polysh/dispatchers.py
index 4af0a61..170b097 100644
--- a/src/polysh/dispatchers.py
+++ b/src/polysh/dispatchers.py
@@ -16,27 +16,23 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
import fcntl
import struct
import sys
import termios
-from typing import List
-from typing import Tuple
+from typing import List, Tuple
-from polysh import remote_dispatcher
-from polysh import display_names
+from polysh import dispatcher_registry, display_names, remote_dispatcher
from polysh.terminal_size import terminal_size
def _split_port(hostname: str) -> Tuple[str, str]:
"""Splits a string(hostname, given by the user) into hostname and port,
returns a tuple"""
- s = hostname.split(":", 1)
+ s = hostname.split(':', 1)
if len(s) > 1:
return s[0], s[1]
- else:
- return s[0], "22"
+ return s[0], '22'
def all_instances() -> List[remote_dispatcher.RemoteDispatcher]:
@@ -44,10 +40,10 @@ def all_instances() -> List[remote_dispatcher.RemoteDispatcher]:
return sorted(
[
i
- for i in asyncore.socket_map.values()
+ for i in dispatcher_registry.all_dispatchers()
if isinstance(i, remote_dispatcher.RemoteDispatcher)
],
- key=lambda i: i.display_name or "",
+ key=lambda i: i.display_name or '',
)
@@ -84,8 +80,8 @@ def update_terminal_size() -> None:
w = max(w - display_names.max_display_name_length - 2, min(w, 10))
# python bug http://python.org/sf/1112949 on amd64
# from ajaxterm.py
- bug = struct.unpack("i", struct.pack("I", termios.TIOCSWINSZ))[0]
- packed_size = struct.pack("HHHH", h, w, 0, 0)
+ bug = struct.unpack('i', struct.pack('I', termios.TIOCSWINSZ))[0]
+ packed_size = struct.pack('HHHH', h, w, 0, 0)
term_size = w, h
for i in all_instances():
if i.enabled and i.term_size != term_size:
@@ -112,17 +108,17 @@ def format_info(info_list: List[List[bytes]]) -> List[bytes]:
# as it can get much longer in some shells than in others
orig_str = info[str_id]
indent = max_lengths[str_id] - len(orig_str)
- info[str_id] = orig_str + indent * b" "
- flattened_info_list.append(b" ".join(info) + b"\n")
+ info[str_id] = orig_str + indent * b' '
+ flattened_info_list.append(b' '.join(info) + b'\n')
return flattened_info_list
def create_remote_dispatchers(hosts: List[str]) -> None:
- last_message = ""
+ last_message = ''
for i, host in enumerate(hosts):
if remote_dispatcher.options.interactive:
- last_message = "Started %d/%d remote processes\r" % (i, len(hosts))
+ last_message = 'Started %d/%d remote processes\r' % (i, len(hosts))
sys.stdout.write(last_message)
sys.stdout.flush()
try:
@@ -133,5 +129,5 @@ def create_remote_dispatchers(hosts: List[str]) -> None:
raise
if last_message:
- sys.stdout.write(" " * len(last_message) + "\r")
+ sys.stdout.write(' ' * len(last_message) + '\r')
sys.stdout.flush()
diff --git a/src/polysh/event_loop.py b/src/polysh/event_loop.py
new file mode 100644
index 0000000..268b742
--- /dev/null
+++ b/src/polysh/event_loop.py
@@ -0,0 +1,79 @@
+"""Polysh - Event Loop
+
+Provides the main loop iteration using selectors, replacing asyncore.loop().
+
+Copyright (c) 2006 Guillaume Chazarain
+Copyright (c) 2024 InnoGames GmbH
+"""
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import errno
+import selectors
+from typing import Optional
+
+from polysh import dispatcher_registry
+
+
+def loop_iteration(timeout: Optional[float] = None) -> None:
+ """Perform a single iteration of the event loop.
+
+ This replaces asyncore.loop(count=1, timeout=timeout, use_poll=True).
+
+ Updates selector registrations based on each dispatcher's readable()/writable()
+ state, then performs select and dispatches events to handlers.
+ """
+ selector = dispatcher_registry.get_selector()
+
+ # Update event registrations based on current readable/writable state.
+ # Iterate the dict values directly — this loop doesn't mutate _dispatchers.
+ for dispatcher in dispatcher_registry.iter_dispatchers():
+ events = 0
+ if dispatcher.readable():
+ events |= selectors.EVENT_READ
+ if dispatcher.writable():
+ events |= selectors.EVENT_WRITE
+ dispatcher_registry.modify_events(dispatcher.fd, events)
+
+ # Perform select
+ try:
+ ready = selector.select(timeout)
+ except OSError as e:
+ if e.errno == errno.EINTR:
+ # Interrupted by signal handler, just return
+ return
+ raise
+
+ # Dispatch events
+ for key, events in ready:
+ dispatcher = key.data
+
+ # Check if dispatcher is still valid
+ if dispatcher_registry.get_dispatcher(key.fd) is None:
+ continue
+
+ if events & selectors.EVENT_READ:
+ try:
+ dispatcher.handle_read()
+ except Exception:
+ dispatcher.handle_close()
+
+ # Re-check dispatcher is still valid after handle_read
+ if dispatcher_registry.get_dispatcher(key.fd) is None:
+ continue
+
+ if events & selectors.EVENT_WRITE:
+ try:
+ dispatcher.handle_write()
+ except Exception:
+ dispatcher.handle_close()
diff --git a/src/polysh/exceptions.py b/src/polysh/exceptions.py
new file mode 100644
index 0000000..9bb9dc5
--- /dev/null
+++ b/src/polysh/exceptions.py
@@ -0,0 +1,22 @@
+"""Polysh - Exceptions
+
+Copyright (c) 2006 Guillaume Chazarain
+Copyright (c) 2024 InnoGames GmbH
+"""
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+class ExitNow(Exception):
+ """Exception to signal clean exit. First argument is exit code."""
+ pass
diff --git a/src/polysh/main.py b/src/polysh/main.py
index 2db3902..2c170cf 100644
--- a/src/polysh/main.py
+++ b/src/polysh/main.py
@@ -16,26 +16,28 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
+import argparse
import atexit
import getpass
import locale
-import argparse
import os
+import readline
+import resource
import signal
import sys
import termios
-import readline
-import resource
-from typing import Callable, List
-
-from polysh import remote_dispatcher
-from polysh import dispatchers
-from polysh import stdin
+from typing import Callable
+
+from polysh import (
+ VERSION,
+ control_commands,
+ dispatchers,
+ remote_dispatcher,
+ stdin,
+)
from polysh.console import console_output
+from polysh.exceptions import ExitNow
from polysh.host_syntax import expand_syntax
-from polysh import control_commands
-from polysh import VERSION
def kill_all() -> None:
@@ -52,100 +54,105 @@ def parse_cmdline() -> argparse.Namespace:
description = 'Control commands are prefixed by ":".'
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
- "--hosts-file",
+ '--hosts-file',
type=str,
- action="append",
- dest="hosts_filenames",
- metavar="FILE",
+ action='append',
+ dest='hosts_filenames',
+ metavar='FILE',
default=[],
- help="read hostnames from given file, one per line",
+ help='read hostnames from given file, one per line',
)
parser.add_argument(
- "--command",
+ '--command',
type=str,
- dest="command",
+ dest='command',
default=None,
- help="command to execute on the remote shells",
- metavar="CMD",
+ help='command to execute on the remote shells',
+ metavar='CMD',
)
- def_ssh = "exec ssh -oLogLevel=Quiet -t %(host)s %(port)s"
+ def_ssh = 'exec ssh -oLogLevel=Quiet -t %(host)s %(port)s'
parser.add_argument(
- "--ssh",
+ '--ssh',
type=str,
- dest="ssh",
+ dest='ssh',
default=def_ssh,
- metavar="SSH",
- help="ssh command to use [%(default)s]",
+ metavar='SSH',
+ help='ssh command to use [%(default)s]',
)
parser.add_argument(
- "--user",
+ '--user',
type=str,
- dest="user",
+ dest='user',
default=None,
- help="remote user to log in as",
- metavar="USER",
+ help='remote user to log in as',
+ metavar='USER',
)
parser.add_argument(
- "--no-color",
- action="store_true",
- dest="disable_color",
- help="disable colored hostnames [enabled]",
+ '--no-color',
+ action='store_true',
+ dest='disable_color',
+ help='disable colored hostnames [enabled]',
)
parser.add_argument(
- "--password-file",
+ '--password-file',
type=str,
- dest="password_file",
+ dest='password_file',
default=None,
- metavar="FILE",
- help="read a password from the specified file. - is the tty.",
+ metavar='FILE',
+ help='read a password from the specified file. - is the tty.',
)
parser.add_argument(
- "--log-file",
+ '--log-file',
type=str,
- dest="log_file",
- help="file to log each machine conversation [none]",
+ dest='log_file',
+ help='file to log each machine conversation [none]',
+ )
+ parser.add_argument(
+ '--abort-errors',
+ action='store_true',
+ dest='abort_error',
+ help='abort if some shell fails to initialize [ignore]',
)
parser.add_argument(
- "--abort-errors",
- action="store_true",
- dest="abort_error",
- help="abort if some shell fails to initialize [ignore]",
+ '--debug',
+ action='store_true',
+ dest='debug',
+ help='print debugging information',
)
parser.add_argument(
- "--debug", action="store_true", dest="debug", help="print debugging information"
+ '--profile', action='store_true', dest='profile', default=False
)
- parser.add_argument("--profile", action="store_true", dest="profile", default=False)
- parser.add_argument("host_names", nargs="*")
+ parser.add_argument('host_names', nargs='*')
args = parser.parse_args()
for filename in args.hosts_filenames:
try:
- hosts_file = open(filename, "r")
+ hosts_file = open(filename)
for line in hosts_file.readlines():
- if "#" in line:
- line = line[: line.index("#")]
+ if '#' in line:
+ line = line[: line.index('#')]
line = line.strip()
if line:
args.host_names.append(line)
hosts_file.close()
- except IOError as e:
+ except OSError as e:
parser.error(str(e))
if args.log_file:
try:
- args.log_file = open(args.log_file, "a")
- except IOError as e:
+ args.log_file = open(args.log_file, 'a')
+ except OSError as e:
print(e)
sys.exit(1)
if not args.host_names:
- parser.error("no hosts given")
+ parser.error('no hosts given')
- if args.password_file == "-":
+ if args.password_file == '-':
args.password = getpass.getpass()
elif args.password_file is not None:
- password_file = open(args.password_file, "r")
- args.password = password_file.readline().rstrip("\n")
+ password_file = open(args.password_file)
+ args.password = password_file.readline().rstrip('\n')
else:
args.password = None
@@ -159,20 +166,20 @@ def find_non_interactive_command(command: str) -> str:
stdin = sys.stdin.read()
if stdin and command:
print(
- "--command and reading from stdin are incompatible",
+ '--command and reading from stdin are incompatible',
file=sys.stderr,
)
sys.exit(1)
- if stdin and not stdin.endswith("\n"):
- stdin += "\n"
+ if stdin and not stdin.endswith('\n'):
+ stdin += '\n'
return command or stdin
def init_history(histfile: str) -> None:
- if hasattr(readline, "read_history_file"):
+ if hasattr(readline, 'read_history_file'):
try:
readline.read_history_file(histfile)
- except IOError:
+ except OSError:
pass
@@ -182,7 +189,7 @@ def save_history(histfile: str) -> None:
def loop(interactive: bool) -> None:
- histfile = os.path.expanduser("~/.polysh_history")
+ histfile = os.path.expanduser('~/.polysh_history')
init_history(histfile)
next_signal = None
last_status = None
@@ -191,11 +198,11 @@ def loop(interactive: bool) -> None:
if next_signal:
current_signal = next_signal
next_signal = None
- sig2chr = {signal.SIGINT: "C", signal.SIGTSTP: "Z"}
+ sig2chr = {signal.SIGINT: 'C', signal.SIGTSTP: 'Z'}
ctrl = sig2chr[current_signal]
- remote_dispatcher.log("> ^{}\n".format(ctrl).encode())
+ remote_dispatcher.log(f'> ^{ctrl}\n'.encode())
control_commands.do_send_ctrl(ctrl)
- console_output(b"")
+ console_output(b'')
stdin.the_stdin_thread.prepend_text = None
while dispatchers.count_awaited_processes()[
0
@@ -206,14 +213,14 @@ def loop(interactive: bool) -> None:
r.print_unfinished_line()
current_status = dispatchers.count_awaited_processes()
if current_status != last_status:
- console_output(b"")
+ console_output(b'')
if remote_dispatcher.options.interactive:
stdin.the_stdin_thread.want_raw_input()
last_status = current_status
if dispatchers.all_terminated():
# Clear the prompt
- console_output(b"")
- raise asyncore.ExitNow(remote_dispatcher.options.exit_code)
+ console_output(b'')
+ raise ExitNow(remote_dispatcher.options.exit_code)
if not next_signal:
# possible race here with the signal handler
remote_dispatcher.main_loop_iteration()
@@ -223,22 +230,22 @@ def loop(interactive: bool) -> None:
else:
kill_all()
os.kill(0, signal.SIGINT)
- except asyncore.ExitNow as e:
- console_output(b"")
+ except ExitNow as e:
+ console_output(b'')
save_history(histfile)
sys.exit(e.args[0])
def _profile(continuation: Callable) -> None:
- prof_file = "polysh.prof"
+ prof_file = 'polysh.prof'
import cProfile
import pstats
- print("Profiling using cProfile")
- cProfile.runctx("continuation()", globals(), locals(), prof_file)
+ print('Profiling using cProfile')
+ cProfile.runctx('continuation()', globals(), locals(), prof_file)
stats = pstats.Stats(prof_file)
stats.strip_dirs()
- stats.sort_stats("time", "calls")
+ stats.sort_stats('time', 'calls')
stats.print_stats(50)
stats.print_callees(50)
os.remove(prof_file)
@@ -252,7 +259,7 @@ def restore_tty_on_exit() -> None:
def run() -> None:
"""Launch polysh"""
- locale.setlocale(locale.LC_ALL, "")
+ locale.setlocale(locale.LC_ALL, '')
atexit.register(kill_all)
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
@@ -260,7 +267,9 @@ def run() -> None:
args.command = find_non_interactive_command(args.command)
args.exit_code = 0
- args.interactive = not args.command and sys.stdin.isatty() and sys.stdout.isatty()
+ args.interactive = (
+ not args.command and sys.stdin.isatty() and sys.stdout.isatty()
+ )
if args.interactive:
restore_tty_on_exit()
@@ -281,8 +290,8 @@ def run() -> None:
resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, new_hard))
except OSError as e:
print(
- "Failed to change RLIMIT_NOFILE from soft={} hard={} to soft={} "
- "hard={}: {}".format(old_soft, old_hard, new_soft, new_hard, e),
+ f'Failed to change RLIMIT_NOFILE from soft={old_soft} hard={old_hard} to soft={new_soft} '
+ f'hard={new_hard}: {e}',
file=sys.stderr,
)
sys.exit(1)
@@ -290,7 +299,8 @@ def run() -> None:
dispatchers.create_remote_dispatchers(hosts)
signal.signal(
- signal.SIGWINCH, lambda signum, frame: dispatchers.update_terminal_size()
+ signal.SIGWINCH,
+ lambda signum, frame: dispatchers.update_terminal_size(),
)
stdin.the_stdin_thread = stdin.StdinThread(args.interactive)
@@ -311,14 +321,14 @@ def safe_loop() -> None:
def main():
"""Wrapper around run() to setup sentry"""
- sentry_dsn = os.environ.get("POLYSH_SENTRY_DSN")
+ sentry_dsn = os.environ.get('POLYSH_SENTRY_DSN')
if sentry_dsn:
import sentry_sdk
sentry_sdk.init(
dsn=sentry_dsn,
- release=".".join(map(str, VERSION)),
+ release='.'.join(map(str, VERSION)),
)
try:
diff --git a/src/polysh/remote_dispatcher.py b/src/polysh/remote_dispatcher.py
index ec5d327..ae66b5a 100644
--- a/src/polysh/remote_dispatcher.py
+++ b/src/polysh/remote_dispatcher.py
@@ -16,29 +16,27 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
import os
+import platform
import pty
+import select
import signal
import sys
import termios
-import select
-import platform
-from typing import Optional, List
-from argparse import Namespace
+from typing import List, Optional
+from polysh import callbacks, display_names, event_loop
from polysh.buffered_dispatcher import BufferedDispatcher
-from polysh import callbacks
from polysh.console import console_output
-from polysh import display_names
+from polysh.exceptions import ExitNow
options = None # type: Optional[Namespace]
# Either the remote shell is expecting a command or one is already running
-STATE_NAMES = ["not_started", "idle", "running", "terminated", "dead"]
+STATE_NAMES = ['not_started', 'idle', 'running', 'terminated', 'dead']
-STATE_NOT_STARTED, STATE_IDLE, STATE_RUNNING, STATE_TERMINATED, STATE_DEAD = list(
- range(len(STATE_NAMES))
+STATE_NOT_STARTED, STATE_IDLE, STATE_RUNNING, STATE_TERMINATED, STATE_DEAD = (
+ list(range(len(STATE_NAMES)))
)
# Terminal color codes
@@ -52,7 +50,7 @@ def main_loop_iteration(timeout: Optional[float] = None) -> int:
"""Return the number of RemoteDispatcher.handle_read() calls made by this
iteration"""
prev_nr_read = nr_handle_read
- asyncore.loop(count=1, timeout=timeout, use_poll=True)
+ event_loop.loop_iteration(timeout=timeout)
return nr_handle_read - prev_nr_read
@@ -63,9 +61,9 @@ def log(msg: bytes) -> None:
try:
written = os.write(fd, msg)
except OSError as e:
- print("Exception while writing log:", options.log_file.name)
+ print('Exception while writing log:', options.log_file.name)
print(e)
- raise asyncore.ExitNow(1)
+ raise ExitNow(1)
msg = msg[written:]
@@ -73,10 +71,10 @@ class RemoteDispatcher(BufferedDispatcher):
"""A RemoteDispatcher is a ssh process we communicate with"""
def __init__(self, hostname: str, port: str) -> None:
- if port != "22":
- port = "-p " + port
+ if port != '22':
+ port = '-p ' + port
else:
- port = ""
+ port = ''
self.pid, fd = pty.fork()
if self.pid == 0:
@@ -97,9 +95,9 @@ def __init__(self, hostname: str, port: str) -> None:
self.change_name(self.hostname.encode())
self.init_string = self.configure_tty() + self.set_prompt()
self.init_string_sent = False
- self.read_in_state_not_started = b""
+ self.read_in_state_not_started = b''
self.command = options.command
- self.last_printed_line = b""
+ self.last_printed_line = b''
self.color_code = None
if sys.stdout.isatty() and not options.disable_color:
COLORS.insert(0, COLORS.pop()) # Rotate the colors
@@ -108,11 +106,11 @@ def __init__(self, hostname: str, port: str) -> None:
def launch_ssh(self, name: str, port: str) -> None:
"""Launch the ssh command in the child process"""
if options.user:
- name = "%s@%s" % (options.user, name)
- evaluated = options.ssh % {"host": name, "port": port}
+ name = '%s@%s' % (options.user, name)
+ evaluated = options.ssh % {'host': name, 'port': port}
if evaluated == options.ssh:
- evaluated = "%s %s" % (evaluated, name)
- os.execlp("/bin/sh", "sh", "-c", evaluated)
+ evaluated = '%s %s' % (evaluated, name)
+ os.execlp('/bin/sh', 'sh', '-c', evaluated)
def set_enabled(self, enabled: bool) -> None:
if enabled != self.enabled and options.interactive:
@@ -126,9 +124,9 @@ def change_state(self, state: int) -> None:
"""Change the state of the remote process, logging the change"""
if state is not self.state:
if self.debug:
- self.print_debug(b"state => " + STATE_NAMES[state].encode())
+ self.print_debug(b'state => ' + STATE_NAMES[state].encode())
if self.state is STATE_NOT_STARTED:
- self.read_in_state_not_started = b""
+ self.read_in_state_not_started = b''
self.state = state
def disconnect(self) -> None:
@@ -138,14 +136,14 @@ def disconnect(self) -> None:
except OSError:
# The process was already dead, no problem
pass
- self.read_buffer = b""
- self.write_buffer = b""
+ self.read_buffer = b''
+ self.write_buffer = b''
self.set_enabled(False)
if self.read_in_state_not_started:
self.print_lines(self.read_in_state_not_started)
- self.read_in_state_not_started = b""
+ self.read_in_state_not_started = b''
if options.abort_error and self.state is STATE_NOT_STARTED:
- raise asyncore.ExitNow(1)
+ raise ExitNow(1)
self.change_state(STATE_DEAD)
def configure_tty(self) -> bytes:
@@ -164,22 +162,22 @@ def seen_prompt_cb(self, unused: str) -> None:
if options.interactive:
self.change_state(STATE_IDLE)
elif self.command:
- p1, p2 = callbacks.add(b"real prompt ends", lambda d: None, True)
+ p1, p2 = callbacks.add(b'real prompt ends', lambda d: None, True)
self.dispatch_command(b'PS1="' + p1 + b'""' + p2 + b'\n"\n')
- self.dispatch_command(self.command.encode() + b"\n")
- self.dispatch_command(b"exit 2>/dev/null\n")
+ self.dispatch_command(self.command.encode() + b'\n')
+ self.dispatch_command(b'exit 2>/dev/null\n')
self.command = None
def set_prompt(self) -> bytes:
"""The prompt is important because we detect the readyness of a process
by waiting for its prompt."""
# No right prompt
- command_line = b"PS2=;RPS1=;RPROMPT=;"
- command_line += b"PROMPT_COMMAND=;"
- command_line += b"TERM=ansi;"
- command_line += b"unset precmd_functions;"
- command_line += b"unset HISTFILE;"
- prompt1, prompt2 = callbacks.add(b"prompt", self.seen_prompt_cb, True)
+ command_line = b'PS2=;RPS1=;RPROMPT=;'
+ command_line += b'PROMPT_COMMAND=;'
+ command_line += b'TERM=ansi;'
+ command_line += b'unset precmd_functions;'
+ command_line += b'unset HISTFILE;'
+ prompt1, prompt2 = callbacks.add(b'prompt', self.seen_prompt_cb, True)
command_line += b'PS1="' + prompt1 + b'""' + prompt2 + b'\n"\n'
return command_line
@@ -190,24 +188,23 @@ def readable(self) -> bool:
def handle_expt(self) -> None:
# Dirty hack to ignore POLLPRI flag that is raised on Mac OS, but not
- # on linux. asyncore calls this method in case POLLPRI flag is set, but
- # self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) == 0
- if platform.system() == "Darwin" and select.POLLPRI:
+ # on linux. This method is kept for compatibility but the new selectors
+ # implementation doesn't call it - we handle this in handle_close()
+ if platform.system() == 'Darwin' and select.POLLPRI:
return
self.handle_close()
def handle_close(self) -> None:
if self.state is STATE_DEAD:
- # This connection has already been killed. Asyncore has probably
- # called handle_close() or handle_expt() on this connection twice.
+ # This connection has already been killed.
return
pid, status = os.waitpid(self.pid, 0)
exit_code = os.WEXITSTATUS(status)
options.exit_code = max(options.exit_code, exit_code)
if exit_code and options.interactive:
- console_output("Error talking to {}\n".format(self.display_name).encode())
+ console_output(f'Error talking to {self.display_name}\n'.encode())
self.disconnect()
if self.temporary:
self.close()
@@ -215,32 +212,36 @@ def handle_close(self) -> None:
def print_lines(self, lines: bytes) -> None:
from polysh.display_names import max_display_name_length
- lines = lines.strip(b"\n")
+ lines = lines.strip(b'\n')
while True:
- no_empty_lines = lines.replace(b"\n\n", b"\n")
+ no_empty_lines = lines.replace(b'\n\n', b'\n')
if len(no_empty_lines) == len(lines):
break
lines = no_empty_lines
if not lines:
return
indent = max_display_name_length - len(self.display_name)
- log_prefix = self.display_name.encode() + indent * b" " + b" : "
+ log_prefix = self.display_name.encode() + indent * b' ' + b' : '
if self.color_code is None:
console_prefix = log_prefix
else:
console_prefix = (
- b"\033[1;"
+ b'\033[1;'
+ str(self.color_code).encode()
- + b"m"
+ + b'm'
+ log_prefix
- + b"\033[1;m"
+ + b'\033[1;m'
)
console_data = (
- console_prefix + lines.replace(b"\n", b"\n" + console_prefix) + b"\n"
+ console_prefix
+ + lines.replace(b'\n', b'\n' + console_prefix)
+ + b'\n'
+ )
+ log_data = (
+ log_prefix + lines.replace(b'\n', b'\n' + log_prefix) + b'\n'
)
- log_data = log_prefix + lines.replace(b"\n", b"\n" + log_prefix) + b"\n"
console_output(console_data, logging_msg=log_data)
- self.last_printed_line = lines[lines.rfind(b"\n") + 1 :]
+ self.last_printed_line = lines[lines.rfind(b'\n') + 1 :]
def handle_read_fast_case(self, data: bytes) -> bool:
"""If we are in a fast case we'll avoid the long processing of each
@@ -249,7 +250,7 @@ def handle_read_fast_case(self, data: bytes) -> bool:
# Slow case :-(
return False
- last_nl = data.rfind(b"\n")
+ last_nl = data.rfind(b'\n')
if last_nl == -1:
# No '\n' in data => slow case
return False
@@ -266,10 +267,10 @@ def handle_read(self) -> None:
nr_handle_read += 1
new_data = self._handle_read_chunk()
if self.debug:
- self.print_debug(b"==> " + new_data)
+ self.print_debug(b'==> ' + new_data)
if self.handle_read_fast_case(self.read_buffer):
return
- lf_pos = new_data.find(b"\n")
+ lf_pos = new_data.find(b'\n')
if lf_pos >= 0:
# Optimization: we knew there were no '\n' in the previous read
# buffer, so we searched only in the new_data and we offset the
@@ -278,10 +279,10 @@ def handle_read(self) -> None:
elif (
self.state is STATE_NOT_STARTED
and options.password is not None
- and b"password:" in self.read_buffer.lower()
+ and b'password:' in self.read_buffer.lower()
):
- self.dispatch_write("{}\n".format(options.password).encode())
- self.read_buffer = b""
+ self.dispatch_write(f'{options.password}\n'.encode())
+ self.read_buffer = b''
return
while lf_pos >= 0:
# For each line in the buffer
@@ -292,24 +293,25 @@ def handle_read(self) -> None:
self.print_lines(line)
elif self.state is STATE_NOT_STARTED:
self.read_in_state_not_started += line
- if b"The authenticity of host" in line:
- msg = line.strip(b"\n") + b" Closing connection."
+ if b'The authenticity of host' in line:
+ msg = line.strip(b'\n') + b' Closing connection.'
self.disconnect()
- elif b"REMOTE HOST IDENTIFICATION HAS CHANGED" in line:
- msg = b"Remote host identification has changed."
+ elif b'REMOTE HOST IDENTIFICATION HAS CHANGED' in line:
+ msg = b'Remote host identification has changed.'
else:
msg = None
if msg:
self.print_lines(
- msg + b" Consider manually connecting or " b"using ssh-keyscan."
+ msg + b' Consider manually connecting or '
+ b'using ssh-keyscan.'
)
# Go to the next line in the buffer
self.read_buffer = self.read_buffer[lf_pos + 1 :]
if self.handle_read_fast_case(self.read_buffer):
return
- lf_pos = self.read_buffer.find(b"\n")
+ lf_pos = self.read_buffer.find(b'\n')
if self.state is STATE_NOT_STARTED and not self.init_string_sent:
self.dispatch_write(self.init_string)
self.init_string_sent = True
@@ -319,7 +321,7 @@ def print_unfinished_line(self) -> None:
if self.state is STATE_RUNNING:
if not callbacks.process(self.read_buffer):
self.print_lines(self.read_buffer)
- self.read_buffer = b""
+ self.read_buffer = b''
def writable(self) -> bool:
"""Do we want to write something?"""
@@ -330,22 +332,28 @@ def handle_write(self) -> None:
num_sent = self.send(self.write_buffer)
if self.debug:
if self.state is not STATE_NOT_STARTED or options.password is None:
- self.print_debug(b"<== " + self.write_buffer[:num_sent])
+ self.print_debug(b'<== ' + self.write_buffer[:num_sent])
self.write_buffer = self.write_buffer[num_sent:]
def print_debug(self, msg: bytes) -> None:
"""Log some debugging information to the console"""
state = STATE_NAMES[self.state].encode()
console_output(
- b"[dbg] " + self.display_name.encode() + b"[" + state + b"]: " + msg + b"\n"
+ b'[dbg] '
+ + self.display_name.encode()
+ + b'['
+ + state
+ + b']: '
+ + msg
+ + b'\n'
)
def get_info(self) -> List[bytes]:
"""Return a list with all information available about this process"""
return [
self.display_name.encode(),
- self.enabled and b"enabled" or b"disabled",
- STATE_NAMES[self.state].encode() + b":",
+ self.enabled and b'enabled' or b'disabled',
+ STATE_NAMES[self.state].encode() + b':',
self.last_printed_line.strip(),
]
@@ -373,9 +381,17 @@ def rename(self, name: bytes) -> None:
"""Send to the remote shell, its new name to be shell expanded"""
if name:
# defug callback add?
- rename1, rename2 = callbacks.add(b"rename", self.change_name, False)
+ rename1, rename2 = callbacks.add(
+ b'rename', self.change_name, False
+ )
self.dispatch_command(
- b'/bin/echo "' + rename1 + b'""' + rename2 + b'"' + name + b"\n"
+ b'/bin/echo "'
+ + rename1
+ + b'""'
+ + rename2
+ + b'"'
+ + name
+ + b'\n'
)
else:
self.change_name(self.hostname.encode())
diff --git a/src/polysh/stdin.py b/src/polysh/stdin.py
index cafadb4..beeb766 100644
--- a/src/polysh/stdin.py
+++ b/src/polysh/stdin.py
@@ -16,8 +16,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import asyncore
import errno
+import fcntl
import os
import readline # Just to say we want to use it with raw_input
import signal
@@ -26,22 +26,27 @@
import sys
import tempfile
import termios
-from threading import Thread, Event, Lock
+from threading import Event, Lock, Thread
+from typing import Optional
-from polysh import dispatchers, remote_dispatcher
+from polysh import (
+ completion,
+ dispatcher_registry,
+ dispatchers,
+ remote_dispatcher,
+)
from polysh.console import console_output, set_last_status_length
-from polysh import completion
-from typing import Optional
+from polysh.exceptions import ExitNow
the_stdin_thread = None # type: StdinThread
-class InputBuffer(object):
+class InputBuffer:
"""The shared input buffer between the main thread and the stdin thread"""
def __init__(self) -> None:
self.lock = Lock()
- self.buf = b""
+ self.buf = b''
def add(self, data: bytes) -> None:
"""Add data to the buffer"""
@@ -50,9 +55,9 @@ def add(self, data: bytes) -> None:
def get(self) -> bytes:
"""Get the content of the buffer"""
- data = b""
+ data = b''
with self.lock:
- data, self.buf = self.buf, b""
+ data, self.buf = self.buf, b''
return data
@@ -63,43 +68,43 @@ def process_input_buffer() -> None:
from polysh.control_commands_helpers import handle_control_command
data = the_stdin_thread.input_buffer.get()
- remote_dispatcher.log(b"> " + data)
+ remote_dispatcher.log(b'> ' + data)
- if data.startswith(b":"):
+ if data.startswith(b':'):
try:
handle_control_command(data[1:-1].decode())
- except UnicodeDecodeError as e:
- console_output(b"Could not decode command.")
+ except UnicodeDecodeError:
+ console_output(b'Could not decode command.')
return
- if data.startswith(b"!"):
+ if data.startswith(b'!'):
try:
retcode = subprocess.call(data[1:], shell=True)
except OSError as e:
if e.errno == errno.EINTR:
- console_output(b"Child was interrupted\n")
+ console_output(b'Child was interrupted\n')
retcode = 0
else:
raise
if retcode > 128 and retcode <= 192:
retcode = 128 - retcode
if retcode > 0:
- console_output("Child returned {:d}\n".format(retcode).encode())
+ console_output(f'Child returned {retcode:d}\n'.encode())
elif retcode < 0:
console_output(
- "Child was terminated by signal {:d}\n".format(-retcode).encode()
+ f'Child was terminated by signal {-retcode:d}\n'.encode()
)
return
for r in dispatchers.all_instances():
try:
r.dispatch_command(data)
- except asyncore.ExitNow as e:
+ except ExitNow as e:
raise e
except Exception as msg:
raise msg
console_output(
- "{} for {}, disconnecting\n".format(str(msg), r.display_name).encode()
+ f'{str(msg)} for {r.display_name}, disconnecting\n'.encode()
)
r.disconnect()
else:
@@ -116,32 +121,82 @@ def process_input_buffer() -> None:
# sends the ACK, and the stdin thread can go on.
-class SocketNotificationReader(asyncore.dispatcher):
+class SocketDispatcher:
+ """Base dispatcher class for socket-based communication."""
+
+ def __init__(self, sock: socket.socket) -> None:
+ self.socket = sock
+ self.fd = sock.fileno()
+
+ # Set non-blocking mode
+ flags = fcntl.fcntl(self.fd, fcntl.F_GETFL)
+ fcntl.fcntl(self.fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ # Register with the dispatcher registry
+ dispatcher_registry.register(self.fd, self)
+
+ def recv(self, buffer_size: int) -> bytes:
+ """Read from the socket."""
+ return self.socket.recv(buffer_size)
+
+ def send(self, data: bytes) -> int:
+ """Write to the socket."""
+ return self.socket.send(data)
+
+ def close(self) -> None:
+ """Unregister and close the socket."""
+ dispatcher_registry.unregister(self.fd)
+ try:
+ self.socket.close()
+ except OSError:
+ pass
+
+ def readable(self) -> bool:
+ """Override in subclass."""
+ return True
+
+ def writable(self) -> bool:
+ """Override in subclass."""
+ return False
+
+ def handle_read(self) -> None:
+ """Override in subclass."""
+ pass
+
+ def handle_write(self) -> None:
+ """Override in subclass."""
+ pass
+
+ def handle_close(self) -> None:
+ """Handle connection close."""
+ self.close()
+
+
+class SocketNotificationReader(SocketDispatcher):
"""The socket reader in the main thread"""
- def __init__(self, the_stdin_thread: "StdinThread") -> None:
- asyncore.dispatcher.__init__(self, the_stdin_thread.socket_read)
+ def __init__(self, the_stdin_thread: 'StdinThread') -> None:
+ super().__init__(the_stdin_thread.socket_read)
def _do(self, c: bytes) -> None:
- if c == b"d":
+ if c == b'd':
process_input_buffer()
else:
- raise Exception("Unknown code: %s" % (c))
+ raise Exception('Unknown code: %s' % (c))
def handle_read(self) -> None:
"""Handle all the available character commands in the socket"""
while True:
try:
c = self.recv(1)
- except socket.error as e:
+ except OSError as e:
if e.errno == errno.EWOULDBLOCK:
return
- else:
- raise
+ raise
else:
self._do(c)
self.socket.setblocking(True)
- self.send(b"A")
+ self.send(b'A')
self.socket.setblocking(False)
def writable(self) -> bool:
@@ -155,7 +210,7 @@ def write_main_socket(c: bytes) -> None:
while True:
try:
the_stdin_thread.socket_write.recv(1)
- except socket.error as e:
+ except OSError as e:
if e.errno != errno.EINTR:
raise
else:
@@ -169,7 +224,7 @@ def write_main_socket(c: bytes) -> None:
# a newline
tempfile_fd, tempfile_name = tempfile.mkstemp()
os.remove(tempfile_name)
-os.write(tempfile_fd, b"\x03")
+os.write(tempfile_fd, b'\x03')
def get_stdin_pid(cached_result: Optional[int] = None) -> int:
@@ -177,7 +232,7 @@ def get_stdin_pid(cached_result: Optional[int] = None) -> int:
ID"""
if cached_result is None:
try:
- tasks = os.listdir("/proc/self/task")
+ tasks = os.listdir('/proc/self/task')
except OSError as e:
if e.errno != errno.ENOENT:
raise
@@ -227,7 +282,7 @@ class StdinThread(Thread):
"""The stdin thread, used to call raw_input()"""
def __init__(self, interactive: bool) -> None:
- Thread.__init__(self, name="stdin thread")
+ Thread.__init__(self, name='stdin thread')
completion.install_completion_handler()
self.input_buffer = InputBuffer()
@@ -254,9 +309,9 @@ def prepend_previous_text(self) -> None:
def want_raw_input(self) -> None:
nr, total = dispatchers.count_awaited_processes()
if nr:
- prompt = "waiting (%d/%d)> " % (nr, total)
+ prompt = 'waiting (%d/%d)> ' % (nr, total)
else:
- prompt = "ready (%d)> " % total
+ prompt = 'ready (%d)> ' % total
self.prompt = prompt
set_last_status_length(len(prompt))
self.raw_input_wanted.set()
@@ -296,5 +351,5 @@ def run(self) -> None:
completion.remove_last_history_item()
set_echo(True)
if cmd is not None:
- self.input_buffer.add("{}\n".format(cmd).encode())
- write_main_socket(b"d")
+ self.input_buffer.add(f'{cmd}\n'.encode())
+ write_main_socket(b'd')
diff --git a/tests/benchmark.py b/tests/benchmark.py
new file mode 100755
index 0000000..2885150
--- /dev/null
+++ b/tests/benchmark.py
@@ -0,0 +1,342 @@
+#!/usr/bin/env python3
+"""Benchmark script for comparing polysh event loop implementations.
+
+Compares the old asyncore-based event loop (master branch) against the new
+selectors-based event loop (g_asyncio branch) by running polysh in
+non-interactive mode and measuring wall-clock time.
+
+Usage examples:
+ # Compare two binaries with explicit server list
+ ./benchmark.py --binary-old ./result-old/bin/polysh \
+ --binary-new ./result-new/bin/polysh \
+ server1 server2 server3
+
+ # Single binary benchmark (just timing, no comparison)
+ ./benchmark.py --binary ./result/bin/polysh server1 server2
+
+ # Use adminapi to resolve servers
+ ./benchmark.py --binary-old ./result-old/bin/polysh \
+ --binary-new ./result-new/bin/polysh \
+ --adminapi 'project=grepo game_market=xx servertype=vm'
+
+ # Custom command and more runs
+ ./benchmark.py --binary-old ./result-old/bin/polysh \
+ --binary-new ./result-new/bin/polysh \
+ --command 'cat /proc/loadavg' \
+ --runs 10 \
+ server1 server2
+
+ # With custom ssh command
+ ./benchmark.py --binary ./result/bin/polysh \
+ --ssh 'exec ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -oLogLevel=Quiet -t %(host)s %(port)s' \
+ server1 server2
+"""
+
+import argparse
+import json
+import math
+import os
+import resource
+import shutil
+import subprocess
+import sys
+import time
+
+
+def resolve_adminapi(query):
+ """Resolve server names using adminapi CLI tool."""
+ try:
+ result = subprocess.run(
+ ['adminapi', query],
+ capture_output=True,
+ text=True,
+ timeout=30,
+ )
+ if result.returncode != 0:
+ print(f'adminapi failed: {result.stderr.strip()}', file=sys.stderr)
+ sys.exit(1)
+ servers = result.stdout.strip().split()
+ if not servers:
+ print('adminapi returned no servers', file=sys.stderr)
+ sys.exit(1)
+ return servers
+ except FileNotFoundError:
+ print('adminapi not found in PATH', file=sys.stderr)
+ sys.exit(1)
+
+
+def build_polysh_cmd(binary, ssh_cmd, command, servers):
+ """Build the polysh command line."""
+ cmd = [binary]
+ if ssh_cmd:
+ cmd.extend(['--ssh', ssh_cmd])
+ cmd.extend(['--command', command])
+ cmd.extend(servers)
+ return cmd
+
+
+def run_single(cmd, run_number, quiet=False):
+ """Run a single polysh invocation and return wall-clock time in seconds."""
+ if not quiet:
+ print(f' Run {run_number}... ', end='', flush=True)
+
+ start = time.monotonic()
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ timeout=300, # 5 minute timeout per run
+ )
+ except subprocess.TimeoutExpired:
+ if not quiet:
+ print('TIMEOUT')
+ return None
+
+ elapsed = time.monotonic() - start
+
+ if not quiet:
+ if result.returncode != 0:
+ print(f'{elapsed:.3f}s (exit code {result.returncode})')
+ else:
+ print(f'{elapsed:.3f}s')
+
+ return elapsed
+
+
+def run_benchmark(label, cmd, runs, warmup):
+ """Run multiple iterations of a benchmark and collect timings."""
+ print(f'\n--- {label} ---')
+ print(f'Command: {" ".join(cmd)}')
+ print(f'Warmup runs: {warmup}, Measured runs: {runs}')
+
+ # Warmup
+ for i in range(warmup):
+ print(f' Warmup {i + 1}... ', end='', flush=True)
+ t = run_single(cmd, i + 1, quiet=True)
+ if t is not None:
+ print(f'{t:.3f}s')
+ else:
+ print('TIMEOUT')
+
+ # Measured runs
+ timings = []
+ for i in range(runs):
+ t = run_single(cmd, i + 1)
+ if t is not None:
+ timings.append(t)
+
+ return timings
+
+
+def compute_stats(timings):
+ """Compute basic statistics from a list of timings."""
+ if not timings:
+ return None
+
+ n = len(timings)
+ mean = sum(timings) / n
+ sorted_t = sorted(timings)
+ median = sorted_t[n // 2] if n % 2 else (sorted_t[n // 2 - 1] + sorted_t[n // 2]) / 2
+
+ if n > 1:
+ variance = sum((t - mean) ** 2 for t in timings) / (n - 1)
+ stddev = math.sqrt(variance)
+ else:
+ stddev = 0.0
+
+ return {
+ 'n': n,
+ 'mean': mean,
+ 'median': median,
+ 'min': sorted_t[0],
+ 'max': sorted_t[-1],
+ 'stddev': stddev,
+ }
+
+
+def print_stats(label, stats):
+ """Print statistics for a benchmark run."""
+ if stats is None:
+ print(f'\n{label}: No successful runs')
+ return
+
+ print(f'\n{label} ({stats["n"]} runs):')
+ print(f' Mean: {stats["mean"]:.3f}s')
+ print(f' Median: {stats["median"]:.3f}s')
+ print(f' Min: {stats["min"]:.3f}s')
+ print(f' Max: {stats["max"]:.3f}s')
+ print(f' Stddev: {stats["stddev"]:.3f}s')
+
+
+def print_comparison(stats_old, stats_new):
+ """Print comparison between old and new implementations."""
+ if stats_old is None or stats_new is None:
+ print('\nCannot compare: one or both benchmarks had no successful runs')
+ return
+
+ diff = stats_new['mean'] - stats_old['mean']
+ if stats_old['mean'] > 0:
+ pct = (diff / stats_old['mean']) * 100
+ else:
+ pct = 0.0
+
+ print('\n=== Comparison ===')
+ print(f' Old mean: {stats_old["mean"]:.3f}s')
+ print(f' New mean: {stats_new["mean"]:.3f}s')
+ print(f' Diff: {diff:+.3f}s ({pct:+.1f}%)')
+
+ if diff < 0:
+ print(f' New is {abs(pct):.1f}% faster')
+ elif diff > 0:
+ print(f' New is {pct:.1f}% slower')
+ else:
+ print(' No difference')
+
+ # Also compare medians
+ diff_med = stats_new['median'] - stats_old['median']
+ if stats_old['median'] > 0:
+ pct_med = (diff_med / stats_old['median']) * 100
+ else:
+ pct_med = 0.0
+ print(f' Median diff: {diff_med:+.3f}s ({pct_med:+.1f}%)')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Benchmark polysh event loop implementations',
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+
+ # Binary selection - mutually exclusive groups
+ binary_group = parser.add_mutually_exclusive_group(required=True)
+ binary_group.add_argument(
+ '--binary',
+ help='Single polysh binary to benchmark (timing only, no comparison)',
+ )
+ binary_group.add_argument(
+ '--binary-old',
+ help='Path to old (asyncore) polysh binary',
+ )
+
+ parser.add_argument(
+ '--binary-new',
+ help='Path to new (selectors) polysh binary (required with --binary-old)',
+ )
+
+ # Server selection
+ parser.add_argument(
+ 'servers',
+ nargs='*',
+ help='Server hostnames to connect to',
+ )
+ parser.add_argument(
+ '--adminapi',
+ help='adminapi query string to resolve server names',
+ )
+
+ # Polysh options
+ parser.add_argument(
+ '--ssh',
+ default='exec ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -oLogLevel=Quiet -t %(host)s %(port)s',
+ help='SSH command template (default: %(default)s)',
+ )
+ parser.add_argument(
+ '--command',
+ default='hostname',
+ help='Command to execute on remote hosts (default: %(default)s)',
+ )
+
+ # Benchmark parameters
+ parser.add_argument(
+ '--runs',
+ type=int,
+ default=5,
+ help='Number of measured runs (default: %(default)s)',
+ )
+ parser.add_argument(
+ '--warmup',
+ type=int,
+ default=1,
+ help='Number of warmup runs (default: %(default)s)',
+ )
+ parser.add_argument(
+ '--json',
+ metavar='FILE',
+ help='Write results as JSON to FILE',
+ )
+
+ args = parser.parse_args()
+
+ # Validate binary arguments
+ if args.binary_old and not args.binary_new:
+ parser.error('--binary-new is required when using --binary-old')
+
+ # Resolve servers
+ servers = list(args.servers) if args.servers else []
+ if args.adminapi:
+ servers.extend(resolve_adminapi(args.adminapi))
+ if not servers:
+ parser.error('No servers specified. Use positional args or --adminapi')
+
+ # Validate binaries exist
+ binaries = {}
+ if args.binary:
+ if not os.path.isfile(args.binary) or not os.access(args.binary, os.X_OK):
+ print(f'Binary not found or not executable: {args.binary}', file=sys.stderr)
+ sys.exit(1)
+ binaries['single'] = args.binary
+ else:
+ for label, path in [('old', args.binary_old), ('new', args.binary_new)]:
+ if not os.path.isfile(path) or not os.access(path, os.X_OK):
+ print(f'Binary not found or not executable: {path}', file=sys.stderr)
+ sys.exit(1)
+ binaries[label] = path
+
+ print(f'Servers ({len(servers)}): {" ".join(servers[:10])}{"..." if len(servers) > 10 else ""}')
+ print(f'Command: {args.command}')
+ print(f'SSH: {args.ssh}')
+ print(f'Runs: {args.runs} (+ {args.warmup} warmup)')
+
+ results = {}
+
+ if 'single' in binaries:
+ cmd = build_polysh_cmd(binaries['single'], args.ssh, args.command, servers)
+ timings = run_benchmark('Polysh', cmd, args.runs, args.warmup)
+ stats = compute_stats(timings)
+ print_stats('Polysh', stats)
+ results['single'] = {'timings': timings, 'stats': stats}
+ else:
+ # Run old first, then new
+ cmd_old = build_polysh_cmd(binaries['old'], args.ssh, args.command, servers)
+ timings_old = run_benchmark('Old (asyncore)', cmd_old, args.runs, args.warmup)
+ stats_old = compute_stats(timings_old)
+
+ cmd_new = build_polysh_cmd(binaries['new'], args.ssh, args.command, servers)
+ timings_new = run_benchmark('New (selectors)', cmd_new, args.runs, args.warmup)
+ stats_new = compute_stats(timings_new)
+
+ print_stats('Old (asyncore)', stats_old)
+ print_stats('New (selectors)', stats_new)
+ print_comparison(stats_old, stats_new)
+
+ results['old'] = {'timings': timings_old, 'stats': stats_old}
+ results['new'] = {'timings': timings_new, 'stats': stats_new}
+
+ # Write JSON output if requested
+ if args.json:
+ json_data = {
+ 'servers': servers,
+ 'command': args.command,
+ 'ssh': args.ssh,
+ 'runs': args.runs,
+ 'warmup': args.warmup,
+ 'results': results,
+ }
+ with open(args.json, 'w') as f:
+ json.dump(json_data, f, indent=2)
+ print(f'\nJSON results written to {args.json}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_dispatcher_registry.py b/tests/test_dispatcher_registry.py
new file mode 100644
index 0000000..4d9e1cd
--- /dev/null
+++ b/tests/test_dispatcher_registry.py
@@ -0,0 +1,208 @@
+"""Polysh - Tests - Dispatcher Registry
+
+Unit tests for the selectors-based dispatcher registry.
+
+Copyright (c) 2024 InnoGames GmbH
+"""
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+import selectors
+import unittest
+
+from polysh import dispatcher_registry
+
+
+class FakeDispatcher:
+ """Minimal dispatcher stub for testing."""
+
+ def __init__(self, fd):
+ self.fd = fd
+ self._readable = True
+ self._writable = False
+
+ def readable(self):
+ return self._readable
+
+ def writable(self):
+ return self._writable
+
+
+class TestDispatcherRegistry(unittest.TestCase):
+ def setUp(self):
+ """Reset module-level global state and track fds for cleanup."""
+ self._fds = []
+ dispatcher_registry._dispatchers.clear()
+ dispatcher_registry._current_events.clear()
+ dispatcher_registry._selector.close()
+ dispatcher_registry._selector = selectors.DefaultSelector()
+
+ def tearDown(self):
+ """Close all pipe fds opened during the test."""
+ # Unregister any remaining dispatchers first to avoid selector complaints
+ for fd in list(dispatcher_registry._dispatchers):
+ dispatcher_registry.unregister(fd)
+ dispatcher_registry._selector.close()
+ dispatcher_registry._selector = selectors.DefaultSelector()
+ for fd in self._fds:
+ try:
+ os.close(fd)
+ except OSError:
+ pass
+
+ def _make_pipe(self):
+ """Create a pipe and track both fds for cleanup."""
+ r, w = os.pipe()
+ self._fds.extend([r, w])
+ return r, w
+
+ def test_register_and_get_dispatcher(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ self.assertIs(dispatcher_registry.get_dispatcher(r), d)
+ # Selector should have it registered with EVENT_READ
+ key = dispatcher_registry.get_selector().get_key(r)
+ self.assertEqual(key.events, selectors.EVENT_READ)
+ self.assertIs(key.data, d)
+
+ def test_register_sets_current_events(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ self.assertEqual(
+ dispatcher_registry._current_events[r], selectors.EVENT_READ
+ )
+
+ def test_unregister(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+ dispatcher_registry.unregister(r)
+
+ self.assertIsNone(dispatcher_registry.get_dispatcher(r))
+ self.assertNotIn(r, dispatcher_registry._current_events)
+ with self.assertRaises(KeyError):
+ dispatcher_registry.get_selector().get_key(r)
+
+ def test_unregister_unknown_fd_is_noop(self):
+ # Should not raise
+ dispatcher_registry.unregister(99999)
+
+ def test_unregister_already_unregistered(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+ dispatcher_registry.unregister(r)
+ # Second unregister should be a safe noop
+ dispatcher_registry.unregister(r)
+
+ def test_modify_events_caching(self):
+ """modify_events should skip syscall when events unchanged."""
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ # Initial state is EVENT_READ. Calling modify with same events is a noop.
+ dispatcher_registry.modify_events(r, selectors.EVENT_READ)
+ key = dispatcher_registry.get_selector().get_key(r)
+ self.assertEqual(key.events, selectors.EVENT_READ)
+
+ def test_modify_events_read_to_readwrite(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ both = selectors.EVENT_READ | selectors.EVENT_WRITE
+ dispatcher_registry.modify_events(r, both)
+
+ key = dispatcher_registry.get_selector().get_key(r)
+ self.assertEqual(key.events, both)
+ self.assertEqual(dispatcher_registry._current_events[r], both)
+
+ def test_modify_events_to_zero_unregisters_from_selector(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ dispatcher_registry.modify_events(r, 0)
+
+ # Should be unregistered from selector but still in _dispatchers
+ with self.assertRaises(KeyError):
+ dispatcher_registry.get_selector().get_key(r)
+ self.assertIs(dispatcher_registry.get_dispatcher(r), d)
+ self.assertEqual(dispatcher_registry._current_events[r], 0)
+
+ def test_modify_events_zero_to_read_reregisters(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ # Unregister from selector
+ dispatcher_registry.modify_events(r, 0)
+ # Re-register
+ dispatcher_registry.modify_events(r, selectors.EVENT_READ)
+
+ key = dispatcher_registry.get_selector().get_key(r)
+ self.assertEqual(key.events, selectors.EVENT_READ)
+
+ def test_modify_events_unknown_fd_is_noop(self):
+ # Should not raise
+ dispatcher_registry.modify_events(99999, selectors.EVENT_READ)
+
+ def test_all_dispatchers_returns_list_copy(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ result = dispatcher_registry.all_dispatchers()
+ self.assertIsInstance(result, list)
+ self.assertEqual(result, [d])
+ # Mutating the returned list should not affect the registry
+ result.clear()
+ self.assertEqual(dispatcher_registry.all_dispatchers(), [d])
+
+ def test_iter_dispatchers(self):
+ r, w = self._make_pipe()
+ d = FakeDispatcher(r)
+ dispatcher_registry.register(r, d)
+
+ result = list(dispatcher_registry.iter_dispatchers())
+ self.assertEqual(result, [d])
+
+ def test_multiple_dispatchers(self):
+ r1, w1 = self._make_pipe()
+ r2, w2 = self._make_pipe()
+ d1 = FakeDispatcher(r1)
+ d2 = FakeDispatcher(r2)
+ dispatcher_registry.register(r1, d1)
+ dispatcher_registry.register(r2, d2)
+
+ self.assertEqual(len(dispatcher_registry.all_dispatchers()), 2)
+ self.assertIs(dispatcher_registry.get_dispatcher(r1), d1)
+ self.assertIs(dispatcher_registry.get_dispatcher(r2), d2)
+
+ dispatcher_registry.unregister(r1)
+ self.assertEqual(len(dispatcher_registry.all_dispatchers()), 1)
+ self.assertIsNone(dispatcher_registry.get_dispatcher(r1))
+ self.assertIs(dispatcher_registry.get_dispatcher(r2), d2)
+
+ def test_get_dispatcher_returns_none_for_unknown(self):
+ self.assertIsNone(dispatcher_registry.get_dispatcher(99999))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_event_loop.py b/tests/test_event_loop.py
new file mode 100644
index 0000000..4ec610e
--- /dev/null
+++ b/tests/test_event_loop.py
@@ -0,0 +1,224 @@
+"""Polysh - Tests - Event Loop
+
+Unit tests for the selectors-based event loop.
+
+Copyright (c) 2024 InnoGames GmbH
+"""
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+import selectors
+import unittest
+
+from polysh import dispatcher_registry
+from polysh.event_loop import loop_iteration
+
+
+class FakeDispatcher:
+ """Dispatcher stub that records which handlers were called."""
+
+ def __init__(self, fd, readable=True, writable=False):
+ self.fd = fd
+ self._readable = readable
+ self._writable = writable
+ self.read_called = 0
+ self.write_called = 0
+ self.close_called = 0
+
+ def readable(self):
+ return self._readable
+
+ def writable(self):
+ return self._writable
+
+ def handle_read(self):
+ self.read_called += 1
+
+ def handle_write(self):
+ self.write_called += 1
+
+ def handle_close(self):
+ self.close_called += 1
+
+
+class FailingReadDispatcher(FakeDispatcher):
+ """Dispatcher whose handle_read raises an exception."""
+
+ def handle_read(self):
+ self.read_called += 1
+ raise RuntimeError('read failed')
+
+
+class FailingWriteDispatcher(FakeDispatcher):
+ """Dispatcher whose handle_write raises an exception."""
+
+ def handle_write(self):
+ self.write_called += 1
+ raise RuntimeError('write failed')
+
+
+class SelfUnregisteringDispatcher(FakeDispatcher):
+ """Dispatcher that unregisters itself during handle_read."""
+
+ def handle_read(self):
+ self.read_called += 1
+ dispatcher_registry.unregister(self.fd)
+
+
+class TestEventLoop(unittest.TestCase):
+ def setUp(self):
+ self._fds = []
+ dispatcher_registry._dispatchers.clear()
+ dispatcher_registry._current_events.clear()
+ dispatcher_registry._selector.close()
+ dispatcher_registry._selector = selectors.DefaultSelector()
+
+ def tearDown(self):
+ for fd in list(dispatcher_registry._dispatchers):
+ dispatcher_registry.unregister(fd)
+ dispatcher_registry._selector.close()
+ dispatcher_registry._selector = selectors.DefaultSelector()
+ for fd in self._fds:
+ try:
+ os.close(fd)
+ except OSError:
+ pass
+
+ def _make_pipe(self):
+ r, w = os.pipe()
+ self._fds.extend([r, w])
+ return r, w
+
+ def test_readable_dispatcher_gets_handle_read(self):
+ r, w = self._make_pipe()
+ # Write something so the read end becomes ready
+ os.write(w, b'hello')
+
+ d = FakeDispatcher(r, readable=True, writable=False)
+ dispatcher_registry.register(r, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d.read_called, 1)
+ self.assertEqual(d.write_called, 0)
+
+ def test_writable_dispatcher_gets_handle_write(self):
+ r, w = self._make_pipe()
+ # Write end of a pipe is immediately writable
+ d = FakeDispatcher(w, readable=False, writable=True)
+ dispatcher_registry.register(w, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertEqual(d.read_called, 0)
+ self.assertGreaterEqual(d.write_called, 1)
+
+ def test_both_readable_and_writable(self):
+ r, w = self._make_pipe()
+ os.write(w, b'data')
+
+ # Use the read end for reading, write end for writing
+ dr = FakeDispatcher(r, readable=True, writable=False)
+ dw = FakeDispatcher(w, readable=False, writable=True)
+ dispatcher_registry.register(r, dr)
+ dispatcher_registry.register(w, dw)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(dr.read_called, 1)
+ self.assertGreaterEqual(dw.write_called, 1)
+
+ def test_not_readable_not_writable_skipped(self):
+ r, w = self._make_pipe()
+ os.write(w, b'data')
+
+ d = FakeDispatcher(r, readable=False, writable=False)
+ dispatcher_registry.register(r, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertEqual(d.read_called, 0)
+ self.assertEqual(d.write_called, 0)
+
+ def test_no_dispatchers_does_not_raise(self):
+ # Empty registry, should just return after timeout
+ loop_iteration(timeout=0.01)
+
+ def test_handle_read_exception_triggers_handle_close(self):
+ r, w = self._make_pipe()
+ os.write(w, b'data')
+
+ d = FailingReadDispatcher(r, readable=True, writable=False)
+ dispatcher_registry.register(r, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d.read_called, 1)
+ self.assertGreaterEqual(d.close_called, 1)
+
+ def test_handle_write_exception_triggers_handle_close(self):
+ r, w = self._make_pipe()
+
+ d = FailingWriteDispatcher(w, readable=False, writable=True)
+ dispatcher_registry.register(w, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d.write_called, 1)
+ self.assertGreaterEqual(d.close_called, 1)
+
+ def test_dispatcher_removed_during_handle_read_skips_handle_write(self):
+ """If handle_read unregisters the dispatcher, handle_write must not run."""
+ r, w = self._make_pipe()
+ os.write(w, b'data')
+
+ # Dispatcher that claims both readable and writable, but unregisters
+ # itself on read — handle_write must be skipped
+ d = SelfUnregisteringDispatcher(r, readable=True, writable=True)
+ dispatcher_registry.register(r, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d.read_called, 1)
+ self.assertEqual(d.write_called, 0)
+
+ def test_events_updated_between_iterations(self):
+ """Dispatcher state changes should be reflected in the next iteration."""
+ r, w = self._make_pipe()
+ os.write(w, b'data')
+
+ d = FakeDispatcher(r, readable=True, writable=False)
+ dispatcher_registry.register(r, d)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d.read_called, 1)
+
+ # Now make it not readable
+ read_count = d.read_called
+ d._readable = False
+ loop_iteration(timeout=0.05)
+ self.assertEqual(d.read_called, read_count)
+
+ def test_multiple_dispatchers_independent(self):
+ """Multiple dispatchers should be handled independently."""
+ r1, w1 = self._make_pipe()
+ r2, w2 = self._make_pipe()
+ os.write(w1, b'data1')
+ # r2 has no data, so not ready for read
+
+ d1 = FakeDispatcher(r1, readable=True, writable=False)
+ d2 = FakeDispatcher(r2, readable=True, writable=False)
+ dispatcher_registry.register(r1, d1)
+ dispatcher_registry.register(r2, d2)
+
+ loop_iteration(timeout=0.1)
+ self.assertGreaterEqual(d1.read_called, 1)
+ self.assertEqual(d2.read_called, 0)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/uv.lock b/uv.lock
index 011963b..197c441 100644
--- a/uv.lock
+++ b/uv.lock
@@ -109,7 +109,7 @@ wheels = [
[[package]]
name = "polysh"
-version = "0.14.0"
+version = "0.16.0"
source = { editable = "." }
[package.optional-dependencies]