diff --git a/src/portkeydrop/app.py b/src/portkeydrop/app.py index 906f1ba..94baab7 100644 --- a/src/portkeydrop/app.py +++ b/src/portkeydrop/app.py @@ -118,7 +118,10 @@ def __init__(self) -> None: self.build_tag = os.environ.get("PORTKEYDROP_BUILD_TAG") self._auto_update_check_timer: wx.Timer | None = None self._site_manager = SiteManager() - self._transfer_service = TransferService(notify_window=self) + self._transfer_service = TransferService( + notify_window=self, + max_workers=self._settings.transfer.concurrent_transfers, + ) self._transfer_state_by_id: dict[str, str] = {} self._last_failed_transfer: str | None = None self._announcer = ScreenReaderAnnouncer() @@ -1574,6 +1577,9 @@ def _on_settings(self, event: wx.CommandEvent) -> None: self._settings = dlg.get_settings() update_last_local_folder(self._settings, self._local_cwd) save_settings(self._settings) + self._transfer_service.set_max_workers( + self._settings.transfer.concurrent_transfers, + ) self.update_check_updates_menu_label() self._start_auto_update_checks() self._populate_file_list( diff --git a/src/portkeydrop/services/transfer_service.py b/src/portkeydrop/services/transfer_service.py index 644bf00..67a102d 100644 --- a/src/portkeydrop/services/transfer_service.py +++ b/src/portkeydrop/services/transfer_service.py @@ -92,15 +92,19 @@ def from_dict(cls, data: dict) -> TransferJob: class TransferService: - """Owns the transfer queue and a single daemon worker thread.""" + """Owns the transfer queue and a pool of daemon worker threads.""" - def __init__(self, notify_window: Any | None = None) -> None: + def __init__(self, notify_window: Any | None = None, max_workers: int = 1) -> None: self._notify_window = notify_window - self._queue: queue.Queue[TransferJob] = queue.Queue() + self._queue: queue.Queue[TransferJob | None] = queue.Queue() self._jobs: list[TransferJob] = [] self._lock = threading.Lock() - self._worker = threading.Thread(target=self._worker_loop, daemon=True) - self._worker.start() + self._max_workers = max(1, max_workers) + self._workers: list[threading.Thread] = [] + for _ in range(self._max_workers): + t = threading.Thread(target=self._worker_loop, daemon=True) + t.start() + self._workers.append(t) # ------------------------------------------------------------------ # Public API @@ -205,6 +209,27 @@ def cancel(self, job_id: str) -> None: break self._post_event() + def set_max_workers(self, n: int) -> None: + """Resize the worker pool to *n* threads. + + Extra workers are drained via a ``None`` sentinel on the queue; + missing workers are spawned immediately. + """ + n = max(1, n) + with self._lock: + # Prune threads that have already exited + self._workers = [t for t in self._workers if t.is_alive()] + current = len(self._workers) + if n > current: + for _ in range(n - current): + t = threading.Thread(target=self._worker_loop, daemon=True) + t.start() + self._workers.append(t) + elif n < current: + for _ in range(current - n): + self._queue.put(None) # sentinel to stop one worker + self._max_workers = n + # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ @@ -218,6 +243,8 @@ def _enqueue(self, job: TransferJob) -> None: def _worker_loop(self) -> None: while True: job = self._queue.get() + if job is None: # shutdown sentinel + break if job.cancel_event.is_set(): job.status = TransferStatus.CANCELLED self._post_event() diff --git a/tests/test_app.py b/tests/test_app.py index 3b28817..2ea0f5a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -35,7 +35,8 @@ def _build_frame(module, tmp_path): sort_by="name", sort_ascending=True, ) - settings = SimpleNamespace(display=display) + transfer = SimpleNamespace(concurrent_transfers=2) + settings = SimpleNamespace(display=display, transfer=transfer) fake_manager = MagicMock(jobs=[]) fake_site_manager = MagicMock() @@ -85,7 +86,7 @@ def _hydrate_frame(module): def test_main_frame_init_sets_transfer_state(tmp_path, app_module): frame, _, transfer_service_cls = _build_frame(app_module, tmp_path) assert frame._transfer_state_by_id == {} - transfer_service_cls.assert_called_once_with(notify_window=frame) + transfer_service_cls.assert_called_once_with(notify_window=frame, max_workers=2) def test_bind_events_hooks_transfer_update(app_module): @@ -1177,6 +1178,7 @@ def test_on_settings_reconfigures_update_menu_and_timer(app_module): frame._settings = SimpleNamespace( app=SimpleNamespace(update_channel="stable"), display=SimpleNamespace(show_hidden_files=True), + transfer=SimpleNamespace(concurrent_transfers=2), ) frame._local_cwd = "/tmp" frame.remote_file_list = MagicMock() @@ -1192,6 +1194,7 @@ def test_on_settings_reconfigures_update_menu_and_timer(app_module): updated_settings = SimpleNamespace( app=SimpleNamespace(update_channel="nightly"), display=SimpleNamespace(show_hidden_files=True), + transfer=SimpleNamespace(concurrent_transfers=4), ) dialog = MagicMock( ShowModal=MagicMock(return_value=fake_wx.ID_OK), @@ -1215,6 +1218,7 @@ def test_on_settings_passes_check_updates_callback(app_module): frame._settings = SimpleNamespace( app=SimpleNamespace(update_channel="stable"), display=SimpleNamespace(show_hidden_files=True), + transfer=SimpleNamespace(concurrent_transfers=2), ) frame._local_cwd = "/tmp" frame.remote_file_list = MagicMock() diff --git a/tests/test_transfer_service.py b/tests/test_transfer_service.py index 79ba20e..25f4616 100644 --- a/tests/test_transfer_service.py +++ b/tests/test_transfer_service.py @@ -69,10 +69,20 @@ def test_cancel_event_independent(self): class TestTransferServiceInit: - def test_starts_daemon_worker_thread(self): + def test_starts_daemon_worker_threads(self): svc = TransferService(notify_window=None) - assert svc._worker.is_alive() - assert svc._worker.daemon is True + assert len(svc._workers) == 1 + assert all(t.is_alive() for t in svc._workers) + assert all(t.daemon for t in svc._workers) + + def test_starts_multiple_workers(self): + svc = TransferService(notify_window=None, max_workers=3) + assert len(svc._workers) == 3 + assert all(t.is_alive() for t in svc._workers) + + def test_max_workers_clamped_to_one(self): + svc = TransferService(notify_window=None, max_workers=0) + assert len(svc._workers) == 1 def test_jobs_returns_snapshot(self): svc = TransferService(notify_window=None) @@ -446,3 +456,61 @@ def test_status_values(self): assert TransferStatus.COMPLETE.value == "complete" assert TransferStatus.FAILED.value == "failed" assert TransferStatus.CANCELLED.value == "cancelled" + + +# --------------------------------------------------------------------------- +# Concurrent worker pool +# --------------------------------------------------------------------------- + + +class TestConcurrentWorkers: + def test_jobs_run_concurrently_with_multiple_workers(self): + """Two slow jobs should overlap when max_workers >= 2.""" + barrier = threading.Barrier(2, timeout=5) + completed_order: list[str] = [] + lock = threading.Lock() + + mock_client = MagicMock() + + def slow_download(src, fh, callback=None, offset=0): + name = PurePosixPath(src).name + barrier.wait() # both workers must reach here before either proceeds + with lock: + completed_order.append(name) + + mock_client.download.side_effect = slow_download + + svc = TransferService(notify_window=None, max_workers=2) + with patch("builtins.open", return_value=MagicMock(spec=io.BufferedWriter)): + j1 = svc.submit_download(mock_client, "/r/a.txt", "/tmp/a.txt") + j2 = svc.submit_download(mock_client, "/r/b.txt", "/tmp/b.txt") + _wait_for_terminal(j1) + _wait_for_terminal(j2) + + assert j1.status == TransferStatus.COMPLETE + assert j2.status == TransferStatus.COMPLETE + assert len(completed_order) == 2 + + def test_set_max_workers_increases_pool(self): + svc = TransferService(notify_window=None, max_workers=1) + assert len([t for t in svc._workers if t.is_alive()]) == 1 + + svc.set_max_workers(3) + time.sleep(0.1) + alive = [t for t in svc._workers if t.is_alive()] + assert len(alive) == 3 + + def test_set_max_workers_decreases_pool(self): + svc = TransferService(notify_window=None, max_workers=3) + assert len(svc._workers) == 3 + + svc.set_max_workers(1) + # Give sentinels time to be consumed + time.sleep(0.5) + alive = [t for t in svc._workers if t.is_alive()] + assert len(alive) == 1 + + def test_set_max_workers_clamps_to_one(self): + svc = TransferService(notify_window=None, max_workers=2) + svc.set_max_workers(0) + assert svc._max_workers == 1