diff --git a/PR_1090_CHIP_CALLABLE_ASYNC_REVISED_PLAN.md b/PR_1090_CHIP_CALLABLE_ASYNC_REVISED_PLAN.md new file mode 100644 index 000000000..c5f109b42 --- /dev/null +++ b/PR_1090_CHIP_CALLABLE_ASYNC_REVISED_PLAN.md @@ -0,0 +1,728 @@ +# PR 1090 Revised Plan + +## 1. L3 DAG Run Async + +Goal: + +```python +dag_h = l3_worker.run_async(orch_fn, args, config) +timing = dag_h.wait() +``` + +This API is async relative to the caller. The DAG itself still executes with +the current synchronous `Worker.run(orch_fn)` semantics: + +```text +orch_fn calls orch.submit_next_level(...) +current scheduler dispatches READY tasks +current WorkerThread waits task completion +current drain waits the whole DAG +``` + +Do not convert `orch.submit_next_level(...)` to child `RUN_ASYNC` in this step. +That is a separate lower-layer scheduler/completion redesign. + +Use one per-worker DAG run queue so sync and async DAG runs share ordering: + +```python +@dataclass +class DagRunState: + cv: Condition + completed: bool = False + result: RunTiming | None = None + error: BaseException | None = None + + +@dataclass +class DagRunRequest: + orch_fn: Callable + args: Any + config: CallConfig + state: DagRunState +``` + +Worker initialization: + +```python +def _start_dag_run_lane(self): + if self._dag_run_thread is not None: + return + + self._dag_run_queue = Queue() + self._dag_run_thread = Thread( + target=self._dag_run_thread_loop, + name="simpler-l3-dag-run", + daemon=True, + ) + self._dag_run_thread.start() +``` + +L3 public async DAG run: + +```python +def run_async(self, orch_fn, args=None, config=None) -> DagRunHandle: + assert self.level >= 3 + assert self._initialized + + state = DagRunState() + req = DagRunRequest( + orch_fn=orch_fn, + args=args, + config=copy_call_config(config), + state=state, + ) + + self._dag_run_queue.put(req) + return DagRunHandle(state) +``` + +L3 public sync DAG run: + +```python +def run(self, orch_fn, args=None, config=None) -> RunTiming: + assert self.level >= 3 + + return self.run_async(orch_fn, args, config).wait() +``` + +The DAG lane calls the existing synchronous DAG implementation: + +```python +def _dag_run_thread_loop(self): + while True: + req = self._dag_run_queue.get() + if req is None: + break + + try: + timing = self._run_dag_sync_impl( + req.orch_fn, + req.args, + req.config, + ) + complete_success(req.state, timing) + except BaseException as exc: + complete_error(req.state, exc) +``` + +Move the current L3 body of `Worker.run(...)` into `_run_dag_sync_impl(...)`: + +```python +def _run_dag_sync_impl(self, orch_fn, args, config) -> RunTiming: + self._start_hierarchical() + self._orch._clear_error() + self._orch._scope_begin() + + t_start = time.perf_counter_ns() + try: + orch_fn(self._orch, args, config) + finally: + self._orch._scope_end() + self._orch._drain() + self._execute_pending_domain_releases() + self._release_all_live_domains() + + return RunTiming(time.perf_counter_ns() - t_start, 0) +``` + +Handle wait: + +```python +class DagRunHandle: + @property + def completed(self) -> bool: + return self._state.completed + + def wait(self) -> RunTiming: + with self._state.cv: + while not self._state.completed: + self._state.cv.wait() + + if self._state.error is not None: + raise self._state.error + + return self._state.result +``` + +Ordering requirement: + +```python +h1 = worker.run_async(orch_fn_1) +t2 = worker.run(orch_fn_2) + +# Required: +# orch_fn_1 completes before orch_fn_2 starts. +``` + +## 2. Public API Shape + +Use level-specific public semantics. + +L2 direct worker APIs: + +```python +h = l2.register(chip_callable) +rh = l2.register_async(chip_callable) + +run_h = l2.run_async(h, args) +timing = run_h.wait() + +timing = l2.run(h, args) + +unreg_h = l2.unregister_async(h) +unreg_h.wait() + +l2.unregister(h) +``` + +L3 worker APIs: + +```python +h = l3.register(chip_callable) +rh = l3.register_async(chip_callable) + +dag_h = l3.run_async(orch_fn) +timing = dag_h.wait() + +timing = l3.run(orch_fn) + +unreg_h = l3.unregister_async(h) +unreg_h.wait() + +l3.unregister(h) +``` + +Do not expose an L3 public direct chip-run API: + +```python +# Not part of the public L3 API. +l3.run_chip_callable_async(...) +l3.run_chip_callable_sync(...) +``` + +L3 execution should go through `run(orch_fn)` or `run_async(orch_fn)`. + +PR scope rule: + +```text +register_async(target): + if target is not ChipCallable: + raise TypeError + +unregister_async(handle): + if handle is not a ChipCallable handle: + raise TypeError +``` + +Async register/unregister for Python callables or `RemoteCallable` is a future +extension, not part of this PR. + +`register_async(...)` and `unregister_async(...)` support `ChipCallable` +targets/handles only. Non-chip callable async registration or unregister must +fail explicitly: + +```python +def register_async(self, target, *, workers=None): + if not isinstance(target, ChipCallable): + raise TypeError("Worker.register_async only supports ChipCallable") + + return self._register_chip_async(target, workers=workers) + + +def unregister_async(self, handle): + if not is_chip_callable_handle(handle): + raise TypeError("Worker.unregister_async only supports ChipCallable") + + return self._unregister_chip_async(handle) +``` + +Generic `register(...)` remains synchronous and supports existing callable +kinds. For a `ChipCallable`, it delegates to `register_async(...).wait()`: + +```python +def register(self, target, *, workers=None): + if isinstance(target, ChipCallable): + return self.register_async( + target, + workers=workers, + ).wait() + + return self._register_non_chip_sync(target, workers=workers) +``` + +Async register implementation: + +```python +def _register_chip_async( + self, + target: ChipCallable, + *, + workers: list[int] | None = None, +) -> RegisterHandle: + if not isinstance(target, ChipCallable): + raise TypeError("expected ChipCallable") + + reg = build_callable_registration(self, target, workers=workers) + + with self._registry_lock: + handle, is_new = self._install_registration_locked(reg) + + if not self._initialized: + return completed_register_handle(handle) + + if self.level == 2: + return self._l2_submit_register_async(handle, target, is_new=is_new) + + return self._l3_submit_register_async(handle, target, is_new=is_new) +``` + +L2 async run: + +```python +def run_async( + self, + handle: CallableHandle, + args=None, + config=None, +) -> RunHandle: + assert self._initialized + assert self.level == 2 + + return self._l2_submit_run_async(handle, args, config) +``` + +L2 sync run: + +```python +def _run_l2_sync(self, handle, args=None, config=None): + return self.run_async(handle, args, config).wait() +``` + +Generic `unregister(...)` remains synchronous and delegates for chip callables: + +```python +def unregister(self, handle): + if is_chip_callable_handle(handle): + return self.unregister_async(handle).wait() + + return self._unregister_non_chip_sync(handle) +``` + +## 3. L2 Direct Worker Lanes + +L2 direct usage should have the same lane shape as an L3 chip child: + +```text +run lane: + serial chip runs + +register lane: + async chip prepare/register +``` + +L2 state: + +```python +class Worker: + _l2_run_queue: Queue[LocalRunRequest] + _l2_register_queue: Queue[LocalRegisterRequest] + + _l2_run_thread: Thread + _l2_register_thread: Thread + + _slot_inflight: dict[int, int] + _slot_tombstoned: set[int] + _slot_pending_unregister: dict[int, UnregisterState] +``` + +Start lanes during L2 init: + +```python +def _init_l2(self): + self._chip_worker = ChipWorker(...) + self._chip_worker.init(...) + + self._start_l2_lanes() + self._initialized = True + + +def _start_l2_lanes(self): + self._l2_run_queue = Queue() + self._l2_register_queue = Queue() + + self._l2_run_thread = Thread(target=self._l2_run_loop, daemon=True) + self._l2_register_thread = Thread( + target=self._l2_register_loop, + daemon=True, + ) + + self._l2_run_thread.start() + self._l2_register_thread.start() +``` + +L2 async register submit: + +```python +def _l2_submit_register_async(self, handle, target, *, is_new): + if not is_new: + return completed_register_handle(handle) + + state = RegisterState() + + callable_bytes = bytes_from_chip_callable(target) + with self._registry_lock: + slot_id = self._identity_registry[handle.digest].slot_id + + self._l2_register_queue.put(LocalRegisterRequest( + slot_id=slot_id, + digest=handle.digest, + callable_bytes=callable_bytes, + state=state, + )) + + return RegisterHandle(state, result=handle) +``` + +L2 register lane: + +```python +def _l2_register_loop(self): + while True: + req = self._l2_register_queue.get() + if req is None: + break + + try: + callable_obj = ChipCallable.from_bytes(req.callable_bytes) + validate_digest(callable_obj, req.digest) + + self._chip_worker._prepare_callable_at_slot( + req.slot_id, + callable_obj, + ) + + complete_success(req.state) + except BaseException as exc: + rollback_parent_registration(req.digest) + complete_error(req.state, exc) +``` + +L2 async run submit: + +```python +def _l2_submit_run_async(self, handle, args, config): + with self._registry_lock: + slot = self._resolve_handle_locked( + handle, + expected_namespace="LOCAL_CHIP", + ) + + if slot.slot_id in self._slot_tombstoned: + raise KeyError("callable handle is pending unregister") + + self._slot_inflight[slot.slot_id] += 1 + + run_state = RunState() + self._l2_run_queue.put(LocalRunRequest( + slot_id=slot.slot_id, + args=copy_run_args(args), + config=copy_call_config(config), + state=run_state, + )) + + return RunHandle(run_state) +``` + +L2 run lane: + +```python +def _l2_run_loop(self): + while True: + req = self._l2_run_queue.get() + if req is None: + break + + try: + timing = self._chip_worker._run_slot( + req.slot_id, + req.args, + req.config, + ) + complete_success(req.state, timing) + except BaseException as exc: + complete_error(req.state, exc) + finally: + self._release_slot_inflight(req.slot_id) +``` + +L2 sync APIs use the same queues: + +```python +def register(self, target): + if self.level == 2 and isinstance(target, ChipCallable): + return self.register_async(target).wait() + + +def run(self, handle, args=None, config=None): + if self.level == 2: + return self.run_async(handle, args, config).wait() +``` + +## 4. Nonblocking Tombstone And Deferred Free + +Unregister submit must be nonblocking: + +```python +unreg_h = worker.unregister_async(handle) +``` + +Return means: + +```text +public handle has been tombstoned +new runs using that handle are rejected +native unregister/free may still be pending +unreg_h.wait() waits for actual native cleanup +``` + +Free condition: + +```text +free when tombstoned(slot_id) and inflight(slot_id) == 0 +never free merely because inflight becomes zero +``` + +Shared state: + +```python +@dataclass +class SlotLifetime: + slot_id: int + digest: bytes + ref_count: int + tombstoned: bool = False + inflight: int = 0 + unregister_state: UnregisterState | None = None +``` + +Run submit holds an in-flight reference before enqueue: + +```python +def hold_slot_for_run(handle) -> tuple[int, bytes]: + with registry_lock: + slot = resolve_live_chip_handle(handle) + + if slot.tombstoned: + raise KeyError("callable handle is pending unregister") + + slot.inflight += 1 + return slot.slot_id, slot.digest +``` + +Run completion releases that reference: + +```python +def release_slot_after_run(slot_id): + cleanup_state = None + cleanup_digest = b"" + + with registry_lock: + slot = slot_by_id[slot_id] + slot.inflight -= 1 + + if slot.inflight == 0 and slot.tombstoned: + cleanup_state = slot.unregister_state + cleanup_digest = slot.digest + + if cleanup_state is not None: + native_unregister_and_finish(slot_id, cleanup_digest, cleanup_state) +``` + +Async unregister: + +```python +def unregister_async(handle) -> UnregisterHandle: + cleanup_slot_id = -1 + cleanup_digest = b"" + cleanup_state = None + + with registry_lock: + slot = resolve_live_chip_handle(handle) + + remove_public_handle(handle) + slot.ref_count -= 1 + + state = UnregisterState() + + if slot.ref_count > 0: + complete_success(state) + return UnregisterHandle(state) + + slot.tombstoned = True + slot.unregister_state = state + pending_unregister_cids.add(slot.slot_id) + + if slot.inflight == 0: + cleanup_slot_id = slot.slot_id + cleanup_digest = slot.digest + cleanup_state = state + + if cleanup_state is not None: + native_unregister_and_finish( + cleanup_slot_id, + cleanup_digest, + cleanup_state, + ) + + return UnregisterHandle(state) +``` + +Local native cleanup: + +```python +def native_unregister_and_finish(slot_id, digest, state): + try: + chip_worker._unregister_slot(slot_id) + complete_success(state) + except BaseException as exc: + mark_cleanup_uncertain(digest) + complete_error(state, exc) + finally: + with registry_lock: + callable_registry.pop(slot_id, None) + identity_registry.pop(digest, None) + pending_unregister_cids.discard(slot_id) + slot_lifetime.pop(slot_id, None) +``` + +L3 parent remote cleanup: + +```python +def _l3_submit_unregister_async(slot) -> UnregisterHandle: + parent_state = slot.unregister_state + remote_handles = [] + + for child in chip_children: + remote_handles.append( + child.control_unregister_async(slot.digest) + ) + + wait_thread = Thread( + target=_wait_remote_unregisters, + args=(slot, remote_handles, parent_state), + daemon=True, + ) + wait_thread.start() + + return UnregisterHandle(parent_state) +``` + +L3 child unregister submit: + +```python +def handle_unregister_async(digest): + state = UnregisterState() + cleanup_cid = None + + with registry_cv: + cid = identity_table.get(digest) + if cid is None: + complete_success(state) + return make_remote_handle(state) + + remove_identity_mapping(digest) + tombstoned_cids.add(cid) + + if inflight_cids.get(cid, 0) == 0: + cleanup_cid = cid + else: + deferred_unregister[cid] = state + + if cleanup_cid is not None: + child_native_unregister_and_finish(cleanup_cid, digest, state) + + return make_remote_handle(state) +``` + +L3 child run submit stores `cid`, not only `digest`: + +```python +def handle_run_async(digest, args_blob, config): + state = RunState() + + with registry_cv: + cid = identity_table.get(digest) + if cid is None or cid in tombstoned_cids: + complete_error(state, KeyError("callable not live")) + return make_remote_handle(state) + + inflight_cids[cid] += 1 + + run_queue.put(ChildRunRequest( + cid=cid, + args_blob=copy(args_blob), + config=config, + state=state, + )) + + return make_remote_handle(state) +``` + +L3 child run completion triggers deferred cleanup: + +```python +def child_release_inflight(cid): + cleanup_digest = b"" + cleanup_state = None + + with registry_cv: + inflight_cids[cid] -= 1 + if inflight_cids[cid] == 0: + del inflight_cids[cid] + + if cid in tombstoned_cids: + state = deferred_unregister.pop(cid, None) + if state is not None: + cleanup_digest = digest_by_cid[cid] + cleanup_state = state + + if cleanup_state is not None: + child_native_unregister_and_finish(cid, cleanup_digest, cleanup_state) +``` + +L3 child native cleanup: + +```python +def child_native_unregister_and_finish(cid, digest, state): + try: + chip_worker._unregister_slot(cid) + complete_success(state) + except BaseException as exc: + complete_error(state, exc) + finally: + with registry_cv: + registry.pop(cid, None) + identity_table.pop(digest, None) + identity_refs.pop(digest, None) + prepared.discard(cid) + tombstoned_cids.discard(cid) + digest_by_cid.pop(cid, None) +``` + +Unregister wait is a resource cleanup barrier: + +```python +run_h = l2.run_async(handle, args, config) +unreg_h = l2.unregister_async(handle) + +unreg_h.wait() +``` + +`unreg_h.wait()` guarantees native unregister/free completed after all accepted +runs that held the slot stopped using it. It does not return run timing or +rethrow the run error. Call `run_h.wait()` when the caller needs the run result: + +```python +timing = run_h.wait() +unreg_h.wait() +``` diff --git a/docs/callable-identity-registration.md b/docs/callable-identity-registration.md index 9bd80b181..bf9b8374a 100644 --- a/docs/callable-identity-registration.md +++ b/docs/callable-identity-registration.md @@ -399,6 +399,42 @@ The L3+ orchestration function captures `CallableHandle` values and passes them to `orch.submit_next_level` or `orch.submit_sub`. Hashid does not add a new top-level registration requirement for `Worker.run`. +Async worker APIs are level-specific: + +```python +# L2 direct chip worker +handle = worker.register(chip_callable) +pending_handle = worker.register_async(chip_callable) +run_handle = worker.run_async(handle, args, config) +unregister_handle = worker.unregister_async(handle) + +# L3+ orchestration worker +handle = worker.register(chip_callable) +pending_handle = worker.register_async(chip_callable) +dag_handle = worker.run_async(orch_fn, args, config) +unregister_handle = worker.unregister_async(handle) +``` + +For L2, `run_async(handle, ...)` submits one chip callable run to the local +chip run lane and returns a run completion handle. Synchronous `run(handle, +...)` submits to the same lane and waits, so sync and async L2 runs share one +ordering queue. + +For L3+, `run_async(orch_fn, ...)` is async relative to the caller but runs the +same orchestration DAG body as `run(orch_fn, ...)`. The current DAG scheduler +inside the run remains synchronous: `orch.submit_next_level(...)` uses the +existing ready-task dispatch path, and the DAG run drains before its handle +completes. Synchronous `run(orch_fn, ...)` submits to the same DAG run queue +and waits, so a later sync run cannot overtake an earlier async run. + +`register_async(...)` and `unregister_async(...)` are chip-callable APIs in the +current implementation. Passing a Python callable, `RemoteCallable`, or a +non-chip handle raises `TypeError`. Generic `register(...)` remains +synchronous and still supports the existing callable kinds; for `ChipCallable` +it delegates to `register_async(...).wait()`. Generic `unregister(...)` +remains synchronous; for chip handles it delegates to +`unregister_async(...).wait()`. + ### Registry Contract Each target namespace records local identity state. In the current local @@ -435,22 +471,30 @@ when the refcount reaches zero. Current local slot reuse rule: - A child resolves `hashid -> local_slot` immediately before execution. -- Each endpoint has one local mailbox operation in flight at a time. -- Parent-side dispatch and control operations to the same endpoint are - serialized by the per-WorkerThread mailbox lock. -- Final unregister removes the hashid from resolution and releases the private - slot only after the current mailbox operation has completed. +- A run submit holds an in-flight reference to the resolved private slot before + it enters the run lane. +- Async run/register/unregister controls may overlap a chip child task after + the child has copied that task's args and published `TASK_RUNNING`. +- Memory and CommDomain controls still wait for an in-flight task dispatch to + finish before claiming the mailbox. +- Final chip unregister tombstones the identity, rejects new runs through that + public handle, and releases executable state only when the slot is both + tombstoned and has no in-flight runs. This rule prevents stale slot reuse without exposing any extra public field. -A future remote or multi-flight control channel must add explicit +A future remote or multi-flight control channel must preserve the same `INSTALLED` / `TOMBSTONED` / `FAILED` target states, sequence numbers, and in-flight user draining before it can reuse private slots safely. ### Registration Failure Contract -Registration remains synchronous and whole-scope. For a given -`target_namespace`, the scope is every active child endpoint in the current -`Worker`'s corresponding resolver domain at register start. +Synchronous registration remains whole-scope. For a given `target_namespace`, +the scope is every active child endpoint in the current `Worker`'s +corresponding resolver domain at register start. `register_async` for a +`ChipCallable` returns after the async prepare/register request has been +submitted; its `RegisterHandle.wait()` is the whole-scope completion barrier +that either returns the public `CallableHandle` or raises the registration +error. 1. Parent builds the canonical descriptor and computes the `hashid`. 2. Parent allocates an unpublished parent-side registration entry and handle @@ -461,7 +505,8 @@ Registration remains synchronous and whole-scope. For a given 5. Target installs `hashid -> local_slot`, or increments `ref_count` when the same descriptor and payload are already installed. 6. Parent returns the `CallableHandle` only after every target in the scope - reports success. + reports success. For `register_async`, this happens from + `RegisterHandle.wait()`. If any target fails or times out: @@ -488,10 +533,18 @@ Parent-side scheduling assumes the handle's `hashid` is installed on every active target in its registration scope. Dispatch choices are constrained by the handle namespace, submit-time affinity, and tensor/buffer accessibility. Submit-time live validation is a preflight check only. It does not pin the -target identity through later drain or child dispatch. Callers must not -concurrently unregister a handle while `Worker.run()` or any in-flight task may -submit or use that handle; wait for the relevant run/drain to return before -unregistering it. +target identity through later drain or child dispatch. For chip callables, a +run submit pins the resolved slot before enqueueing work. A later +`unregister_async(handle)` removes that public handle from live resolution and +tombstones the slot only when the final public reference is removed. Runs that +already acquired the slot continue to completion; new runs through the +tombstoned handle are rejected. `UnregisterHandle.wait()` is the cleanup +barrier for target-local executable state, not a run result barrier. Call +`RunHandle.wait()` separately when the caller needs run timing or run errors. + +Non-chip callable unregister remains synchronous in this PR. Async unregister +for Python callables or remote task dispatcher handles is not part of the +current contract. Parent-side `TaskSlotState` stores the submitted callable's stable identity: the 32-byte `sha256` digest plus parent-side scheduling metadata such as @@ -597,15 +650,22 @@ Target unregister sequence: 1. Decrement the target-local refcount for `hashid`. 2. If the refcount remains nonzero, keep the mapping installed. -3. If the refcount reaches zero, stop new local resolutions from `hashid` to - private slot. -4. Clear executable state. -5. Release the private slot for reuse. -6. Remove or archive the `hashid` entry. +3. If the refcount reaches zero, remove the public digest resolution and mark + the private slot tombstoned. +4. If no run holds that private slot, call the target-local native unregister + path immediately and then release the slot for reuse. +5. If any run already holds the private slot, leave executable state installed + and attach the unregister state to that tombstone. +6. Each run completion decrements the private slot's in-flight count. The + completion that observes `tombstoned && inflight == 0` performs the native + unregister/free and completes the unregister state. This sequence is the concrete unregister form of the target-local slot reuse -rule for the current single-flight local mailbox. A future multi-flight target -must insert a tombstone/drain phase before clearing executable state. +rule. Final unregister is non-blocking at submit time for chip callables: +`unregister_async(handle)` returns an `UnregisterHandle` after the tombstone is +submitted, and `UnregisterHandle.wait()` is the cleanup barrier for native +unregister/free. It does not replace `RunHandle.wait()` when the caller needs +the run's timing or error result. If failed-register cleanup cannot be confirmed, the parent must not dispatch that hashid to the uncertain target again during the current Worker lifetime. @@ -625,9 +685,18 @@ The implementation provides canonical descriptor and hash helpers: The public API is handle-based: - `Worker.register` returns `CallableHandle`. +- `Worker.register_async` returns `RegisterHandle` for `ChipCallable` + targets only; non-chip targets raise `TypeError`. +- `Worker.unregister_async` returns `UnregisterHandle` for chip handles only; + non-chip handles raise `TypeError`. - `CallableHandle` validation rejects forged, stale, mutated, or wrong-namespace handles. -- L3+ `Worker.run(raw_orch_fn, ...)` behavior is unchanged. +- L2 `Worker.run_async(handle, ...)` submits to the local chip run lane. +- L3+ `Worker.run_async(raw_orch_fn, ...)` submits one whole DAG run to the + Worker-level DAG run lane. +- L3+ `Worker.run(raw_orch_fn, ...)` behavior is unchanged except that it + delegates to `run_async(...).wait()` so sync and async DAG runs share one + ordering queue. - Integer execution slots remain private to the target child process. Each target owns identity state: @@ -649,10 +718,23 @@ Local mailbox task frames are hashid-based: - The local mailbox task payload is prefixed with the 32-byte `sha256` digest. - The existing `TaskArgs` blob follows the digest prefix. -- Chip and sub child loops resolve `hashid -> local_slot` immediately before - execution. +- Chip child loops resolve `hashid -> local_slot`, copy the args blob, pin the + private slot, enqueue work on their run lane, and publish `TASK_RUNNING`. +- Sub and Worker-child loops keep the historical synchronous path and publish + `TASK_DONE` after their Python callable or inner `Worker.run()` returns. - `ChipWorker.run(local_slot)` remains private to the child process. +Async control overlap is chip-specific: + +- `CTRL_REGISTER_ASYNC`, `CTRL_WAIT_REGISTER`, `CTRL_RUN_ASYNC`, + `CTRL_WAIT_RUN`, `CTRL_UNREGISTER_ASYNC`, and `CTRL_WAIT_UNREGISTER` may + claim the mailbox while the child is in `TASK_RUNNING`. +- Memory and CommDomain controls still wait for the task to publish + `TASK_DONE` before claiming the mailbox. +- The control command restores the previous `TASK_RUNNING` state after + `CONTROL_DONE`, so the parent dispatch path can keep waiting for the + original task completion. + Register failure cleanup is conservative: - Handles are not published until every target in scope installed the hashid. @@ -681,6 +763,11 @@ Required tests: | Pre-start register | Startup hashid mappings are visible after ready. | | Partial register failure | No public handle is returned. | | Cleanup uncertainty | Unconfirmed cleanup blocks that target/hashid pair. | +| L2 run queue | Sync and async direct chip runs share FIFO ordering. | +| L3 DAG run queue | Sync and async DAG runs share FIFO ordering. | +| Async API type guard | Non-chip async register/unregister inputs raise `TypeError`. | +| Task/control overlap | Chip `TASK_RUNNING` permits async register/run/unregister controls. | +| Deferred unregister | Final unregister tombstones before native free. | | Unregister cleanup | Hashid resolution stops before final slot cleanup. | | Unsupported kind | Target rejects unsupported kind before install. | | Hashid format fuzz | Bad prefix, length, or hex encoding is rejected. | diff --git a/docs/hierarchical_level_runtime.md b/docs/hierarchical_level_runtime.md index bc361ad97..1e10fa935 100644 --- a/docs/hierarchical_level_runtime.md +++ b/docs/hierarchical_level_runtime.md @@ -115,11 +115,15 @@ See [scheduler.md](scheduler.md) for the dispatch loop and coordination. The **execution layer**. `WorkerManager` holds two pools of `WorkerThread`s (next-level pool and sub pool). Each `WorkerThread` owns one std::thread that encodes `(callable, config, args_blob)` into a `MAILBOX_SIZE`-byte shared -memory region, signals the pre-forked Python child, and spin-polls - `TASK_DONE`, returning an explicit completion outcome to the Scheduler. +memory region, signals the pre-forked Python child, and waits for +`TASK_DONE`, returning an explicit completion outcome to the Scheduler. +Chip children may first publish `TASK_RUNNING` after copying the payload and +enqueueing the run on their child-local run lane; selected async controls can +use the mailbox during that running window. - Next-level (chip) children run `_chip_process_loop`, which constructs a - `ChipWorker` and dispatches each kernel through it. + `ChipWorker`, owns child-local run/register lanes, and dispatches each + kernel through the run lane. - SUB children run `_sub_worker_loop`, which decodes the args blob into a `TaskArgs` and calls the registered Python callable as `fn(args)`. There is no C++ `SubWorker` class — SUB workers exist only as a worker-type @@ -149,8 +153,8 @@ what flows through `ChipWorker::run`. │ │ pop ready_queue │ │ pick idle WorkerThread │ │ wt.dispatch(slot_id) ──────► WorkerThread - │ │ encode mailbox → spin-poll TASK_DONE - │ │ (blocking; child runs the kernel) + │ │ encode mailbox → wait TASK_DONE + │ │ (chip child may publish TASK_RUNNING) │ │◄── completion_queue ────── on_complete_(completion) │ │ on_task_complete: │ │ success → COMPLETED diff --git a/docs/remote-l3-worker-design.md b/docs/remote-l3-worker-design.md index 13780a910..083231dd2 100644 --- a/docs/remote-l3-worker-design.md +++ b/docs/remote-l3-worker-design.md @@ -131,8 +131,11 @@ Relevant code paths: - `_child_worker_loop()` runs a nested `Worker` child via shm mailbox. - `_run_chip_main_loop()` handles task and control mailbox states. - `src/common/hierarchical/worker_manager.{h,cpp}` - - `WorkerThread` owns one local mailbox and blocks until `TASK_DONE`. - - Control commands share the same mailbox and serialize on `mailbox_mu_`. + - `WorkerThread` owns one local mailbox and waits until `TASK_DONE`. + - The mailbox lock serializes payload writes and task acknowledgement. + After a chip child publishes `TASK_RUNNING`, selected async controls can + temporarily claim the same mailbox and restore `TASK_RUNNING` after + `CONTROL_DONE`. - Errors are reported through `MAILBOX_OFF_ERROR` and `MAILBOX_OFF_ERROR_MSG`. - `src/common/hierarchical/orchestrator.{h,cpp}` diff --git a/docs/task-flow.md b/docs/task-flow.md index 0aa67d6cf..0455aec7d 100644 --- a/docs/task-flow.md +++ b/docs/task-flow.md @@ -58,7 +58,7 @@ to C++: | Context | Namespace | How it's consumed | | ------- | --------- | ----------------- | -| `w3.submit_next_level(handle, …)` dispatched to a chip child | `LOCAL_CHIP` | child resolves digest to its private chip slot, then calls `ChipWorker::run(local_slot, …)` | +| `w3.submit_next_level(handle, …)` dispatched to a chip child | `LOCAL_CHIP` | child resolves digest to its private chip slot, copies args, and enqueues the run on its chip run lane | | `w4.submit_next_level(handle, …)` dispatched to an L3 `Worker` child | `LOCAL_PYTHON` | child resolves digest to an orchestration function and calls `inner_worker.run(orch_fn, …)` | | remote `w4.submit_next_level(handle, …)` dispatched to remote L3 | `REMOTE_TASK_DISPATCHER` | remote endpoint resolves digest in its dispatcher registry and calls its embedded L3 Worker | | `w3.submit_sub(handle, …)` dispatched to a SUB child | `LOCAL_PYTHON` | child resolves digest to a Python callable and calls `fn(args)` | @@ -188,7 +188,8 @@ View does **not** own memory. Valid for the duration of a single │ child decodes header → builds TaskArgsView over the blob bytes ▼ child resolves digest -> local slot - ChipWorker::run(local_slot, view, config) (in the forked child) + child copies args and enqueues run_prepared_from_blob(local_slot, ...) + on the child-local chip run lane │ (L2 ABI edge) ▼ @@ -381,11 +382,17 @@ reclaim independently of outer-scope tasks. See ## 8. Data flow on completion -When the child finishes the kernel, it writes `TASK_DONE` to the mailbox; -`LocalMailboxEndpoint::run` exits its spin-poll, reads the mailbox error -fields, and returns a `WorkerCompletion`. `MAILBOX_OFF_ERROR == 0` maps to -success; a non-zero child error maps to task failure. The parent -`WorkerThread` pushes that completion onto `Scheduler::completion_queue_`. +For chip children, the child copies the task payload out of the mailbox, +pins the private callable slot, enqueues the work on its chip run lane, and +publishes `TASK_RUNNING`. `LocalMailboxEndpoint::run` still waits until the +same mailbox reaches `TASK_DONE`. When the chip run lane finishes the kernel, +it writes `TASK_DONE`; the parent reads the mailbox error fields and returns a +`WorkerCompletion`. `MAILBOX_OFF_ERROR == 0` maps to success; a non-zero child +error maps to task failure. The parent `WorkerThread` pushes that completion +onto `Scheduler::completion_queue_`. + +Sub-worker and Worker-child dispatches publish `TASK_DONE` directly after the +Python callable or inner `Worker.run()` returns. At this point: @@ -473,7 +480,7 @@ L4 parent process | 1 | L4 parent Python | `w4.run(my_l4_orch)` → `scope_begin` → `my_l4_orch(orch4, ...)` | | 2 | L4 `Orchestrator.submit_next_level` | the L3 callable handle digest is stored in the slot's callable identity; slot pushed to L4's ready queue | | 3 | L4 Scheduler | pop slot; pick idle WorkerThread → the L3 child's mailbox | -| 4 | L4 WorkerThread (PROCESS) | encode `(callable digest, config, args_blob)` into mailbox; write `TASK_READY`; spin-poll | +| 4 | L4 WorkerThread (PROCESS) | encode `(callable digest, config, args_blob)` into mailbox; write `TASK_READY`; wait for completion | | 5 | L3 child `_child_worker_loop` | wake on `TASK_READY`; read digest → child-local slot → `my_l3_orch` | | 6 | L3 child | `inner_worker.run(my_l3_orch, args, cfg)` → `scope_begin` → `my_l3_orch(orch3, ...)` | | 7 | L3 `Orchestrator.submit_sub` | `l3_sub_handle` digest dispatched to L3's own sub worker child | @@ -524,14 +531,15 @@ Step-by-step (one chip worker): | 2 | `Worker::run` | `scope_begin` → call `my_orch(&orch_, args.view(), cfg)` | | 3 | `Orchestrator::submit_next_level` | `slot = ring.alloc()`; move `chip_args` into `slot.task_args`; walk tags → `tensormap.lookup(a.data)`, `tensormap.lookup(b.data)`, `tensormap.insert(c.data, slot)`; push ready | | 4 | Scheduler thread | pop `slot`; `wt = manager.pick_idle(NEXT_LEVEL)` (WT_chip_0); `wt->dispatch(slot)` | -| 5 | WT_chip_0 parent side | encode mailbox: write reserved callable field, `config`, digest prefix, `write_blob` of task_args; set `TASK_READY`; spin-poll | -| 6 | chip_0 child process | wake on `TASK_READY`; resolve digest to local slot; `read_blob` → `view`; call `ChipWorker::run(local_slot, view, cfg)` | -| 7 | `ChipWorker::run` | assemble `ChipStorageTaskArgs` POD (memcpy view); call `pto2_run_runtime(local_slot, &chip_storage, &cfg)` | -| 8 | runtime.so | translate host ptrs → device ptrs; dispatch AICPU / AICore; write output into `c`'s shm | -| 9 | chip_0 child | `run` returns; write `TASK_DONE` | -| 10 | WT_chip_0 parent | see `TASK_DONE`; push success completion | -| 11 | Scheduler | mark slot COMPLETED; fanout release (none in this DAG); scope_end will release scope ref | -| 12 | `Worker::run` returns | user's `w3.run(...)` returns; `c` contains result in shm, visible to user | +| 5 | WT_chip_0 parent side | encode mailbox: write reserved callable field, `config`, digest prefix, `write_blob` of task_args; set `TASK_READY`; wait for completion | +| 6 | chip_0 child process | wake on `TASK_READY`; resolve digest to local slot; copy `args_blob`; enqueue run; publish `TASK_RUNNING` | +| 7 | chip run lane | call `run_prepared_from_blob(local_slot, copied_args, cfg)` | +| 8 | `ChipWorker::run` path | assemble `ChipStorageTaskArgs` POD (memcpy view); call `pto2_run_runtime(local_slot, &chip_storage, &cfg)` | +| 9 | runtime.so | translate host ptrs → device ptrs; dispatch AICPU / AICore; write output into `c`'s shm | +| 10 | chip run lane | `run` returns; release in-flight slot ref; write `TASK_DONE` | +| 11 | WT_chip_0 parent | see `TASK_DONE`; push success completion | +| 12 | Scheduler | mark slot COMPLETED; fanout release (none in this DAG); scope_end will release scope ref | +| 13 | `Worker::run` returns | user's `w3.run(...)` returns; `c` contains result in shm, visible to user | --- diff --git a/docs/worker-manager.md b/docs/worker-manager.md index 93d5ef666..480b71da9 100644 --- a/docs/worker-manager.md +++ b/docs/worker-manager.md @@ -163,10 +163,19 @@ WorkerCompletion LocalMailboxEndpoint::run(Ring *ring, WorkerDispatch d) { const TaskArgs &args = s.is_group() ? s.task_args_list[d.group_index] : s.task_args; write_blob(m + MAILBOX_OFF_TASK_ARGS_BLOB, args); - // Signal child + // Signal child. The chip child may acknowledge TASK_RUNNING after it has + // copied the args blob and enqueued the run on its private run lane; sub + // and Worker-child paths may go straight to TASK_DONE. write_state(mailbox_, MailboxState::TASK_READY); + while (true) { + MailboxState state = read_state(mailbox_); + if (state == MailboxState::TASK_RUNNING || + state == MailboxState::TASK_DONE) + break; + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } - // Poll for completion + // Poll for final completion. while (read_state(mailbox_) != MailboxState::TASK_DONE) std::this_thread::sleep_for(std::chrono::microseconds(50)); @@ -187,15 +196,53 @@ Parent-side cost per dispatch: - Poll loop with `sleep_for(50us)` (not busy-wait) - One explicit completion outcome: success, task failure, or endpoint failure -Total ~nanoseconds overhead; the wait is dominated by actual kernel execution. +The parent `run(...)` call still blocks until `TASK_DONE` and returns one +`WorkerCompletion`. The mailbox lock is held only through the payload write +and the `TASK_RUNNING` / `TASK_DONE` acknowledgement. After `TASK_RUNNING`, +the parent run path releases the mailbox lock while it continues waiting for +final completion. The early `TASK_RUNNING` state is a control-channel handoff +point, not a task completion: it means the child has copied the task payload +out of the mailbox so selected async controls can temporarily claim the +mailbox while the task run lane continues. ### 3.2 Child loop The child loop lives in Python — see `_chip_process_loop` and `_sub_worker_loop` in `python/simpler/worker.py`. Each child polls -`MAILBOX_OFF_STATE`, decodes the digest-prefixed args blob on `TASK_READY`, -resolves the digest to its private local slot/callable, writes back any error, -and publishes `TASK_DONE`. +`MAILBOX_OFF_STATE` and decodes the digest-prefixed args blob on +`TASK_READY`. + +Chip children have separate run and register lanes inside the child process: + +```text +TASK_READY: + digest -> private chip slot + copy config and TaskArgs blob out of the mailbox + increment the slot's in-flight run count + enqueue run request on the chip run lane + publish TASK_RUNNING + +chip run lane: + run_prepared_from_blob(private_slot, copied_args, config) + decrement in-flight run count + if slot is tombstoned and in-flight count is zero: + native unregister/free private slot + complete pending unregister handle + publish TASK_DONE for mailbox-dispatched tasks +``` + +While a chip child is in `TASK_RUNNING`, the parent may issue +`CTRL_REGISTER_ASYNC`, `CTRL_WAIT_REGISTER`, `CTRL_RUN_ASYNC`, +`CTRL_WAIT_RUN`, `CTRL_UNREGISTER_ASYNC`, or `CTRL_WAIT_UNREGISTER`. +The control command claims the mailbox as `CONTROL_REQUEST`, the child +publishes `CONTROL_DONE`, and the parent restores `TASK_RUNNING` so the +original dispatch can continue waiting for `TASK_DONE`. Memory and CommDomain +controls are still serialized behind the running task. + +Sub-worker and Worker-child loops do not have a chip run lane. They keep the +historical synchronous behavior: resolve the digest, execute the Python +callable or inner `Worker.run()`, write back any error, and publish +`TASK_DONE`. The child inherits the parent's full address space at fork time, so: - ChipCallable objects (pre-fork allocated) are COW-visible at the same VA diff --git a/examples/a2a3/tensormap_and_ringbuffer/l3_l2_orch_comm_stream/l3_l2_orch_comm_stream.py b/examples/a2a3/tensormap_and_ringbuffer/l3_l2_orch_comm_stream/l3_l2_orch_comm_stream.py index 0feeab4a2..555edb4a3 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/l3_l2_orch_comm_stream/l3_l2_orch_comm_stream.py +++ b/examples/a2a3/tensormap_and_ringbuffer/l3_l2_orch_comm_stream/l3_l2_orch_comm_stream.py @@ -66,7 +66,7 @@ def _build_chip_callable(platform: str) -> ChipCallable: signature=[], func_name="l3_l2_orch_comm_orchestration", binary=orch, - children=[(0, CoreCallable.build(signature=[D.IN, D.OUT], binary=aiv))], + children=[(0, CoreCallable.build(signature=[D.IN, D.OUT], arg_index=[0, 1], binary=aiv))], ) diff --git a/python/bindings/task_interface.cpp b/python/bindings/task_interface.cpp index f42846684..9b1f4e6be 100644 --- a/python/bindings/task_interface.cpp +++ b/python/bindings/task_interface.cpp @@ -903,6 +903,7 @@ NB_MODULE(_task_interface, m) { .def( "prepare_callable", [](ChipWorker &self, int32_t callable_id, const PyChipCallable &callable) { + nb::gil_scoped_release release; self.prepare_callable(callable_id, callable.buffer_.data()); }, nb::arg("callable_id"), nb::arg("callable"), @@ -912,6 +913,7 @@ NB_MODULE(_task_interface, m) { .def( "prepare_callable_from_blob", [](ChipWorker &self, int32_t callable_id, uint64_t blob_ptr) { + nb::gil_scoped_release release; self.prepare_callable(callable_id, reinterpret_cast(blob_ptr)); }, nb::arg("callable_id"), nb::arg("blob_ptr"), @@ -925,6 +927,7 @@ NB_MODULE(_task_interface, m) { .def( "run", [](ChipWorker &self, int32_t callable_id, ChipStorageTaskArgs &args, const CallConfig &config) { + nb::gil_scoped_release release; return self.run(callable_id, &args, config); }, nb::arg("callable_id"), nb::arg("args"), nb::arg("config"), @@ -935,6 +938,7 @@ NB_MODULE(_task_interface, m) { "run", [](ChipWorker &self, int32_t callable_id, TaskArgs &args, const CallConfig &config) { TaskArgsView view = make_view(args); + nb::gil_scoped_release release; return self.run(callable_id, view, config); }, nb::arg("callable_id"), nb::arg("args"), nb::arg("config"), @@ -952,6 +956,7 @@ NB_MODULE(_task_interface, m) { // loops never re-implement the tensor/scalar layout in Python // (where it has historically dropped fields like child_memory). TaskArgsView view = read_blob(reinterpret_cast(args_blob_ptr), blob_capacity); + nb::gil_scoped_release release; return self.run(callable_id, view, config); }, nb::arg("callable_id"), nb::arg("args_blob_ptr"), nb::arg("blob_capacity"), nb::arg("config"), @@ -963,6 +968,7 @@ NB_MODULE(_task_interface, m) { .def( "unregister_callable", [](ChipWorker &self, int32_t callable_id) { + nb::gil_scoped_release release; self.unregister_callable(callable_id); }, nb::arg("callable_id"), diff --git a/python/bindings/worker_bind.h b/python/bindings/worker_bind.h index 5b4893b63..4bd59e685 100644 --- a/python/bindings/worker_bind.h +++ b/python/bindings/worker_bind.h @@ -242,6 +242,13 @@ inline void bind_worker(nb::module_ &m) { .def_ro("ok", &ControlResult::ok) .def_ro("error_message", &ControlResult::error_message); + nb::class_(m, "AsyncControlResult") + .def_ro("worker_type", &AsyncControlResult::worker_type) + .def_ro("worker_id", &AsyncControlResult::worker_id) + .def_ro("ok", &AsyncControlResult::ok) + .def_ro("remote_handle", &AsyncControlResult::remote_handle) + .def_ro("error_message", &AsyncControlResult::error_message); + // --- TaskState --- nb::enum_(m, "TaskState") .value("FREE", TaskState::FREE) @@ -422,10 +429,10 @@ inline void bind_worker(nb::module_ &m) { ) // --- Mailbox control plane (parent side) --- - // These hold the per-WorkerThread mailbox_mu_ inside C++, so they - // serialize against dispatch_process without any Python-side lock. - // Release the GIL during the spin-poll wait so other Python threads - // (e.g. a concurrent Worker.run) can keep running. + // These hold the per-WorkerThread mailbox_mu_ only while claiming and + // driving a control request. TASK_RUNNING dispatches release that + // mutex after the child has copied its args, so control requests can + // overlap the child run lane. .def( "control_prepare", [](Worker &self, int worker_id, nb::object digest) { @@ -448,6 +455,54 @@ inline void bind_worker(nb::module_ &m) { "Stage `blob_size` bytes from `blob_ptr` into a POSIX shm and broadcast " "CTRL_REGISTER to every NEXT_LEVEL child in parallel. Returns per-child status." ) + .def( + "broadcast_register_async_all", + [](Worker &self, uint64_t blob_ptr, uint64_t blob_size, nb::object digest) { + std::string digest_bytes = bytes_from_digest_arg(digest); + nb::gil_scoped_release release; + return self.broadcast_register_async_all( + blob_ptr, blob_size, reinterpret_cast(digest_bytes.data()) + ); + }, + nb::arg("blob_ptr"), nb::arg("blob_size"), nb::arg("digest"), + "Broadcast CTRL_REGISTER_ASYNC to every NEXT_LEVEL child and return remote register handles." + ) + .def( + "control_run_async", + [](Worker &self, int worker_id, nb::object digest, const TaskArgs &args, const CallConfig &config) { + std::string digest_bytes = bytes_from_digest_arg(digest); + nb::gil_scoped_release release; + return self.control_run_async( + worker_id, reinterpret_cast(digest_bytes.data()), args, config + ); + }, + nb::arg("worker_id"), nb::arg("digest"), nb::arg("args"), nb::arg("config"), + "Submit one chip-callable run to a NEXT_LEVEL child run lane and return its remote handle." + ) + .def( + "control_wait_run", + [](Worker &self, int worker_id, uint64_t handle_id) { + nb::gil_scoped_release release; + return self.control_wait_run(worker_id, handle_id); + }, + nb::arg("worker_id"), nb::arg("handle_id"), "Wait for a remote chip run handle and return RunTiming." + ) + .def( + "control_wait_register", + [](Worker &self, int worker_id, uint64_t handle_id) { + nb::gil_scoped_release release; + self.control_wait_register(worker_id, handle_id); + }, + nb::arg("worker_id"), nb::arg("handle_id"), "Wait for a remote async register handle." + ) + .def( + "control_wait_unregister", + [](Worker &self, int worker_id, uint64_t handle_id) { + nb::gil_scoped_release release; + self.control_wait_unregister(worker_id, handle_id); + }, + nb::arg("worker_id"), nb::arg("handle_id"), "Wait for a remote async unregister handle." + ) .def( "control_digest_only", [](Worker &self, WorkerType worker_type, int worker_id, uint64_t sub_cmd, nb::object digest, @@ -675,6 +730,16 @@ inline void bind_worker(nb::module_ &m) { "Best-effort broadcast of CTRL_UNREGISTER to every NEXT_LEVEL child in parallel. " "Returns a list of per-child error strings (empty on full success)." ) + .def( + "broadcast_unregister_async_all", + [](Worker &self, nb::object digest) { + std::string digest_bytes = bytes_from_digest_arg(digest); + nb::gil_scoped_release release; + return self.broadcast_unregister_async_all(reinterpret_cast(digest_bytes.data())); + }, + nb::arg("digest"), + "Broadcast CTRL_UNREGISTER_ASYNC to every NEXT_LEVEL child and return remote unregister handles." + ) .def( "broadcast_control_all", [](Worker &self, WorkerType worker_type, uint64_t sub_cmd, nb::object payload, nb::object digest, @@ -708,8 +773,8 @@ inline void bind_worker(nb::module_ &m) { .def( "control_alloc_domain", &Worker::control_alloc_domain, nb::arg("worker_id"), nb::arg("request_shm_name"), nb::arg("reply_shm_name"), nb::call_guard(), - "Drive one NEXT_LEVEL chip child through CTRL_ALLOC_DOMAIN. Holds mailbox_mu_ " - "so it serialises with task dispatch on the same mailbox. Caller fans out to all " + "Drive one NEXT_LEVEL chip child through CTRL_ALLOC_DOMAIN. Memory/domain controls " + "wait behind in-flight task dispatch. Caller fans out to all " "participating chips in parallel (one Python thread per chip)." ) .def( diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 3e51edc93..c51706490 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -63,6 +63,7 @@ def my_l4_orch(orch, args, config): import importlib import json import os +import queue import re import signal import socket @@ -71,7 +72,7 @@ def my_l4_orch(orch, args, config): import threading import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from multiprocessing.shared_memory import SharedMemory from typing import Any @@ -107,6 +108,7 @@ def my_l4_orch(orch, args, config): CallConfig, ChipCallable, ChipDomainContext, + ChipStorageTaskArgs, ChipWorker, CommBufferSpec, CommDomainHandle, @@ -175,6 +177,7 @@ def my_l4_orch(orch, args, config): # INIT_DONE before allowing any dispatch — keeps cross-rank init skew out # of the per-rank host-side stream sync budget (issue #897). _INIT_DONE = 6 +_TASK_RUNNING = 7 # Control sub-commands (written at _OFF_CALLABLE as uint64) _CTRL_MALLOC = 0 @@ -210,6 +213,12 @@ def my_l4_orch(orch, args, config): _CTRL_PY_UNREGISTER = 11 _CTRL_PY_IMPORT_REGISTER = 12 _CTRL_L3_L2_ORCH_COMM_INIT = 13 +_CTRL_REGISTER_ASYNC = 14 +_CTRL_WAIT_REGISTER = 15 +_CTRL_RUN_ASYNC = 16 +_CTRL_WAIT_RUN = 17 +_CTRL_UNREGISTER_ASYNC = 18 +_CTRL_WAIT_UNREGISTER = 19 # Layout of the CTRL_COMM_INIT request shm. _COMM_INIT_HEADER = struct.Struct(" None: + self._waiter = waiter + self._completed_probe = completed_probe + self._cached: RunTiming | None = None + self._waited = False + self._lock = threading.Lock() + + @property + def completed(self) -> bool: + if self._waited: + return True + return bool(self._completed_probe()) + + def wait(self) -> RunTiming: + with self._lock: + if not self._waited: + self._cached = self._waiter() + self._waited = True + assert self._cached is not None + return self._cached + + +class RegisterHandle: + """Completion handle returned by Worker.register_async().""" + + def __init__(self, waiter, completed_probe) -> None: + self._waiter = waiter + self._completed_probe = completed_probe + self._cached: CallableHandle | None = None + self._waited = False + self._lock = threading.Lock() + + @property + def completed(self) -> bool: + if self._waited: + return True + return bool(self._completed_probe()) + + def wait(self) -> CallableHandle: + with self._lock: + if not self._waited: + self._cached = self._waiter() + self._waited = True + assert self._cached is not None + return self._cached + + +class UnregisterHandle: + """Completion handle returned by Worker.unregister_async().""" + + def __init__(self, waiter, completed_probe) -> None: + self._waiter = waiter + self._completed_probe = completed_probe + self._waited = False + self._lock = threading.Lock() + + @property + def completed(self) -> bool: + if self._waited: + return True + return bool(self._completed_probe()) + + def wait(self) -> None: + with self._lock: + if not self._waited: + self._waiter() + self._waited = True + + @dataclass(frozen=True) class RemoteCallable: """Import-path descriptor for a parent-facing remote L3 callable.""" @@ -528,6 +713,10 @@ def _validate_chip_payload_digest( _validate_descriptor_digest(expected=digest, descriptor=descriptor, context=context) +def _chip_callable_bytes(target: ChipCallable) -> bytes: + return ctypes.string_at(int(target.buffer_ptr()), int(target.buffer_size())) + + def _read_py_callable_payload_from_shm(shm_name: str) -> bytes: shm = SharedMemory(name=shm_name) shm_buf = shm.buf @@ -555,7 +744,7 @@ def _read_py_callable_payload_from_shm(shm_name: str) -> bytes: def _read_raw_payload_from_shm(shm_name: str, payload_size: int) -> bytes: - shm = SharedMemory(name=shm_name) + shm = _attach_cpp_owned_shm(shm_name) shm_buf = shm.buf assert shm_buf is not None try: @@ -568,7 +757,7 @@ def _read_raw_payload_from_shm(shm_name: str, payload_size: int) -> bytes: def _read_chip_callable_from_shm(shm_name: str, payload_size: int) -> ChipCallable: - shm = SharedMemory(name=shm_name) + shm = _attach_cpp_owned_shm(shm_name) shm_buf = shm.buf assert shm_buf is not None try: @@ -592,6 +781,19 @@ def _load_py_callable_from_shm(shm_name: str): return _load_py_callable_from_payload(_read_py_callable_payload_from_shm(shm_name)) +def _attach_cpp_owned_shm(shm_name: str) -> SharedMemory: + shm = SharedMemory(name=shm_name) + normalized = shm_name.lstrip("/") + if normalized.startswith("simpler-cb-"): + try: + from multiprocessing import resource_tracker # noqa: PLC0415 + + resource_tracker.unregister(shm._name, "shared_memory") # pyright: ignore[reportAttributeAccessIssue] + except Exception: + pass + return shm + + def _load_py_import_target(target: str): module_name, qualname = parse_python_import_target(target) obj = importlib.import_module(module_name) @@ -614,6 +816,88 @@ def _format_digest(digest: bytes) -> str: return "sha256:" + digest.hex() +def _read_config_from_bytes(data: bytes | bytearray | memoryview, offset: int) -> CallConfig: + ( + block_dim, + aicpu_tn, + swl, + dt, + pmu, + dep_gen, + scope_stats, + *ring_values, + prefix_bytes, + ) = _CFG_FMT.unpack_from(data, offset) + ring_task_window = list(ring_values[:RUNTIME_ENV_RING_COUNT]) + ring_heap = list(ring_values[RUNTIME_ENV_RING_COUNT : 2 * RUNTIME_ENV_RING_COUNT]) + ring_dep_pool = list(ring_values[2 * RUNTIME_ENV_RING_COUNT : 3 * RUNTIME_ENV_RING_COUNT]) + cfg = CallConfig() + cfg.block_dim = block_dim + cfg.aicpu_thread_num = aicpu_tn + cfg.enable_l2_swimlane = swl + cfg.enable_dump_tensor = int(dt) + cfg.enable_pmu = pmu + cfg.enable_dep_gen = bool(dep_gen) + cfg.enable_scope_stats = bool(scope_stats) + cfg.runtime_env.ring_task_window = ring_task_window + cfg.runtime_env.ring_heap = ring_heap + cfg.runtime_env.ring_dep_pool = ring_dep_pool + cfg.output_prefix = prefix_bytes.split(b"\x00", 1)[0].decode("utf-8") + return cfg + + +def _decode_async_run_request(payload: bytes) -> tuple[bytes, bytes, CallConfig]: + if len(payload) < _ASYNC_RUN_OFF_ARGS_BLOB + 8: + raise RuntimeError(f"RUN_ASYNC payload too small: {len(payload)} bytes") + digest = bytes(payload[_ASYNC_RUN_OFF_DIGEST : _ASYNC_RUN_OFF_DIGEST + CALLABLE_HASH_DIGEST_BYTES]) + cfg = _read_config_from_bytes(payload, _ASYNC_RUN_OFF_CONFIG) + return digest, bytes(payload[_ASYNC_RUN_OFF_ARGS_BLOB:]), cfg + + +def _copy_call_config(config: CallConfig | None) -> CallConfig: + src = config if config is not None else CallConfig() + dst = CallConfig() + dst.block_dim = src.block_dim + dst.aicpu_thread_num = src.aicpu_thread_num + dst.enable_l2_swimlane = src.enable_l2_swimlane + dst.enable_dump_tensor = int(src.enable_dump_tensor) + dst.enable_pmu = src.enable_pmu + dst.enable_dep_gen = bool(src.enable_dep_gen) + dst.enable_scope_stats = bool(src.enable_scope_stats) + dst.runtime_env.ring_task_window = list(src.runtime_env.ring_task_window) + dst.runtime_env.ring_heap = list(src.runtime_env.ring_heap) + dst.runtime_env.ring_dep_pool = list(src.runtime_env.ring_dep_pool) + dst.output_prefix = src.output_prefix + return dst + + +def _copy_run_args(args, *, prefer_task_args: bool) -> ChipStorageTaskArgs | TaskArgs: + if args is None: + return TaskArgs() if prefer_task_args else ChipStorageTaskArgs() + if isinstance(args, TaskArgs): + copied = TaskArgs() + for i in range(args.tensor_count()): + copied.add_tensor(args.tensor(i), args.tag(i)) + for i in range(args.scalar_count()): + copied.add_scalar(args.scalar(i)) + return copied + if isinstance(args, ChipStorageTaskArgs): + if prefer_task_args: + copied_task = TaskArgs() + for i in range(args.tensor_count()): + copied_task.add_tensor(args.tensor(i)) + for i in range(args.scalar_count()): + copied_task.add_scalar(args.scalar(i)) + return copied_task + copied_chip = ChipStorageTaskArgs() + for i in range(args.tensor_count()): + copied_chip.add_tensor(args.tensor(i)) + for i in range(args.scalar_count()): + copied_chip.add_scalar(args.scalar(i)) + return copied_chip + raise TypeError("run_async args must be TaskArgs, ChipStorageTaskArgs, or None") + + def _handle_py_callable_control( buf, registry: dict[int, Any], @@ -954,6 +1238,122 @@ def _ensure_prepared(cw, registry, prepared, cid: int, *, lazy: bool, device_id: prepared.add(cid) +def _wait_child_run_state(state: _ChildRunState) -> RunTiming: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + assert state.result is not None + return state.result + + +def _wait_child_register_state(state: _ChildRegisterState) -> None: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + + +def _wait_child_unregister_state(state: _ChildUnregisterState) -> None: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + + +def _complete_child_run(state: _ChildRunState, *, result: RunTiming | None = None, error: BaseException | None = None): + with state.cv: + state.result = result + state.error = error + state.completed = True + state.cv.notify_all() + + +def _complete_child_register(state: _ChildRegisterState, error: BaseException | None = None) -> None: + with state.cv: + state.error = error + state.completed = True + state.cv.notify_all() + + +def _complete_child_unregister(state: _ChildUnregisterState, error: BaseException | None = None) -> None: + with state.cv: + state.error = error + state.completed = True + state.cv.notify_all() + + +def _prepare_child_register_bytes( # noqa: PLR0913 + cw: ChipWorker, + registry: dict[int, Any], + identity_table: dict[bytes, int], + identity_refs: dict[bytes, int], + prepared: set[int], + registry_lock: threading.Condition, + digest: bytes, + callable_bytes: bytearray, + *, + chip_platform: str, + chip_runtime: str, + device_id: int, +) -> int: + callable_obj = ChipCallable.from_bytes(bytes(callable_bytes)) + _validate_chip_payload_digest( + callable_obj, + digest, + platform=chip_platform, + runtime=chip_runtime, + context=f"chip_process dev={device_id}", + ) + with registry_lock: + if digest in identity_table: + identity_refs[digest] = identity_refs.get(digest, 1) + 1 + return int(identity_table[digest]) + cid = _install_local_identity(registry, identity_table, identity_refs, digest, callable_obj) + if int(cid) in prepared: + try: + cw._unregister_slot(int(cid)) + except Exception: # noqa: BLE001 + pass + prepared.discard(int(cid)) + exported = ctypes.c_char.from_buffer(callable_bytes) + try: + cw._impl.prepare_callable_from_blob(int(cid), ctypes.addressof(exported)) + except Exception: + if registry.get(int(cid)) is callable_obj: + registry.pop(int(cid), None) + identity_table.pop(digest, None) + identity_refs.pop(digest, None) + raise + finally: + del exported + prepared.add(int(cid)) + return int(cid) + + +def _unregister_child_digest( + cw: ChipWorker, + registry: dict[int, Any], + identity_table: dict[bytes, int], + identity_refs: dict[bytes, int], + prepared: set[int], + registry_lock: threading.Condition, + active_cids: dict[int, int], + digest: bytes, +) -> None: + with registry_lock: + cid, removed = _remove_local_identity(registry, identity_table, identity_refs, digest) + if not removed or cid is None: + return + while active_cids.get(int(cid), 0) > 0: + registry_lock.wait() + cw._unregister_slot(int(cid)) + prepared.discard(int(cid)) + + def _run_chip_main_loop( # noqa: PLR0912, PLR0913, PLR0915 -- unified TASK_READY / CONTROL_REQUEST state machine cw: ChipWorker, buf: memoryview, @@ -984,42 +1384,187 @@ def _run_chip_main_loop( # noqa: PLR0912, PLR0913, PLR0915 -- unified TASK_READ """ prepared: set[int] = set() l3_l2_control_shms: list[SharedMemory] = [] + registry_cv = threading.Condition(threading.RLock()) + active_cids: dict[int, int] = {} + tombstoned_cids: set[int] = set() + tombstoned_digests: set[bytes] = set() + deferred_unregister: dict[int, _ChildUnregisterState] = {} + digest_by_cid: dict[int, bytes] = {int(cid): digest for digest, cid in identity_table.items()} + run_queue: queue.Queue = queue.Queue() + control_queue: queue.Queue = queue.Queue() + run_states: dict[int, _ChildRunState] = {} + register_states: dict[int, _ChildRegisterState] = {} + unregister_states: dict[int, _ChildUnregisterState] = {} + next_run_handle = 1 + next_register_handle = 1 + next_unregister_handle = 1 + + def child_native_unregister_and_finish(cid: int, digest: bytes, state: _ChildUnregisterState) -> None: + try: + cw._unregister_slot(int(cid)) + _complete_child_unregister(state) + except BaseException as e: # noqa: BLE001 + _complete_child_unregister(state, e) + finally: + with registry_cv: + registry.pop(int(cid), None) + prepared.discard(int(cid)) + tombstoned_cids.discard(int(cid)) + tombstoned_digests.discard(digest) + deferred_unregister.pop(int(cid), None) + digest_by_cid.pop(int(cid), None) + registry_cv.notify_all() + + def release_child_inflight(cid: int) -> None: + cleanup_state: _ChildUnregisterState | None = None + cleanup_digest = b"" + with registry_cv: + count = active_cids.get(int(cid), 0) + if count <= 1: + active_cids.pop(int(cid), None) + if int(cid) in tombstoned_cids: + cleanup_state = deferred_unregister.get(int(cid)) + cleanup_digest = digest_by_cid.get(int(cid), b"") + else: + active_cids[int(cid)] = count - 1 + registry_cv.notify_all() + if cleanup_state is not None: + child_native_unregister_and_finish(int(cid), cleanup_digest, cleanup_state) + + def hold_child_cid_for_run(digest: bytes) -> int: + with registry_cv: + cid = identity_table.get(digest) + if cid is None or int(cid) in tombstoned_cids: + raise RuntimeError(f"callable hash {_format_digest(digest)} not registered") + _ensure_prepared(cw, registry, prepared, int(cid), lazy=True, device_id=device_id) + active_cids[int(cid)] = active_cids.get(int(cid), 0) + 1 + digest_by_cid[int(cid)] = digest + return int(cid) + + def submit_child_unregister(digest: bytes, state: _ChildUnregisterState) -> None: + cleanup_cid: int | None = None + cleanup_digest = b"" + with registry_cv: + cid = identity_table.get(digest) + if cid is None: + _complete_child_unregister(state) + return + refs = identity_refs.get(digest, 1) - 1 + if refs > 0: + identity_refs[digest] = refs + _complete_child_unregister(state) + return + + cid_i = int(cid) + identity_refs.pop(digest, None) + identity_table.pop(digest, None) + tombstoned_cids.add(cid_i) + tombstoned_digests.add(digest) + digest_by_cid[cid_i] = digest + + if active_cids.get(cid_i, 0) == 0: + cleanup_cid = cid_i + cleanup_digest = digest + else: + deferred_unregister[cid_i] = state + + if cleanup_cid is not None: + child_native_unregister_and_finish(cleanup_cid, cleanup_digest, state) + + def publish_task_done() -> None: + while _mailbox_load_i32(state_addr) in (_CONTROL_REQUEST, _CONTROL_DONE): + time.sleep(0.00001) + _mailbox_store_i32(state_addr, _TASK_DONE) + + def run_thread_loop() -> None: + while True: + req = run_queue.get() + if req is None: + break + try: + args_blob = bytearray(req.args_blob) + exported = ctypes.c_char.from_buffer(args_blob) + try: + timing = cw._impl.run_prepared_from_blob( + int(req.cid), + ctypes.addressof(exported), + len(args_blob), + req.config, + ) + finally: + del exported + code = 0 + msg = "" + if req.publish_task_done and on_task_done_success is not None: + code, msg = on_task_done_success() + _complete_child_run(req.state, result=timing) + except BaseException as e: # noqa: BLE001 + code = 1 + msg = _format_exc(f"chip_process dev={device_id}", e) + _complete_child_run(req.state, error=e) + finally: + release_child_inflight(int(req.cid)) + if req.publish_task_done: + _write_error(buf, code, msg) + publish_task_done() + + def control_thread_loop() -> None: + while True: + req = control_queue.get() + if req is None: + break + try: + with registry_cv: + if req.digest in tombstoned_digests: + raise RuntimeError(f"callable hash {_format_digest(req.digest)} is pending unregister") + cid = _prepare_child_register_bytes( + cw, + registry, + identity_table, + identity_refs, + prepared, + registry_cv, + req.digest, + req.callable_bytes, + chip_platform=chip_platform, + chip_runtime=chip_runtime, + device_id=device_id, + ) + digest_by_cid[int(cid)] = req.digest + _complete_child_register(req.state) + except BaseException as e: # noqa: BLE001 + _complete_child_register(req.state, e) + + run_thread = threading.Thread(target=run_thread_loop, name=f"simpler-chip-run-{device_id}", daemon=True) + control_thread = threading.Thread( + target=control_thread_loop, + name=f"simpler-chip-control-{device_id}", + daemon=True, + ) + run_thread.start() + control_thread.start() try: while True: state = _mailbox_load_i32(state_addr) if state == _TASK_READY: digest = _read_task_digest(buf) - cid = identity_table.get(digest) cfg = _read_config_from_mailbox(buf) + args_blob = bytes(buf[_OFF_TASK_ARGS_BLOB : _OFF_TASK_ARGS_BLOB + _MAILBOX_ARGS_CAPACITY]) code = 0 msg = "" + cid: int | None = None + run_state: _ChildRunState | None = None try: - if cid is None: - raise RuntimeError(f"callable hash {_format_digest(digest)} not registered") - _ensure_prepared(cw, registry, prepared, cid, lazy=True, device_id=device_id) - # Hand the mailbox bytes straight to C++ (zero-copy zero-decode): - # the blob layout is what `write_blob` already wrote, so re-parsing - # it in Python is N×40B of avoidable work and a permanent - # opportunity to drop a field. C++ reinterpret_cast - # is the source of truth. - cw._impl.run_prepared_from_blob( - cid, mailbox_addr + _OFF_TASK_ARGS_BLOB, _MAILBOX_ARGS_CAPACITY, cfg - ) + cid = hold_child_cid_for_run(digest) + run_state = _ChildRunState() + run_queue.put(_ChildRunRequest(int(cid), args_blob, cfg, run_state, True)) + _mailbox_store_i32(state_addr, _TASK_RUNNING) except Exception as e: # noqa: BLE001 code = 1 msg = _format_exc(f"chip_process dev={device_id}", e) - - # On a successful kernel run, give the caller a chance to do - # post-run work (e.g. store_to_host D2H staging) before the - # parent sees TASK_DONE. The kernel's failure path skips the - # hook because the device output region is undefined and - # staging garbage would mask the real error in post-mortems. - if code == 0 and on_task_done_success is not None: - code, msg = on_task_done_success() - - _write_error(buf, code, msg) - _mailbox_store_i32(state_addr, _TASK_DONE) + _write_error(buf, code, msg) + _mailbox_store_i32(state_addr, _TASK_DONE) elif state == _CONTROL_REQUEST: sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0] code = 0 @@ -1044,68 +1589,100 @@ def _run_chip_main_loop( # noqa: PLR0912, PLR0913, PLR0915 -- unified TASK_READ cw.copy_from(dst, src, n) elif sub_cmd == _CTRL_PREPARE: digest = _read_control_digest(buf) - cid = identity_table.get(digest) - if cid is None: - raise RuntimeError( - f"prepare chip={device_id}: callable hash {_format_digest(digest)} not registered" - ) - _ensure_prepared(cw, registry, prepared, int(cid), lazy=False, device_id=device_id) + with registry_cv: + cid = identity_table.get(digest) + if cid is None or int(cid) in tombstoned_cids: + raise RuntimeError( + f"prepare chip={device_id}: callable hash {_format_digest(digest)} not registered" + ) + _ensure_prepared(cw, registry, prepared, int(cid), lazy=False, device_id=device_id) elif sub_cmd == _CTRL_REGISTER: digest = _read_control_digest(buf) payload_size = struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0] - raw = bytes(buf[_OFF_ARGS : _OFF_ARGS + _CTRL_SHM_NAME_BYTES]) - nul = raw.find(b"\x00") - shm_name = raw[: nul if nul >= 0 else _CTRL_SHM_NAME_BYTES].decode("utf-8", "replace") - shm = SharedMemory(name=shm_name) - shm_buf = shm.buf - assert shm_buf is not None - try: - if payload_size <= 0 or payload_size > shm.size: - raise RuntimeError( - f"CTRL_REGISTER payload size mismatch: payload={payload_size}, shm={shm.size}" - ) - callable_obj = ChipCallable.from_bytes(bytes(shm_buf[:payload_size])) - _validate_chip_payload_digest( - callable_obj, + shm_name = _read_shm_name(buf, _OFF_ARGS) + callable_bytes = bytearray(_read_raw_payload_from_shm(shm_name, int(payload_size))) + with registry_cv: + if digest in tombstoned_digests: + raise RuntimeError(f"callable hash {_format_digest(digest)} is pending unregister") + cid = _prepare_child_register_bytes( + cw, + registry, + identity_table, + identity_refs, + prepared, + registry_cv, digest, - platform=chip_platform, - runtime=chip_runtime, - context=f"chip_process dev={device_id}", + callable_bytes, + chip_platform=chip_platform, + chip_runtime=chip_runtime, + device_id=device_id, ) - if digest in identity_table: - identity_refs[digest] = identity_refs.get(digest, 1) + 1 - else: - cid = _install_local_identity( - registry, identity_table, identity_refs, digest, callable_obj - ) - # Self-heal when a prior unregister popped the local - # identity table but failed before clearing device - # prepared state for the reusable private slot. - if int(cid) in prepared: - try: - cw._unregister_slot(int(cid)) - except Exception: # noqa: BLE001 - pass - prepared.discard(int(cid)) - exported = ctypes.c_char.from_buffer(shm_buf) - try: - addr = ctypes.addressof(exported) - cw._impl.prepare_callable_from_blob(int(cid), addr) - finally: - del exported - prepared.add(int(cid)) + digest_by_cid[int(cid)] = digest + elif sub_cmd == _CTRL_REGISTER_ASYNC: + digest = _read_control_digest(buf) + payload_size = struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0] + shm_name = _read_shm_name(buf, _OFF_ARGS) + callable_bytes = bytearray(_read_raw_payload_from_shm(shm_name, int(payload_size))) + reg_state = _ChildRegisterState() + handle_id = next_register_handle + next_register_handle += 1 + register_states[handle_id] = reg_state + control_queue.put(_ChildRegisterRequest(digest, callable_bytes, reg_state)) + struct.pack_into("Q", buf, _CTRL_OFF_RESULT, int(handle_id)) + elif sub_cmd == _CTRL_WAIT_REGISTER: + handle_id = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) + reg_state = register_states.get(handle_id) + if reg_state is None: + raise RuntimeError(f"WAIT_REGISTER unknown handle {handle_id}") + try: + _wait_child_register_state(reg_state) + finally: + register_states.pop(handle_id, None) + elif sub_cmd == _CTRL_RUN_ASYNC: + payload_size = struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0] + shm_name = _read_shm_name(buf, _OFF_ARGS) + payload = _read_raw_payload_from_shm(shm_name, int(payload_size)) + digest, args_blob, cfg = _decode_async_run_request(payload) + run_state = _ChildRunState() + handle_id = next_run_handle + next_run_handle += 1 + run_states[handle_id] = run_state + cid = hold_child_cid_for_run(digest) + run_queue.put(_ChildRunRequest(cid, args_blob, cfg, run_state)) + struct.pack_into("Q", buf, _CTRL_OFF_RESULT, int(handle_id)) + elif sub_cmd == _CTRL_WAIT_RUN: + handle_id = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) + run_state = run_states.get(handle_id) + if run_state is None: + raise RuntimeError(f"WAIT_RUN unknown handle {handle_id}") + try: + timing = _wait_child_run_state(run_state) finally: - shm_buf.release() - # Release the local mmap as soon as prepare returns; - # prepare_callable has already H2D-copied the bytes to - # device GM, so the child no longer needs the shm. - shm.close() + run_states.pop(handle_id, None) + struct.pack_into("Q", buf, _CTRL_OFF_RESULT, int(timing.host_wall_ns)) + struct.pack_into("Q", buf, _CTRL_OFF_RESULT1, int(timing.device_wall_ns)) elif sub_cmd == _CTRL_UNREGISTER: digest = _read_control_digest(buf) - cid, removed = _remove_local_identity(registry, identity_table, identity_refs, digest) - if removed and cid is not None: - cw._unregister_slot(int(cid)) - prepared.discard(int(cid)) + unreg_state = _ChildUnregisterState() + submit_child_unregister(digest, unreg_state) + _wait_child_unregister_state(unreg_state) + elif sub_cmd == _CTRL_UNREGISTER_ASYNC: + digest = _read_control_digest(buf) + unreg_state = _ChildUnregisterState() + handle_id = next_unregister_handle + next_unregister_handle += 1 + unregister_states[handle_id] = unreg_state + submit_child_unregister(digest, unreg_state) + struct.pack_into("Q", buf, _CTRL_OFF_RESULT, int(handle_id)) + elif sub_cmd == _CTRL_WAIT_UNREGISTER: + handle_id = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) + unreg_state = unregister_states.get(handle_id) + if unreg_state is None: + raise RuntimeError(f"WAIT_UNREGISTER unknown handle {handle_id}") + try: + _wait_child_unregister_state(unreg_state) + finally: + unregister_states.pop(handle_id, None) elif sub_cmd == _CTRL_ALLOC_DOMAIN: _handle_ctrl_alloc_domain(cw, buf) elif sub_cmd == _CTRL_RELEASE_DOMAIN: @@ -1118,16 +1695,33 @@ def _run_chip_main_loop( # noqa: PLR0912, PLR0913, PLR0915 -- unified TASK_READ raise RuntimeError(f"unknown control sub-command {int(sub_cmd)}") except Exception as e: # noqa: BLE001 code = 1 - if sub_cmd in (_CTRL_REGISTER, _CTRL_UNREGISTER): - op = "register" if sub_cmd == _CTRL_REGISTER else "unregister" + if sub_cmd in ( + _CTRL_REGISTER, + _CTRL_REGISTER_ASYNC, + _CTRL_WAIT_REGISTER, + _CTRL_UNREGISTER, + _CTRL_UNREGISTER_ASYNC, + _CTRL_WAIT_UNREGISTER, + ): + op = ( + "register" + if sub_cmd in (_CTRL_REGISTER, _CTRL_REGISTER_ASYNC, _CTRL_WAIT_REGISTER) + else "unregister" + ) msg = _format_exc(f"{op} hash={_format_digest(_read_control_digest(buf))} chip={device_id}", e) else: msg = _format_exc(f"chip_process dev={device_id} ctrl={int(sub_cmd)}", e) _write_error(buf, code, msg) _mailbox_store_i32(state_addr, _CONTROL_DONE) + elif state == _TASK_RUNNING: + time.sleep(0.00001) elif state == _SHUTDOWN: break finally: + run_queue.put(None) + control_queue.put(None) + run_thread.join() + control_thread.join() if l3_l2_control_shms: try: cw.l3_l2_orch_comm_shutdown() @@ -1208,34 +1802,7 @@ def _chip_process_loop( def _read_config_from_mailbox(buf: memoryview) -> CallConfig: """Reconstruct a CallConfig from the unified mailbox layout.""" - ( - block_dim, - aicpu_tn, - swl, - dt, - pmu, - dep_gen, - scope_stats, - *ring_values, - prefix_bytes, - ) = _CFG_FMT.unpack_from(buf, _OFF_CONFIG) - ring_task_window = list(ring_values[:RUNTIME_ENV_RING_COUNT]) - ring_heap = list(ring_values[RUNTIME_ENV_RING_COUNT : 2 * RUNTIME_ENV_RING_COUNT]) - ring_dep_pool = list(ring_values[2 * RUNTIME_ENV_RING_COUNT : 3 * RUNTIME_ENV_RING_COUNT]) - cfg = CallConfig() - cfg.block_dim = block_dim - cfg.aicpu_thread_num = aicpu_tn - cfg.enable_l2_swimlane = swl - cfg.enable_dump_tensor = int(dt) - cfg.enable_pmu = pmu - cfg.enable_dep_gen = bool(dep_gen) - cfg.enable_scope_stats = bool(scope_stats) - cfg.runtime_env.ring_task_window = ring_task_window - cfg.runtime_env.ring_heap = ring_heap - cfg.runtime_env.ring_dep_pool = ring_dep_pool - # NUL-terminated C string in a 1024-byte field. - cfg.output_prefix = prefix_bytes.split(b"\x00", 1)[0].decode("utf-8") - return cfg + return _read_config_from_bytes(buf, _OFF_CONFIG) def _child_worker_loop( @@ -1387,10 +1954,20 @@ def __init__( # Level-2 internals self._chip_worker: ChipWorker | None = None + self._run_queue: queue.Queue | None = None + self._run_thread: threading.Thread | None = None + self._register_queue: queue.Queue | None = None + self._register_thread: threading.Thread | None = None + self._run_thread_stop = False + self._slot_inflight: dict[int, int] = {} + self._slot_tombstoned: set[int] = set() + self._slot_pending_unregister: dict[int, _UnregisterState] = {} # Level-3+ internals self._worker: _Worker | None = None self._orch: Orchestrator | None = None + self._dag_run_queue: queue.Queue | None = None + self._dag_run_thread: threading.Thread | None = None self._chip_shms: list[SharedMemory] = [] self._chip_pids: list[int] = [] self._sub_shms: list[SharedMemory] = [] @@ -1483,6 +2060,154 @@ def _remote_session_timeout_s(self) -> float: raise ValueError("Worker remote_session_timeout_s must be positive") return timeout_s + def _start_l2_lanes(self) -> None: + if self._run_thread is not None: + return + self._run_queue = queue.Queue() + self._register_queue = queue.Queue() + self._run_thread_stop = False + self._run_thread = threading.Thread(target=self._l2_run_thread_loop, name="simpler-l2-run", daemon=True) + self._register_thread = threading.Thread( + target=self._l2_register_thread_loop, + name="simpler-l2-register", + daemon=True, + ) + self._run_thread.start() + self._register_thread.start() + + def _stop_l2_lanes(self) -> None: + q = self._run_queue + t = self._run_thread + rq = self._register_queue + rt = self._register_thread + if q is None or t is None: + return + self._run_thread_stop = True + q.put(None) + if rq is not None: + rq.put(None) + t.join() + if rt is not None: + rt.join() + self._run_queue = None + self._run_thread = None + self._register_queue = None + self._register_thread = None + + def _start_l2_run_lane(self) -> None: + self._start_l2_lanes() + + def _stop_l2_run_lane(self) -> None: + self._stop_l2_lanes() + + def _complete_local_register(self, state: _LocalRegisterState, error: BaseException | None = None) -> None: + with state.cv: + state.error = error + state.completed = True + state.cv.notify_all() + + def _wait_local_register(self, state: _LocalRegisterState) -> None: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + + def _l2_run_thread_loop(self) -> None: + assert self._run_queue is not None + while True: + req = self._run_queue.get() + if req is None: + break + assert isinstance(req, _LocalRunRequest) + try: + assert self._chip_worker is not None + timing = self._chip_worker._run_slot(req.slot_id, req.args, req.config) + with req.state.cv: + req.state.result = timing + req.state.completed = True + req.state.cv.notify_all() + except BaseException as e: # noqa: BLE001 + with req.state.cv: + req.state.error = e + req.state.completed = True + req.state.cv.notify_all() + finally: + self._release_l2_slot_inflight(req.slot_id) + + def _l2_register_thread_loop(self) -> None: + assert self._register_queue is not None + while True: + req = self._register_queue.get() + if req is None: + break + assert isinstance(req, _LocalRegisterRequest) + try: + assert self._chip_worker is not None + callable_obj = ChipCallable.from_bytes(req.callable_bytes) + platform = str(self._config.get("platform", "")) + runtime = str(self._config.get("runtime", "")) + _validate_chip_payload_digest( + callable_obj, + req.digest, + platform=platform, + runtime=runtime, + context="Worker.register_async level=2", + ) + self._chip_worker._prepare_callable_at_slot(req.slot_id, callable_obj) + self._complete_local_register(req.state) + except BaseException as e: # noqa: BLE001 + self._complete_local_register(req.state, e) + + def _complete_unregister_state(self, state: _UnregisterState, error: BaseException | None = None) -> None: + with state.cv: + state.error = error + state.completed = True + state.cv.notify_all() + + def _wait_unregister_state(self, state: _UnregisterState) -> None: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + + def _release_l2_slot_inflight(self, slot_id: int) -> None: + cleanup_state: _UnregisterState | None = None + cleanup_digest = b"" + with self._registry_lock: + count = self._slot_inflight.get(int(slot_id), 0) + if count <= 1: + self._slot_inflight.pop(int(slot_id), None) + if int(slot_id) in self._slot_tombstoned: + cleanup_state = self._slot_pending_unregister.get(int(slot_id)) + for digest, state in self._identity_registry.items(): + if state.slot_id == int(slot_id): + cleanup_digest = digest + break + else: + self._slot_inflight[int(slot_id)] = count - 1 + if cleanup_state is not None: + self._l2_native_unregister_and_finish(int(slot_id), cleanup_digest, cleanup_state) + + def _l2_native_unregister_and_finish(self, slot_id: int, digest: bytes, state: _UnregisterState) -> None: + try: + assert self._chip_worker is not None + self._chip_worker._unregister_slot(int(slot_id)) + self._complete_unregister_state(state) + except BaseException as e: # noqa: BLE001 + if digest: + self._uncertain_hashids.add(digest) + self._complete_unregister_state(state, e) + finally: + with self._registry_lock: + self._callable_registry.pop(int(slot_id), None) + if digest: + self._identity_registry.pop(digest, None) + self._pending_unregister_cids.discard(int(slot_id)) + self._slot_tombstoned.discard(int(slot_id)) + self._slot_pending_unregister.pop(int(slot_id), None) + @staticmethod def _send_remote_daemon_json(sock: socket.socket, payload: dict[str, Any]) -> None: data = json.dumps(payload, sort_keys=True).encode("utf-8") @@ -2062,6 +2787,95 @@ def _resolve_handle( return self._resolve_handle_locked(handle, expected_namespace=expected_namespace) def register(self, target, *, workers: list[int] | None = None) -> CallableHandle: + if isinstance(target, ChipCallable) and not isinstance(target, RemoteCallable): + return self.register_async(target, workers=workers).wait() + return self._register_sync_impl(target, workers=workers) + + def register_async(self, target, *, workers: list[int] | None = None) -> RegisterHandle: + if not isinstance(target, ChipCallable) or isinstance(target, RemoteCallable): + raise TypeError("Worker.register_async only supports ChipCallable") + if workers is not None: + raise TypeError("Worker.register_async: workers= is only supported for synchronous RemoteCallable register") + + if self.level >= 3 and self._initialized: + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + reg = _build_callable_registration(self, target, workers=workers) + with self._registry_lock: + handle, is_new = self._install_registration_locked(reg) + try: + submitted = self._post_init_register_async(target, handle.digest, is_new=is_new) + except Exception: + with self._registry_lock: + self._rollback_handle_locked(handle) + raise + + completed = False + + def wait() -> CallableHandle: + nonlocal completed + try: + if submitted: + self._wait_async_register_results(handle.digest, submitted) + completed = True + return handle + except Exception: + with self._registry_lock: + self._rollback_handle_locked(handle) + raise + + return RegisterHandle(wait, lambda: completed) + + if self.level == 2 and self._initialized: + reg = _build_callable_registration(self, target, workers=workers) + with self._registry_lock: + handle, is_new = self._install_registration_locked(reg) + return self._l2_submit_register_async(handle, target, is_new=is_new) + + reg = _build_callable_registration(self, target, workers=workers) + with self._registry_lock: + handle, _is_new = self._install_registration_locked(reg) + completed = True + return RegisterHandle(lambda: handle, lambda: completed) + + def _l2_submit_register_async( + self, handle: CallableHandle, target: ChipCallable, *, is_new: bool + ) -> RegisterHandle: + if not is_new: + completed = True + return RegisterHandle(lambda: handle, lambda: completed) + if self._register_queue is None: + raise RuntimeError("Worker.register_async: L2 register lane is not started") + with self._registry_lock: + slot_id = int(self._identity_registry[handle.digest].slot_id) + state = _LocalRegisterState() + self._register_queue.put( + _LocalRegisterRequest( + digest=handle.digest, + slot_id=slot_id, + callable_bytes=_chip_callable_bytes(target), + state=state, + ) + ) + completed = False + + def wait() -> CallableHandle: + nonlocal completed + try: + self._wait_local_register(state) + except Exception: + with self._registry_lock: + self._rollback_handle_locked(handle) + raise + completed = True + return handle + + return RegisterHandle(wait, lambda: completed or state.completed) + + def _register_sync_impl(self, target, *, workers: list[int] | None = None) -> CallableHandle: """Register a callable for dispatch and return an opaque handle. Integer execution slots remain private to the local target process. @@ -2504,9 +3318,9 @@ def _post_init_register(self, target: ChipCallable, digest: bytes, *, is_new: bo """Broadcast a new ChipCallable to every NEXT_LEVEL child via C++. Delegates the entire shm-staging + per-child mailbox handshake to - ``_Worker.broadcast_register_all``, which holds per-WorkerThread - ``mailbox_mu_`` so the broadcast serializes against any in-flight - dispatch on each child mailbox. No Python lock required. + ``_Worker.broadcast_register_all``. Public ``register`` now reaches + the async path, but this helper remains for internal synchronous + cleanup paths. """ # Chip children are forked lazily on the first Worker.run() via # _start_hierarchical; before that point the chip mailboxes have no @@ -2531,6 +3345,69 @@ def _post_init_register(self, target: ChipCallable, digest: bytes, *, is_new: bo self._uncertain_hashids.add(digest) raise RuntimeError(self._format_register_partial_failure(digest, errors, cleanup_errors)) + def _post_init_register_async(self, target: ChipCallable, digest: bytes, *, is_new: bool) -> list[Any]: + if not getattr(self, "_hierarchical_started", False): + return [] + assert self._worker is not None + try: + results = self._worker.broadcast_register_async_all( + int(target.buffer_ptr()), + int(target.buffer_size()), + digest, + ) + except Exception: + cleanup_errors = self._cleanup_chip_registration(digest) if is_new else [] + if cleanup_errors: + with self._registry_lock: + self._uncertain_hashids.add(digest) + raise + errors = self._control_errors(list(results)) + if errors: + cleanup_errors = self._cleanup_async_register_successes(list(results), digest) + if cleanup_errors: + with self._registry_lock: + self._uncertain_hashids.add(digest) + raise RuntimeError(self._format_register_partial_failure(digest, errors, cleanup_errors)) + return list(results) + + def _wait_async_register_results(self, digest: bytes, results: list[Any]) -> None: + assert self._worker is not None + errors: list[str] = [] + completed: list[Any] = [] + for result in results: + try: + self._worker.control_wait_register(int(result.worker_id), int(result.remote_handle)) + completed.append(result) + except Exception as exc: # noqa: BLE001 + errors.append(f"{result.worker_type}[{int(result.worker_id)}]: {exc}") + if errors: + cleanup_errors = self._cleanup_async_register_successes(completed, digest) + if cleanup_errors: + with self._registry_lock: + self._uncertain_hashids.add(digest) + raise RuntimeError(self._format_register_partial_failure(digest, errors, cleanup_errors)) + + def _cleanup_async_register_successes(self, results: list[Any], digest: bytes) -> list[str]: + if self._worker is None: + return [] + errors: list[str] = [] + for result in results: + if not result.ok: + continue + try: + cleanup = self._worker.control_digest_only( + WorkerType.NEXT_LEVEL, + int(result.worker_id), + _CTRL_UNREGISTER, + digest, + timeout_s=self._py_control_timeout_s, + ) + if not cleanup.ok: + errors.append(f"{cleanup.worker_type}[{cleanup.worker_id}]: {cleanup.error_message}") + except Exception as exc: # noqa: BLE001 + errors.append(f"{result.worker_type}[{int(result.worker_id)}]: {exc}") + return errors + @staticmethod def _format_register_partial_failure(digest: bytes, errors: list[str], cleanup_errors: list[str]) -> str: msg = ( @@ -2603,7 +3480,158 @@ def _pre_start_unregister_if_needed(self, handle_or_slot) -> bool: self._identity_registry.pop(digest, None) return True + @staticmethod + def _is_chip_callable_handle(handle_or_slot) -> bool: + return ( + isinstance(handle_or_slot, CallableHandle) + and handle_or_slot.kind == "CHIP_CALLABLE" + and handle_or_slot.target_namespace == "LOCAL_CHIP" + ) + + def unregister_async(self, handle_or_slot) -> UnregisterHandle: + if not self._is_chip_callable_handle(handle_or_slot): + raise TypeError("Worker.unregister_async only supports ChipCallable handles") + assert isinstance(handle_or_slot, CallableHandle) + state = _UnregisterState() + + if not self._initialized: + with self._registry_lock: + _handle_id, digest, info = self._coerce_handle_state(handle_or_slot) + self._live_handles.pop(handle_or_slot._handle_id, None) + info.ref_count -= 1 + if info.ref_count <= 0: + self._callable_registry.pop(info.slot_id, None) + self._identity_registry.pop(digest, None) + self._complete_unregister_state(state) + return UnregisterHandle(lambda: self._wait_unregister_state(state), lambda: state.completed) + + if self.level == 2: + return self._l2_unregister_async(handle_or_slot, state) + if self.level >= 3: + return self._l3_unregister_async(handle_or_slot, state) + raise ValueError(f"Worker: level {self.level} not supported") + + def _l2_unregister_async(self, handle: CallableHandle, unreg_state: _UnregisterState) -> UnregisterHandle: + cleanup_slot_id = -1 + cleanup_digest = b"" + with self._registry_lock: + _handle_id, digest, state = self._coerce_handle_state(handle) + slot_id = int(state.slot_id) + if slot_id in self._pending_unregister_cids: + raise KeyError("UNREGISTER_TOMBSTONE_ACTIVE: callable handle already pending unregister") + self._live_handles.pop(handle._handle_id, None) + state.ref_count -= 1 + if state.ref_count > 0: + self._complete_unregister_state(unreg_state) + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + + self._pending_unregister_cids.add(slot_id) + self._slot_tombstoned.add(slot_id) + self._slot_pending_unregister[slot_id] = unreg_state + if self._slot_inflight.get(slot_id, 0) == 0: + cleanup_slot_id = slot_id + cleanup_digest = digest + + if cleanup_slot_id >= 0: + self._l2_native_unregister_and_finish(cleanup_slot_id, cleanup_digest, unreg_state) + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + + def _l3_unregister_async(self, handle: CallableHandle, unreg_state: _UnregisterState) -> UnregisterHandle: + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + started = self._hierarchical_start_state == "started" or getattr(self, "_hierarchical_started", False) + + remove_target = False + cid = -1 + digest = b"" + with self._registry_lock: + _handle_id, digest, state = self._coerce_handle_state(handle) + cid = int(state.slot_id) + if cid in self._pending_unregister_cids: + raise KeyError("UNREGISTER_TOMBSTONE_ACTIVE: callable handle already pending unregister") + self._live_handles.pop(handle._handle_id, None) + state.ref_count -= 1 + remove_target = state.ref_count <= 0 + if remove_target: + self._pending_unregister_cids.add(cid) + + if not started: + if remove_target: + self._callable_registry.pop(cid, None) + self._identity_registry.pop(digest, None) + self._pending_unregister_cids.discard(cid) + self._complete_unregister_state(unreg_state) + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + + worker_endpoint = self._worker + assert worker_endpoint is not None + try: + submitted = list(worker_endpoint.broadcast_unregister_async_all(digest)) + except BaseException as e: # noqa: BLE001 + with self._registry_lock: + if remove_target: + self._uncertain_hashids.add(digest) + self._pending_unregister_cids.discard(cid) + self._complete_unregister_state(unreg_state, e) + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + + errors = self._control_errors(submitted) + if errors: + err = RuntimeError( + f"UNREGISTER_PARTIAL_FAILURE: Worker.unregister(hash={_format_digest(digest)}) failed on " + f"{len(errors)} child workers; first error: {errors[0]}" + ) + with self._registry_lock: + if remove_target: + self._uncertain_hashids.add(digest) + self._pending_unregister_cids.discard(cid) + self._complete_unregister_state(unreg_state, err) + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + + state_for_wait = state + + def wait_remote() -> None: + try: + wait_errors: list[str] = [] + for result in submitted: + try: + worker_endpoint.control_wait_unregister(int(result.worker_id), int(result.remote_handle)) + except BaseException as exc: # noqa: BLE001 + wait_errors.append(f"{result.worker_type}[{int(result.worker_id)}]: {exc}") + if wait_errors: + raise RuntimeError( + f"UNREGISTER_PARTIAL_FAILURE: Worker.unregister(hash={_format_digest(digest)}) failed on " + f"{len(wait_errors)} child workers; first error: {wait_errors[0]}" + ) + self._complete_unregister_state(unreg_state) + except BaseException as e: # noqa: BLE001 + with self._registry_lock: + if remove_target: + self._uncertain_hashids.add(digest) + self._complete_unregister_state(unreg_state, e) + finally: + if remove_target: + with self._registry_lock: + current = self._identity_registry.get(digest) + if current is not None and current is state_for_wait and current.ref_count <= 0: + self._callable_registry.pop(cid, None) + self._identity_registry.pop(digest, None) + self._pending_unregister_cids.discard(cid) + + waiter = threading.Thread(target=wait_remote, name="simpler-l3-unregister-wait", daemon=True) + waiter.start() + return UnregisterHandle(lambda: self._wait_unregister_state(unreg_state), lambda: unreg_state.completed) + def unregister(self, handle_or_slot) -> None: + if self._is_chip_callable_handle(handle_or_slot): + self.unregister_async(handle_or_slot).wait() + return + self._unregister_non_chip_sync(handle_or_slot) + + def _unregister_non_chip_sync(self, handle_or_slot) -> None: """Drop a ``CallableHandle`` from the registry and propagate cleanup. Symmetric to ``Worker.register`` for the dynamic post-init path. @@ -2878,13 +3906,32 @@ def _init_level2(self) -> None: self._chip_worker = ChipWorker() self._chip_worker.init(device_id, binaries) + self._start_l2_lanes() # Pre-warm any registered ChipCallable so the first run(handle, …) # does not pay the H2D upload cost. assert self._chip_worker is not None - for cid, target in self._callable_registry.items(): + for cid, target in list(self._callable_registry.items()): if isinstance(target, ChipCallable): - self._chip_worker._prepare_callable_at_slot(cid, target) + digest = b"" + with self._registry_lock: + for known_digest, state in self._identity_registry.items(): + if state.slot_id == int(cid): + digest = known_digest + break + if not digest: + raise RuntimeError(f"Worker.init(level=2): no digest for callable slot {cid}") + state = _LocalRegisterState() + assert self._register_queue is not None + self._register_queue.put( + _LocalRegisterRequest( + digest=digest, + slot_id=int(cid), + callable_bytes=_chip_callable_bytes(target), + state=state, + ) + ) + self._wait_local_register(state) def _init_hierarchical(self) -> None: device_ids = self._config.get("device_ids", []) @@ -3191,6 +4238,7 @@ def _cleanup_partial_init(self) -> None: self._close_remote_sessions(remote_sessions) if self._chip_worker is not None: try: + self._stop_l2_run_lane() self._chip_worker.finalize() except BaseException: # noqa: BLE001 pass @@ -3733,8 +4781,8 @@ def _release_all_live_domains(self) -> None: # ------------------------------------------------------------------ # memory management — forward to C++ Orchestrator, which holds - # per-WorkerThread mailbox_mu_ so these are safe to call concurrently - # with in-flight dispatch on the same chip mailbox. + # per-WorkerThread mailbox control. Memory/domain controls intentionally + # wait behind in-flight task dispatch. # ------------------------------------------------------------------ def _check_chip_worker_id(self, worker_id: int) -> None: @@ -3755,6 +4803,7 @@ def malloc(self, size: int, worker_id: int = 0) -> int: if self.level == 2: assert self._chip_worker is not None return self._chip_worker.malloc(size) + self._start_hierarchical() self._check_chip_worker_id(worker_id) assert self._orch is not None return self._orch.malloc(worker_id, size) @@ -3765,6 +4814,7 @@ def free(self, ptr: int, worker_id: int = 0) -> None: assert self._chip_worker is not None self._chip_worker.free(ptr) return + self._start_hierarchical() self._check_chip_worker_id(worker_id) assert self._orch is not None self._orch.free(worker_id, ptr) @@ -3775,6 +4825,7 @@ def copy_to(self, dst: int, src: int, size: int, worker_id: int = 0) -> None: assert self._chip_worker is not None self._chip_worker.copy_to(dst, src, size) return + self._start_hierarchical() self._check_chip_worker_id(worker_id) assert self._orch is not None self._orch.copy_to(worker_id, dst, src, size) @@ -3785,10 +4836,123 @@ def copy_from(self, dst: int, src: int, size: int, worker_id: int = 0) -> None: assert self._chip_worker is not None self._chip_worker.copy_from(dst, src, size) return + self._start_hierarchical() self._check_chip_worker_id(worker_id) assert self._orch is not None self._orch.copy_from(worker_id, dst, src, size) + @staticmethod + def _wait_local_run(state: _LocalRunState) -> RunTiming: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + assert state.result is not None + return state.result + + @staticmethod + def _wait_dag_run(state: _DagRunState) -> RunTiming: + with state.cv: + while not state.completed: + state.cv.wait() + if state.error is not None: + raise state.error + assert state.result is not None + return state.result + + def _complete_dag_run( + self, + state: _DagRunState, + *, + result: RunTiming | None = None, + error: BaseException | None = None, + ) -> None: + with state.cv: + state.result = result + state.error = error + state.completed = True + state.cv.notify_all() + + def _start_dag_run_lane(self) -> None: + if self._dag_run_thread is not None: + return + self._dag_run_queue = queue.Queue() + self._dag_run_thread = threading.Thread( + target=self._dag_run_thread_loop, + name="simpler-l3-dag-run", + daemon=True, + ) + self._dag_run_thread.start() + + def _stop_dag_run_lane(self) -> None: + q = self._dag_run_queue + t = self._dag_run_thread + if q is None or t is None: + return + q.put(None) + t.join() + self._dag_run_queue = None + self._dag_run_thread = None + + def _dag_run_thread_loop(self) -> None: + assert self._dag_run_queue is not None + while True: + req = self._dag_run_queue.get() + if req is None: + break + assert isinstance(req, _DagRunRequest) + try: + timing = self._run_dag_sync_impl(req.orch_fn, req.args, req.config) + self._complete_dag_run(req.state, result=timing) + except BaseException as e: # noqa: BLE001 + self._complete_dag_run(req.state, error=e) + + def _l2_submit_run_async(self, handle: CallableHandle, args=None, config=None) -> RunHandle: + assert self._initialized, "Worker not initialized; call init() first" + if self._run_queue is None: + raise RuntimeError("Worker.run_async: L2 run lane is not started") + cfg = _copy_call_config(config) + with self._registry_lock: + state_info = self._resolve_handle_locked(handle, expected_namespace="LOCAL_CHIP") + slot_id = int(state_info.slot_id) + if slot_id in self._slot_tombstoned or slot_id in self._pending_unregister_cids: + raise KeyError("callable handle is pending unregister") + self._slot_inflight[slot_id] = self._slot_inflight.get(slot_id, 0) + 1 + run_state = _LocalRunState() + self._run_queue.put( + _LocalRunRequest( + slot_id=slot_id, + args=_copy_run_args(args, prefer_task_args=False), + config=cfg, + state=run_state, + ) + ) + return RunHandle(lambda: self._wait_local_run(run_state), lambda: run_state.completed) + + def run_async(self, callable, args=None, config=None) -> RunHandle: + """Run asynchronously with level-specific semantics. + + L2 submits a direct ChipCallable handle to the local run lane. + L3+ submits the whole orchestration function to the per-worker DAG lane. + """ + assert self._initialized, "Worker not initialized; call init() first" + if self.level == 2: + return self._l2_submit_run_async(callable, args, config) + + self._start_dag_run_lane() + assert self._dag_run_queue is not None + state = _DagRunState() + self._dag_run_queue.put( + _DagRunRequest( + orch_fn=callable, + args=args, + config=_copy_call_config(config), + state=state, + ) + ) + return RunHandle(lambda: self._wait_dag_run(state), lambda: state.completed) + # ------------------------------------------------------------------ # run — uniform entry point # ------------------------------------------------------------------ @@ -3814,13 +4978,14 @@ def run(self, callable, args=None, config=None) -> RunTiming: device timings are not aggregated here. """ assert self._initialized, "Worker not initialized; call init() first" - cfg = config if config is not None else CallConfig() if self.level == 2: - assert self._chip_worker is not None - state = self._resolve_handle(callable, expected_namespace="LOCAL_CHIP") - return self._chip_worker._run_slot(state.slot_id, args, cfg) + return self.run_async(callable, args, config).wait() + + return self.run_async(callable, args, config).wait() + def _run_dag_sync_impl(self, callable, args=None, config=None) -> RunTiming: + cfg = config if config is not None else CallConfig() self._start_hierarchical() assert self._orch is not None assert self._worker is not None @@ -3919,8 +5084,10 @@ def close(self) -> None: # noqa: PLR0912 -- parallel teardown for _worker + sub if self.level == 2: if self._chip_worker: + self._stop_l2_lanes() self._chip_worker.finalize() else: + self._stop_dag_run_lane() if self._worker: self._worker.close() self._worker = None diff --git a/src/common/hierarchical/worker.h b/src/common/hierarchical/worker.h index e915b428f..fe972d29e 100644 --- a/src/common/hierarchical/worker.h +++ b/src/common/hierarchical/worker.h @@ -99,6 +99,18 @@ class Worker { // Forward CTRL_PREPARE to a specific NEXT_LEVEL worker (prewarm path // used by the Python facade at end of _start_hierarchical). void control_prepare(int worker_id, const uint8_t *digest) { manager_.control_prepare(worker_id, digest); } + uint64_t control_run_async(int worker_id, const uint8_t *digest, const TaskArgs &args, const CallConfig &config) { + return manager_.control_run_async(worker_id, digest, args, config); + } + RunTiming control_wait_run(int worker_id, uint64_t handle_id) { + return manager_.control_wait_run(worker_id, handle_id); + } + void control_wait_register(int worker_id, uint64_t handle_id) { + manager_.control_wait_register(worker_id, handle_id); + } + void control_wait_unregister(int worker_id, uint64_t handle_id) { + manager_.control_wait_unregister(worker_id, handle_id); + } // Drive a single chip child through one CommDomain alloc / release. The // Python orch facade is expected to call this on every participating chip @@ -179,6 +191,15 @@ class Worker { reinterpret_cast(blob_ptr), static_cast(blob_size), digest ); } + std::vector + broadcast_register_async_all(uint64_t blob_ptr, uint64_t blob_size, const uint8_t *digest) { + return manager_.broadcast_register_async_all( + reinterpret_cast(blob_ptr), static_cast(blob_size), digest + ); + } + std::vector broadcast_unregister_async_all(const uint8_t *digest) { + return manager_.broadcast_unregister_async_all(digest); + } std::vector broadcast_unregister_all(const uint8_t *digest) { return manager_.broadcast_unregister_all(digest); } diff --git a/src/common/hierarchical/worker_manager.cpp b/src/common/hierarchical/worker_manager.cpp index ff8c94bf7..5f22080b3 100644 --- a/src/common/hierarchical/worker_manager.cpp +++ b/src/common/hierarchical/worker_manager.cpp @@ -31,6 +31,10 @@ namespace { +static constexpr size_t ASYNC_RUN_OFF_DIGEST = 0; +static constexpr size_t ASYNC_RUN_OFF_CONFIG = ASYNC_RUN_OFF_DIGEST + CALLABLE_HASH_DIGEST_SIZE; +static constexpr size_t ASYNC_RUN_OFF_ARGS_BLOB = (ASYNC_RUN_OFF_CONFIG + sizeof(CallConfig) + 7) & ~size_t{7}; + // Read the child-written error message from the mailbox, guaranteeing // NUL-termination even if the child wrote exactly MAILBOX_ERROR_MSG_SIZE // bytes without a terminator. @@ -62,6 +66,73 @@ namespace { throw std::runtime_error(std::string(op_name) + " is not supported by this WorkerEndpoint"); } +std::atomic g_shm_counter{0}; + +std::string make_shm_name() { + char buf[CTRL_SHM_NAME_BYTES]; + int pid = static_cast(getpid()); + uint64_t counter = g_shm_counter.fetch_add(1, std::memory_order_relaxed); + int n = std::snprintf(buf, sizeof(buf), "simpler-cb-%d-%llu", pid, static_cast(counter)); + if (n < 0 || static_cast(n) >= sizeof(buf)) { + throw std::runtime_error("broadcast_register: shm name overflow"); + } + return std::string(buf); +} + +std::string strip_control_prefix(const std::string &msg, const std::string &op_name) { + const std::string needle = op_name + " failed on child: "; + if (msg.compare(0, needle.size(), needle) == 0) { + return msg.substr(needle.size()); + } + return msg; +} + +class PosixShmHolder { +public: + PosixShmHolder(const std::string &name, size_t size) : + name_(name), + size_(size) { + std::string full_name = "/" + name_; + fd_ = shm_open(full_name.c_str(), O_CREAT | O_RDWR | O_EXCL, 0600); + if (fd_ < 0) { + throw std::runtime_error( + std::string("broadcast_register: shm_open(") + full_name + ") failed: " + std::strerror(errno) + ); + } + if (ftruncate(fd_, static_cast(size)) != 0) { + int err = errno; + ::close(fd_); + shm_unlink(full_name.c_str()); + throw std::runtime_error(std::string("broadcast_register: ftruncate failed: ") + std::strerror(err)); + } + addr_ = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0); + if (addr_ == MAP_FAILED) { + int err = errno; + ::close(fd_); + shm_unlink(full_name.c_str()); + addr_ = nullptr; + throw std::runtime_error(std::string("broadcast_register: mmap failed: ") + std::strerror(err)); + } + } + ~PosixShmHolder() { + if (addr_ != nullptr) munmap(addr_, size_); + if (fd_ >= 0) ::close(fd_); + std::string full_name = "/" + name_; + shm_unlink(full_name.c_str()); + } + PosixShmHolder(const PosixShmHolder &) = delete; + PosixShmHolder &operator=(const PosixShmHolder &) = delete; + + void *addr() { return addr_; } + const std::string &name() const { return name_; } + +private: + std::string name_; + size_t size_{0}; + int fd_{-1}; + void *addr_{nullptr}; +}; + } // namespace uint64_t WorkerEndpoint::control_malloc(size_t) { throw_unsupported_control("control_malloc"); } @@ -72,6 +143,18 @@ void WorkerEndpoint::control_prepare(const uint8_t *) { throw_unsupported_contro void WorkerEndpoint::control_register(const char *, size_t, const uint8_t *) { throw_unsupported_control("control_register"); } +uint64_t WorkerEndpoint::control_register_async(const char *, size_t, const uint8_t *) { + throw_unsupported_control("control_register_async"); +} +uint64_t WorkerEndpoint::control_run_async(const uint8_t *, const TaskArgs &, const CallConfig &) { + throw_unsupported_control("control_run_async"); +} +void WorkerEndpoint::control_wait_register(uint64_t) { throw_unsupported_control("control_wait_register"); } +RunTiming WorkerEndpoint::control_wait_run(uint64_t) { throw_unsupported_control("control_wait_run"); } +uint64_t WorkerEndpoint::control_unregister_async(const uint8_t *) { + throw_unsupported_control("control_unregister_async"); +} +void WorkerEndpoint::control_wait_unregister(uint64_t) { throw_unsupported_control("control_wait_unregister"); } void WorkerEndpoint::control_unregister(const uint8_t *) { throw_unsupported_control("control_unregister"); } void WorkerEndpoint::control_remote_prepare_register( remote_l3::RemoteRegistryTarget, CallableKind, const uint8_t *, const void *, size_t @@ -156,6 +239,13 @@ void LocalMailboxEndpoint::write_mailbox_state(MailboxState s) { #endif } +bool LocalMailboxEndpoint::compare_exchange_mailbox_state(MailboxState expected, MailboxState desired) { + volatile int32_t *ptr = reinterpret_cast(mbox() + MAILBOX_OFF_STATE); + int32_t expected_v = static_cast(expected); + int32_t desired_v = static_cast(desired); + return __atomic_compare_exchange_n(ptr, &expected_v, desired_v, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE); +} + void LocalMailboxEndpoint::shutdown_child() { write_mailbox_state(MailboxState::SHUTDOWN); } // ============================================================================= @@ -266,62 +356,65 @@ WorkerCompletion LocalMailboxEndpoint::run(Ring *ring, const WorkerDispatch &dis completion.task_slot = dispatch.task_slot; completion.group_index = group_index; - // Hold mailbox_mu_ for the entire round trip (write payload + state + - // spin-poll TASK_DONE + reset to IDLE). Any control_* request from the - // orch thread waits for the dispatch to finish before claiming the - // mailbox; without this they would race on MAILBOX_OFF_STATE. - std::lock_guard lk(mailbox_mu_); - if (mailbox_control_timed_out_) { - completion.outcome = EndpointOutcome::ENDPOINT_FAILURE; - completion.error_message = "LocalMailboxEndpoint::run: mailbox has an unresolved timed-out control command"; - return completion; - } + { + std::lock_guard lk(mailbox_mu_); + if (mailbox_control_timed_out_) { + completion.outcome = EndpointOutcome::ENDPOINT_FAILURE; + completion.error_message = "LocalMailboxEndpoint::run: mailbox has an unresolved timed-out control command"; + return completion; + } - // Clear the child-writable error fields so stale bytes from a prior - // dispatch cannot masquerade as a fresh failure. - int32_t zero_err = 0; - std::memcpy(mbox() + MAILBOX_OFF_ERROR, &zero_err, sizeof(int32_t)); - std::memset(mbox() + MAILBOX_OFF_ERROR_MSG, 0, MAILBOX_ERROR_MSG_SIZE); + // Clear the child-writable error fields so stale bytes from a prior + // dispatch cannot masquerade as a fresh failure. + int32_t zero_err = 0; + std::memcpy(mbox() + MAILBOX_OFF_ERROR, &zero_err, sizeof(int32_t)); + std::memset(mbox() + MAILBOX_OFF_ERROR_MSG, 0, MAILBOX_ERROR_MSG_SIZE); - uint64_t reserved_callable = 0; - std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &reserved_callable, sizeof(uint64_t)); + uint64_t reserved_callable = 0; + std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &reserved_callable, sizeof(uint64_t)); - // Write config as a single packed POD block (see call_config.h). - std::memcpy(mbox() + MAILBOX_OFF_CONFIG, &s.config, sizeof(CallConfig)); + // Write config as a single packed POD block (see call_config.h). + std::memcpy(mbox() + MAILBOX_OFF_CONFIG, &s.config, sizeof(CallConfig)); - // Write length-prefixed TaskArgs blob: [T][S][tensors][scalars]. - size_t blob_bytes = TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(Tensor) + - static_cast(view.scalar_count) * sizeof(uint64_t); - if (blob_bytes > MAILBOX_ARGS_CAPACITY) { - completion.outcome = EndpointOutcome::ENDPOINT_FAILURE; - completion.error_message = - "LocalMailboxEndpoint::run: args blob exceeds mailbox capacity: need " + std::to_string(blob_bytes) + - " bytes, capacity " + std::to_string(MAILBOX_ARGS_CAPACITY) + - " bytes, tensors=" + std::to_string(view.tensor_count) + ", scalars=" + std::to_string(view.scalar_count); - return completion; - } - uint8_t *hash_dst = reinterpret_cast(mbox() + MAILBOX_OFF_TASK_CALLABLE_HASH); - std::memcpy(hash_dst, s.callable.digest.data(), CALLABLE_HASH_DIGEST_SIZE); + // Write length-prefixed TaskArgs blob: [T][S][tensors][scalars]. + size_t blob_bytes = TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(Tensor) + + static_cast(view.scalar_count) * sizeof(uint64_t); + if (blob_bytes > MAILBOX_ARGS_CAPACITY) { + completion.outcome = EndpointOutcome::ENDPOINT_FAILURE; + completion.error_message = "LocalMailboxEndpoint::run: args blob exceeds mailbox capacity: need " + + std::to_string(blob_bytes) + " bytes, capacity " + + std::to_string(MAILBOX_ARGS_CAPACITY) + + " bytes, tensors=" + std::to_string(view.tensor_count) + + ", scalars=" + std::to_string(view.scalar_count); + return completion; + } + uint8_t *hash_dst = reinterpret_cast(mbox() + MAILBOX_OFF_TASK_CALLABLE_HASH); + std::memcpy(hash_dst, s.callable.digest.data(), CALLABLE_HASH_DIGEST_SIZE); + + uint8_t *d = reinterpret_cast(mbox() + MAILBOX_OFF_TASK_ARGS_BLOB); + std::memcpy(d + 0, &view.tensor_count, sizeof(int32_t)); + std::memcpy(d + 4, &view.scalar_count, sizeof(int32_t)); + if (view.tensor_count > 0) { + std::memcpy( + d + TASK_ARGS_BLOB_HEADER_SIZE, view.tensor_bytes, + static_cast(view.tensor_count) * sizeof(Tensor) + ); + } + if (view.scalar_count > 0) { + std::memcpy( + d + TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(Tensor), view.scalars, + static_cast(view.scalar_count) * sizeof(uint64_t) + ); + } - uint8_t *d = reinterpret_cast(mbox() + MAILBOX_OFF_TASK_ARGS_BLOB); - std::memcpy(d + 0, &view.tensor_count, sizeof(int32_t)); - std::memcpy(d + 4, &view.scalar_count, sizeof(int32_t)); - if (view.tensor_count > 0) { - std::memcpy( - d + TASK_ARGS_BLOB_HEADER_SIZE, view.tensor_bytes, static_cast(view.tensor_count) * sizeof(Tensor) - ); - } - if (view.scalar_count > 0) { - std::memcpy( - d + TASK_ARGS_BLOB_HEADER_SIZE + static_cast(view.tensor_count) * sizeof(Tensor), view.scalars, - static_cast(view.scalar_count) * sizeof(uint64_t) - ); + write_mailbox_state(MailboxState::TASK_READY); + while (true) { + MailboxState state = read_mailbox_state(); + if (state == MailboxState::TASK_RUNNING || state == MailboxState::TASK_DONE) break; + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } } - // Signal child process. - write_mailbox_state(MailboxState::TASK_READY); - - // Spin-poll until child signals TASK_DONE. while (read_mailbox_state() != MailboxState::TASK_DONE) { std::this_thread::sleep_for(std::chrono::microseconds(50)); } @@ -556,16 +649,40 @@ void LocalMailboxEndpoint::run_control_command(const char *op_name, double timeo if (mailbox_control_timed_out_) { throw std::runtime_error(std::string(op_name) + " failed: mailbox has an unresolved timed-out control command"); } - int32_t zero_err = 0; - std::memcpy(mbox() + MAILBOX_OFF_ERROR, &zero_err, sizeof(int32_t)); - std::memset(mbox() + MAILBOX_OFF_ERROR_MSG, 0, MAILBOX_ERROR_MSG_SIZE); - write_mailbox_state(MailboxState::CONTROL_REQUEST); auto deadline = std::chrono::steady_clock::time_point::max(); if (timeout_s >= 0.0) { deadline = std::chrono::steady_clock::now() + std::chrono::duration_cast(std::chrono::duration(timeout_s)); } + uint64_t sub_cmd = 0; + std::memcpy(&sub_cmd, mbox() + MAILBOX_OFF_CALLABLE, sizeof(uint64_t)); + bool can_overlap_task = sub_cmd == CTRL_REGISTER_ASYNC || sub_cmd == CTRL_WAIT_REGISTER || + sub_cmd == CTRL_RUN_ASYNC || sub_cmd == CTRL_WAIT_RUN || sub_cmd == CTRL_UNREGISTER_ASYNC || + sub_cmd == CTRL_WAIT_UNREGISTER; + MailboxState restore_state = MailboxState::IDLE; + while (true) { + MailboxState current = read_mailbox_state(); + if (current == MailboxState::TASK_DONE || (current == MailboxState::TASK_RUNNING && !can_overlap_task)) { + if (std::chrono::steady_clock::now() >= deadline) { + mailbox_control_timed_out_ = true; + throw std::runtime_error(std::string(op_name) + " timed out waiting for task dispatch to finish"); + } + std::this_thread::sleep_for(std::chrono::microseconds(50)); + continue; + } + if (current == MailboxState::IDLE || current == MailboxState::INIT_DONE || + current == MailboxState::TASK_RUNNING) { + restore_state = (current == MailboxState::TASK_RUNNING) ? MailboxState::TASK_RUNNING : MailboxState::IDLE; + if (compare_exchange_mailbox_state(current, MailboxState::CONTROL_REQUEST)) break; + continue; + } + if (std::chrono::steady_clock::now() >= deadline) { + mailbox_control_timed_out_ = true; + throw std::runtime_error(std::string(op_name) + " timed out waiting for control mailbox availability"); + } + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } while (read_mailbox_state() != MailboxState::CONTROL_DONE) { if (std::chrono::steady_clock::now() >= deadline) { mailbox_control_timed_out_ = true; @@ -576,10 +693,10 @@ void LocalMailboxEndpoint::run_control_command(const char *op_name, double timeo std::memcpy(&err, mbox() + MAILBOX_OFF_ERROR, sizeof(int32_t)); if (err != 0) { std::string msg = read_error_msg(mbox()); - write_mailbox_state(MailboxState::IDLE); + write_mailbox_state(restore_state); throw std::runtime_error(std::string(op_name) + " failed on child: " + msg); } - write_mailbox_state(MailboxState::IDLE); + write_mailbox_state(restore_state); } uint64_t LocalMailboxEndpoint::control_malloc(size_t size) { @@ -616,6 +733,81 @@ void LocalMailboxEndpoint::control_register(const char *shm_name, size_t blob_si run_control_command("control_register"); } +uint64_t LocalMailboxEndpoint::control_register_async(const char *shm_name, size_t blob_size, const uint8_t *digest) { + std::lock_guard lk(mailbox_mu_); + uint64_t sub_cmd = CTRL_REGISTER_ASYNC; + std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &sub_cmd, sizeof(uint64_t)); + uint64_t payload_size = static_cast(blob_size); + std::memcpy(mbox() + CTRL_OFF_ARG0, &payload_size, sizeof(uint64_t)); + write_control_digest(mbox(), digest); + size_t name_len = std::strlen(shm_name); + if (name_len + 1 > CTRL_SHM_NAME_BYTES) { + throw std::runtime_error(std::string("control_register_async: shm name too long: ") + shm_name); + } + std::memcpy(mbox() + MAILBOX_OFF_ARGS, shm_name, name_len); + std::memset(mbox() + MAILBOX_OFF_ARGS + name_len, 0, CTRL_SHM_NAME_BYTES - name_len); + run_control_command("control_register_async"); + return read_control_result(mbox()); +} + +uint64_t +LocalMailboxEndpoint::control_run_async(const uint8_t *digest, const TaskArgs &args, const CallConfig &config) { + size_t args_blob_size = task_args_blob_size(args); + size_t payload_size = ASYNC_RUN_OFF_ARGS_BLOB + args_blob_size; + std::string shm_name = make_shm_name(); + PosixShmHolder shm(shm_name, payload_size); + auto *payload = static_cast(shm.addr()); + std::memset(payload, 0, payload_size); + std::memcpy(payload + ASYNC_RUN_OFF_DIGEST, digest, CALLABLE_HASH_DIGEST_SIZE); + std::memcpy(payload + ASYNC_RUN_OFF_CONFIG, &config, sizeof(CallConfig)); + write_blob(payload + ASYNC_RUN_OFF_ARGS_BLOB, args); + + std::lock_guard lk(mailbox_mu_); + uint64_t sub_cmd = CTRL_RUN_ASYNC; + std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &sub_cmd, sizeof(uint64_t)); + uint64_t staged_payload_size = static_cast(payload_size); + std::memcpy(mbox() + CTRL_OFF_ARG0, &staged_payload_size, sizeof(uint64_t)); + write_control_digest(mbox(), digest); + size_t name_len = shm.name().size(); + if (name_len + 1 > CTRL_SHM_NAME_BYTES) { + throw std::runtime_error(std::string("control_run_async: shm name too long: ") + shm.name()); + } + std::memcpy(mbox() + MAILBOX_OFF_ARGS, shm.name().data(), name_len); + std::memset(mbox() + MAILBOX_OFF_ARGS + name_len, 0, CTRL_SHM_NAME_BYTES - name_len); + run_control_command("control_run_async"); + return read_control_result(mbox()); +} + +void LocalMailboxEndpoint::control_wait_register(uint64_t handle_id) { + std::lock_guard lk(mailbox_mu_); + write_control_args(mbox(), CTRL_WAIT_REGISTER, handle_id); + run_control_command("control_wait_register"); +} + +RunTiming LocalMailboxEndpoint::control_wait_run(uint64_t handle_id) { + std::lock_guard lk(mailbox_mu_); + write_control_args(mbox(), CTRL_WAIT_RUN, handle_id); + run_control_command("control_wait_run"); + uint64_t host_wall_ns = read_control_result(mbox()); + uint64_t device_wall_ns = 0; + std::memcpy(&device_wall_ns, mbox() + CTRL_OFF_RESULT1, sizeof(uint64_t)); + return RunTiming{host_wall_ns, device_wall_ns}; +} + +uint64_t LocalMailboxEndpoint::control_unregister_async(const uint8_t *digest) { + std::lock_guard lk(mailbox_mu_); + write_control_args(mbox(), CTRL_UNREGISTER_ASYNC); + write_control_digest(mbox(), digest); + run_control_command("control_unregister_async"); + return read_control_result(mbox()); +} + +void LocalMailboxEndpoint::control_wait_unregister(uint64_t handle_id) { + std::lock_guard lk(mailbox_mu_); + write_control_args(mbox(), CTRL_WAIT_UNREGISTER, handle_id); + run_control_command("control_wait_unregister"); +} + void LocalMailboxEndpoint::control_unregister(const uint8_t *digest) { std::lock_guard lk(mailbox_mu_); write_control_args(mbox(), CTRL_UNREGISTER); @@ -787,6 +979,36 @@ void WorkerThread::control_register(const char *shm_name, size_t blob_size, cons endpoint_->control_register(shm_name, blob_size, digest); } +uint64_t WorkerThread::control_register_async(const char *shm_name, size_t blob_size, const uint8_t *digest) { + if (!endpoint_) throw std::runtime_error("control_register_async: null endpoint"); + return endpoint_->control_register_async(shm_name, blob_size, digest); +} + +uint64_t WorkerThread::control_run_async(const uint8_t *digest, const TaskArgs &args, const CallConfig &config) { + if (!endpoint_) throw std::runtime_error("control_run_async: null endpoint"); + return endpoint_->control_run_async(digest, args, config); +} + +void WorkerThread::control_wait_register(uint64_t handle_id) { + if (!endpoint_) throw std::runtime_error("control_wait_register: null endpoint"); + endpoint_->control_wait_register(handle_id); +} + +RunTiming WorkerThread::control_wait_run(uint64_t handle_id) { + if (!endpoint_) throw std::runtime_error("control_wait_run: null endpoint"); + return endpoint_->control_wait_run(handle_id); +} + +uint64_t WorkerThread::control_unregister_async(const uint8_t *digest) { + if (!endpoint_) throw std::runtime_error("control_unregister_async: null endpoint"); + return endpoint_->control_unregister_async(digest); +} + +void WorkerThread::control_wait_unregister(uint64_t handle_id) { + if (!endpoint_) throw std::runtime_error("control_wait_unregister: null endpoint"); + endpoint_->control_wait_unregister(handle_id); +} + void WorkerThread::control_unregister(const uint8_t *digest) { if (!endpoint_) throw std::runtime_error("control_unregister: null endpoint"); endpoint_->control_unregister(digest); @@ -917,97 +1139,45 @@ bool WorkerManager::any_busy() const { // Dynamic register/unregister broadcast (POSIX shm staging + parallel fan-out) // ============================================================================= -namespace { - -// Process-wide monotonic counter so concurrent broadcasts do not collide on shm -// name. Atomic increment is enough — no need to lock. -std::atomic g_shm_counter{0}; - -// Build the per-broadcast POSIX shm name. The name itself does NOT carry the -// leading '/' that shm_open requires (Python's multiprocessing.SharedMemory -// uses the same convention, so the child Python side reads the field as a -// plain name). Caller adds '/' when opening. -std::string make_shm_name() { - char buf[CTRL_SHM_NAME_BYTES]; - int pid = static_cast(getpid()); - uint64_t counter = g_shm_counter.fetch_add(1, std::memory_order_relaxed); - int n = std::snprintf(buf, sizeof(buf), "simpler-cb-%d-%llu", pid, static_cast(counter)); - if (n < 0 || static_cast(n) >= sizeof(buf)) { - throw std::runtime_error("broadcast_register: shm name overflow"); +void WorkerManager::control_prepare(int worker_id, const uint8_t *digest) { + auto *wt = get_worker_by_id(WorkerType::NEXT_LEVEL, worker_id); + if (wt == nullptr) { + throw std::runtime_error("control_prepare: invalid worker_id " + std::to_string(worker_id)); } - return std::string(buf); + wt->control_prepare(digest); } -// Strip the outer " failed on child: " prefix that -// run_control_command prepends to every control failure, so the broadcast -// caller can surface the child-side message (`register hash=sha256:... -// chip=: `) directly under its own one-line Worker.register prefix. -std::string strip_control_prefix(const std::string &msg, const std::string &op_name) { - const std::string needle = op_name + " failed on child: "; - if (msg.compare(0, needle.size(), needle) == 0) { - return msg.substr(needle.size()); +uint64_t +WorkerManager::control_run_async(int worker_id, const uint8_t *digest, const TaskArgs &args, const CallConfig &config) { + auto *wt = get_worker_by_id(WorkerType::NEXT_LEVEL, worker_id); + if (wt == nullptr) { + throw std::runtime_error("control_run_async: invalid worker_id " + std::to_string(worker_id)); } - return msg; + return wt->control_run_async(digest, args, config); } -// RAII guard for a POSIX shm segment: create on construction, unlink on -// destruction. mmaps the region so the staged blob can be memcpy'd in -// place; the mmap is released in the destructor as well. The shm is only -// unlinked once — children open by name *before* this guard is destroyed. -class PosixShmHolder { -public: - PosixShmHolder(const std::string &name, size_t size) : - name_(name), - size_(size) { - std::string full_name = "/" + name_; - fd_ = shm_open(full_name.c_str(), O_CREAT | O_RDWR | O_EXCL, 0600); - if (fd_ < 0) { - throw std::runtime_error( - std::string("broadcast_register: shm_open(") + full_name + ") failed: " + std::strerror(errno) - ); - } - if (ftruncate(fd_, static_cast(size)) != 0) { - int err = errno; - ::close(fd_); - shm_unlink(full_name.c_str()); - throw std::runtime_error(std::string("broadcast_register: ftruncate failed: ") + std::strerror(err)); - } - addr_ = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0); - if (addr_ == MAP_FAILED) { - int err = errno; - ::close(fd_); - shm_unlink(full_name.c_str()); - addr_ = nullptr; - throw std::runtime_error(std::string("broadcast_register: mmap failed: ") + std::strerror(err)); - } - } - ~PosixShmHolder() { - if (addr_ != nullptr) munmap(addr_, size_); - if (fd_ >= 0) ::close(fd_); - std::string full_name = "/" + name_; - shm_unlink(full_name.c_str()); +RunTiming WorkerManager::control_wait_run(int worker_id, uint64_t handle_id) { + auto *wt = get_worker_by_id(WorkerType::NEXT_LEVEL, worker_id); + if (wt == nullptr) { + throw std::runtime_error("control_wait_run: invalid worker_id " + std::to_string(worker_id)); } - PosixShmHolder(const PosixShmHolder &) = delete; - PosixShmHolder &operator=(const PosixShmHolder &) = delete; - - void *addr() { return addr_; } - const std::string &name() const { return name_; } - -private: - std::string name_; - size_t size_{0}; - int fd_{-1}; - void *addr_{nullptr}; -}; + return wt->control_wait_run(handle_id); +} -} // namespace +void WorkerManager::control_wait_register(int worker_id, uint64_t handle_id) { + auto *wt = get_worker_by_id(WorkerType::NEXT_LEVEL, worker_id); + if (wt == nullptr) { + throw std::runtime_error("control_wait_register: invalid worker_id " + std::to_string(worker_id)); + } + wt->control_wait_register(handle_id); +} -void WorkerManager::control_prepare(int worker_id, const uint8_t *digest) { +void WorkerManager::control_wait_unregister(int worker_id, uint64_t handle_id) { auto *wt = get_worker_by_id(WorkerType::NEXT_LEVEL, worker_id); if (wt == nullptr) { - throw std::runtime_error("control_prepare: invalid worker_id " + std::to_string(worker_id)); + throw std::runtime_error("control_wait_unregister: invalid worker_id " + std::to_string(worker_id)); } - wt->control_prepare(digest); + wt->control_wait_unregister(handle_id); } void WorkerManager::control_alloc_domain(int worker_id, const char *request_shm_name, const char *reply_shm_name) { @@ -1248,6 +1418,45 @@ WorkerManager::broadcast_register_all(const void *blob_ptr, size_t blob_size, co return results; } +std::vector +WorkerManager::broadcast_register_async_all(const void *blob_ptr, size_t blob_size, const uint8_t *digest) { + std::vector results; + results.reserve(next_level_threads_.size()); + for (size_t i = 0; i < next_level_threads_.size(); ++i) { + results.push_back(AsyncControlResult{"NEXT_LEVEL", next_level_threads_[i]->worker_id(), true, 0, ""}); + } + if (next_level_threads_.empty()) return results; + + std::string shm_name = make_shm_name(); + PosixShmHolder shm(shm_name, blob_size); + std::memcpy(shm.addr(), blob_ptr, blob_size); + + std::vector workers; + workers.reserve(next_level_threads_.size()); + for (size_t i = 0; i < next_level_threads_.size(); ++i) { + workers.emplace_back([this, i, digest, blob_size, name = shm.name(), &results]() { + try { + results[i].remote_handle = + next_level_threads_[i]->control_register_async(name.c_str(), blob_size, digest); + } catch (const std::exception &e) { + results[i].ok = false; + results[i].error_message = strip_control_prefix(e.what(), "control_register_async"); + } + }); + } + for (auto &t : workers) + t.join(); + + std::string hash = format_digest(digest); + for (auto &result : results) { + if (!result.ok && result.error_message.find("hash=") == std::string::npos) { + result.error_message = "Worker.register_async(hash=" + hash + ") failed on next_level " + + std::to_string(result.worker_id) + ": " + result.error_message; + } + } + return results; +} + std::vector WorkerManager::broadcast_unregister_all(const uint8_t *digest) { std::vector errors; if (next_level_threads_.empty()) return errors; @@ -1275,6 +1484,39 @@ std::vector WorkerManager::broadcast_unregister_all(const uint8_t * return errors; } +std::vector WorkerManager::broadcast_unregister_async_all(const uint8_t *digest) { + std::vector results; + results.reserve(next_level_threads_.size()); + for (size_t i = 0; i < next_level_threads_.size(); ++i) { + results.push_back(AsyncControlResult{"NEXT_LEVEL", next_level_threads_[i]->worker_id(), true, 0, ""}); + } + if (next_level_threads_.empty()) return results; + + std::vector workers; + workers.reserve(next_level_threads_.size()); + for (size_t i = 0; i < next_level_threads_.size(); ++i) { + workers.emplace_back([this, i, digest, &results]() { + try { + results[i].remote_handle = next_level_threads_[i]->control_unregister_async(digest); + } catch (const std::exception &e) { + results[i].ok = false; + results[i].error_message = strip_control_prefix(e.what(), "control_unregister_async"); + } + }); + } + for (auto &t : workers) + t.join(); + + std::string hash = format_digest(digest); + for (auto &result : results) { + if (!result.ok && result.error_message.find("hash=") == std::string::npos) { + result.error_message = "Worker.unregister_async(hash=" + hash + ") failed on next_level " + + std::to_string(result.worker_id) + ": " + result.error_message; + } + } + return results; +} + std::vector WorkerManager::broadcast_control_all( WorkerType type, uint64_t sub_cmd, const void *payload, size_t payload_size, const uint8_t *digest, double timeout_s ) { diff --git a/src/common/hierarchical/worker_manager.h b/src/common/hierarchical/worker_manager.h index db66a45e5..f518bde72 100644 --- a/src/common/hierarchical/worker_manager.h +++ b/src/common/hierarchical/worker_manager.h @@ -69,6 +69,7 @@ enum class MailboxState : int32_t { // across distributed ranks so cross-rank init skew never charges against // the per-rank PLATFORM_STREAM_SYNC_TIMEOUT_MS budget (issue #897). INIT_DONE = 6, + TASK_RUNNING = 7, }; // Sized so the args region can hold any TaskArgs the runtime itself accepts @@ -151,7 +152,14 @@ static constexpr uint64_t CTRL_RELEASE_DOMAIN = 8; static constexpr uint64_t CTRL_COMM_INIT = 9; static constexpr uint64_t CTRL_PY_REGISTER = 10; static constexpr uint64_t CTRL_PY_UNREGISTER = 11; +static constexpr uint64_t CTRL_PY_IMPORT_REGISTER = 12; static constexpr uint64_t CTRL_L3_L2_ORCH_COMM_INIT = 13; +static constexpr uint64_t CTRL_REGISTER_ASYNC = 14; +static constexpr uint64_t CTRL_WAIT_REGISTER = 15; +static constexpr uint64_t CTRL_RUN_ASYNC = 16; +static constexpr uint64_t CTRL_WAIT_RUN = 17; +static constexpr uint64_t CTRL_UNREGISTER_ASYNC = 18; +static constexpr uint64_t CTRL_WAIT_UNREGISTER = 19; // Control args reuse the task mailbox region (mutually exclusive with task dispatch): // offset 16: uint64 arg0 (size for malloc/register; ptr for free; dst for copy) @@ -162,6 +170,7 @@ static constexpr ptrdiff_t CTRL_OFF_ARG0 = 16; static constexpr ptrdiff_t CTRL_OFF_ARG1 = 24; static constexpr ptrdiff_t CTRL_OFF_ARG2 = 32; static constexpr ptrdiff_t CTRL_OFF_RESULT = 40; +static constexpr ptrdiff_t CTRL_OFF_RESULT1 = 48; // CTRL_REGISTER puts the NUL-terminated POSIX shm name at MAILBOX_OFF_ARGS, // the exact staged blob size at CTRL_OFF_ARG0, and the callable digest @@ -176,6 +185,14 @@ struct ControlResult { std::string error_message; }; +struct AsyncControlResult { + std::string worker_type; + int32_t worker_id{0}; + bool ok{false}; + uint64_t remote_handle{0}; + std::string error_message; +}; + struct WorkerDispatch; enum class WorkerEndpointKind : int32_t { @@ -206,6 +223,12 @@ class WorkerEndpoint { virtual void control_copy_from(uint64_t dst, uint64_t src, size_t size); virtual void control_prepare(const uint8_t *digest); virtual void control_register(const char *shm_name, size_t blob_size, const uint8_t *digest); + virtual uint64_t control_register_async(const char *shm_name, size_t blob_size, const uint8_t *digest); + virtual uint64_t control_run_async(const uint8_t *digest, const TaskArgs &args, const CallConfig &config); + virtual void control_wait_register(uint64_t handle_id); + virtual RunTiming control_wait_run(uint64_t handle_id); + virtual uint64_t control_unregister_async(const uint8_t *digest); + virtual void control_wait_unregister(uint64_t handle_id); virtual void control_unregister(const uint8_t *digest); virtual void control_remote_prepare_register( remote_l3::RemoteRegistryTarget target_registry, CallableKind callable_kind, const uint8_t *digest, @@ -256,6 +279,12 @@ class LocalMailboxEndpoint : public WorkerEndpoint { void control_copy_from(uint64_t dst, uint64_t src, size_t size) override; void control_prepare(const uint8_t *digest) override; void control_register(const char *shm_name, size_t blob_size, const uint8_t *digest) override; + uint64_t control_register_async(const char *shm_name, size_t blob_size, const uint8_t *digest) override; + uint64_t control_run_async(const uint8_t *digest, const TaskArgs &args, const CallConfig &config) override; + void control_wait_register(uint64_t handle_id) override; + RunTiming control_wait_run(uint64_t handle_id) override; + uint64_t control_unregister_async(const uint8_t *digest) override; + void control_wait_unregister(uint64_t handle_id) override; void control_unregister(const uint8_t *digest) override; void control_remote_prepare_register( remote_l3::RemoteRegistryTarget target_registry, CallableKind callable_kind, const uint8_t *digest, @@ -300,6 +329,7 @@ class LocalMailboxEndpoint : public WorkerEndpoint { char *mbox() const { return static_cast(mailbox_); } MailboxState read_mailbox_state() const; void write_mailbox_state(MailboxState s); + bool compare_exchange_mailbox_state(MailboxState expected, MailboxState desired); void run_control_command(const char *op_name, double timeout_s = -1.0); }; @@ -364,10 +394,10 @@ class WorkerThread { // thread may be running a task. Issues a control command via the // mailbox and blocks until the child responds. // - // The mailbox is a single shared region; dispatch_process and the - // control_* methods both write its state field. They serialize on - // `mailbox_mu_` so a control request issued mid-dispatch waits for - // TASK_DONE before claiming the mailbox. + // The mailbox is a single shared region; task dispatch owns the payload + // until the child acknowledges TASK_RUNNING. After that, control_* may + // claim the state field while the child run lane continues from its + // private args copy. uint64_t control_malloc(size_t size); void control_free(uint64_t ptr); void control_copy_to(uint64_t dst, uint64_t src, size_t size); @@ -380,10 +410,15 @@ class WorkerThread { // Dynamic post-init register/unregister of a ChipCallable identity. // `shm_name` is the (NUL-terminated, ≤ CTRL_SHM_NAME_BYTES-1) POSIX shm // name where the ChipCallable bytes are staged; `blob_size` is the exact - // byte span to read from that shm. Both methods hold mailbox_mu_, so a - // CTRL_REGISTER concurrent with dispatch_process waits for the in-flight - // TASK_DONE before claiming the mailbox. + // byte span to read from that shm. Async run/register/unregister controls + // can overlap TASK_RUNNING after the child has copied the task args. void control_register(const char *shm_name, size_t blob_size, const uint8_t *digest); + uint64_t control_register_async(const char *shm_name, size_t blob_size, const uint8_t *digest); + uint64_t control_run_async(const uint8_t *digest, const TaskArgs &args, const CallConfig &config); + void control_wait_register(uint64_t handle_id); + RunTiming control_wait_run(uint64_t handle_id); + uint64_t control_unregister_async(const uint8_t *digest); + void control_wait_unregister(uint64_t handle_id); void control_unregister(const uint8_t *digest); void control_remote_prepare_register( remote_l3::RemoteRegistryTarget target_registry, CallableKind callable_kind, const uint8_t *digest, @@ -418,8 +453,8 @@ class WorkerThread { // request payload (header + rank_ids + buffer_nbytes); for alloc the child // writes its (device_ctx, local_window_base, buffer_ptrs) into // `reply_shm_name`. Both names are NUL-terminated and ≤ - // CTRL_SHM_NAME_BYTES-1. Holds mailbox_mu_ so it serialises with task - // dispatch on the same chip mailbox. + // CTRL_SHM_NAME_BYTES-1. Memory/domain controls still wait for any + // in-flight task dispatch to finish before claiming the mailbox. void control_alloc_domain(const char *request_shm_name, const char *reply_shm_name); void control_release_domain(const char *request_shm_name); @@ -483,6 +518,10 @@ class WorkerManager { // over WorkerThread::control_prepare; exposed at manager level so the // Python facade can prewarm without reaching into individual WorkerThreads. void control_prepare(int worker_id, const uint8_t *digest); + uint64_t control_run_async(int worker_id, const uint8_t *digest, const TaskArgs &args, const CallConfig &config); + RunTiming control_wait_run(int worker_id, uint64_t handle_id); + void control_wait_register(int worker_id, uint64_t handle_id); + void control_wait_unregister(int worker_id, uint64_t handle_id); // Forward CTRL_ALLOC_DOMAIN / CTRL_RELEASE_DOMAIN to a specific NEXT_LEVEL // worker. Used by the Python orch facade to drive collective domain @@ -530,6 +569,9 @@ class WorkerManager { // target so the Python facade can clean up only targets that confirmed // install/refcount increment on a partial failure. std::vector broadcast_register_all(const void *blob_ptr, size_t blob_size, const uint8_t *digest); + std::vector + broadcast_register_async_all(const void *blob_ptr, size_t blob_size, const uint8_t *digest); + std::vector broadcast_unregister_async_all(const uint8_t *digest); // Best-effort: broadcast CTRL_UNREGISTER for `digest` to every NEXT_LEVEL // worker in parallel. Returns a vector of per-worker error strings diff --git a/src/common/platform/onboard/host/device_runner_base.cpp b/src/common/platform/onboard/host/device_runner_base.cpp index bbe37e305..4b0f9da77 100644 --- a/src/common/platform/onboard/host/device_runner_base.cpp +++ b/src/common/platform/onboard/host/device_runner_base.cpp @@ -558,25 +558,33 @@ int DeviceRunnerBase::prewarm_callable(int32_t callable_id) { LOG_ERROR("prewarm_callable: ensure_device_initialized failed: %d", rc); return rc; } - + if (stream_aicpu_prewarm_ == nullptr) { + rc = rtStreamCreate(&stream_aicpu_prewarm_, 0); + if (rc != 0) { + LOG_ERROR("prewarm_callable: rtStreamCreate(prewarm) failed: %d", rc); + return rc; + } + } Runtime runtime; rc = stamp_orch_so(runtime, callable_id, /*force_reload=*/true); if (rc != 0) return rc; - rc = init_runtime_args_with_metadata(runtime); + rc = init_runtime_args_with_metadata(runtime, prewarm_kernel_args_); if (rc != 0) return rc; auto runtime_args_cleanup = RAIIScopeGuard([this]() { - kernel_args_.finalize_runtime_args(); + prewarm_kernel_args_.finalize_runtime_args(); }); LOG_INFO_V0("=== launch_aicpu_kernel %s ===", host::KernelNames::PrewarmName); - rc = launch_aicpu_kernel(stream_aicpu_, &kernel_args_.args, host::KernelNames::PrewarmName, /*aicpu_num=*/1); + rc = launch_aicpu_kernel( + stream_aicpu_prewarm_, &prewarm_kernel_args_.args, host::KernelNames::PrewarmName, /*aicpu_num=*/1 + ); if (rc != 0) { LOG_ERROR("prewarm_callable: launch_aicpu_kernel failed: %d", rc); return rc; } - rc = aclrtSynchronizeStreamWithTimeout(stream_aicpu_, PLATFORM_STREAM_SYNC_TIMEOUT_MS); + rc = aclrtSynchronizeStreamWithTimeout(stream_aicpu_prewarm_, PLATFORM_STREAM_SYNC_TIMEOUT_MS); if (rc == ACL_ERROR_RT_STREAM_SYNC_TIMEOUT) { LOG_ERROR( "prewarm_callable: stream sync timeout timeout_ms=%d device_id=%d", PLATFORM_STREAM_SYNC_TIMEOUT_MS, @@ -785,6 +793,16 @@ int DeviceRunnerBase::finalize_common() { capture(rtStreamDestroy(stream_aicore_)); stream_aicore_ = nullptr; } + if (stream_aicpu_prewarm_ != nullptr) { + capture(rtStreamDestroy(stream_aicpu_prewarm_)); + stream_aicpu_prewarm_ = nullptr; + } + + // Cleanup kernel args; device-side KernelArgs + runtime args + // are released by runtime_args_cleanup RAII so they also unwind on errors. + capture(kernel_args_.finalize_device_kernel_args()); + capture(prewarm_kernel_args_.finalize_runtime_args()); + capture(prewarm_kernel_args_.finalize_device_kernel_args()); // load_aicpu_op_ has no per-task host-side state to release — // rtsLaunchCpuKernel does not hand back any per-launch handle, and the @@ -1072,7 +1090,11 @@ void DeviceRunnerBase::read_device_wall_ns() { } int DeviceRunnerBase::init_runtime_args_with_metadata(Runtime &runtime) { - int rc = kernel_args_.init_runtime_args(runtime, mem_alloc_); + return init_runtime_args_with_metadata(runtime, kernel_args_); +} + +int DeviceRunnerBase::init_runtime_args_with_metadata(Runtime &runtime, KernelArgsHelper &helper) { + int rc = helper.init_runtime_args(runtime, mem_alloc_); if (rc != 0) { LOG_ERROR("init_runtime_args failed: %d", rc); return rc; @@ -1081,10 +1103,10 @@ int DeviceRunnerBase::init_runtime_args_with_metadata(Runtime &runtime) { // HostLogger is the single source of truth for log config (seeded by // libsimpler_log.so via simpler_log_init before host_runtime.so was even // dlopen'd). Read it directly when populating KernelArgs. - kernel_args_.args.log_level = static_cast(HostLogger::get_instance().level()); - kernel_args_.args.log_info_v = static_cast(HostLogger::get_instance().info_v()); + helper.args.log_level = static_cast(HostLogger::get_instance().level()); + helper.args.log_info_v = static_cast(HostLogger::get_instance().info_v()); // Device ordinal for the AICPU executor's per-device orchestration-SO name. - kernel_args_.args.device_id = static_cast(device_id_); + helper.args.device_id = static_cast(device_id_); return 0; } diff --git a/src/common/platform/onboard/host/device_runner_base.h b/src/common/platform/onboard/host/device_runner_base.h index c844d2546..b09fca31e 100644 --- a/src/common/platform/onboard/host/device_runner_base.h +++ b/src/common/platform/onboard/host/device_runner_base.h @@ -560,6 +560,7 @@ class DeviceRunnerBase : public L3L2OrchCommBackend { * @return 0 on success, the underlying init_runtime_args rc on failure. */ int init_runtime_args_with_metadata(Runtime &runtime); + int init_runtime_args_with_metadata(Runtime &runtime, KernelArgsHelper &helper); /** * Start collector mgmt + poll threads for the four shared @@ -746,7 +747,9 @@ class DeviceRunnerBase : public L3L2OrchCommBackend { // `finalize()`. `nullptr` before init. rtStream_t stream_aicpu_{nullptr}; rtStream_t stream_aicore_{nullptr}; + rtStream_t stream_aicpu_prewarm_{nullptr}; KernelArgsHelper kernel_args_; + KernelArgsHelper prewarm_kernel_args_; // Platform-level device wall buffer: 8-byte device-resident slot // whose address rides on `KernelArgs.device_wall_data_base`. AICPU diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/repeat_vector_add_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/repeat_vector_add_orch.cpp new file mode 100644 index 000000000..3f6d534da --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/repeat_vector_add_orch.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include "pto_orchestration_api.h" // NOLINT(build/include_subdir) + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +repeat_vector_add_orchestration_config(const L2TaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 4, + }; +} + +__attribute__((visibility("default"))) void repeat_vector_add_orchestration(const L2TaskArgs &orch_args) { + const Tensor &a = orch_args.tensor(0).ref(); + const Tensor &b = orch_args.tensor(1).ref(); + const Tensor &out = orch_args.tensor(2).ref(); + volatile uint64_t spin_count = orch_args.scalar(0); + + while (spin_count--) { + __asm__ __volatile__("" ::: "memory"); + } + + L0TaskArgs params; + params.add_input(a); + params.add_input(b); + params.add_output(out); + rt_submit_aiv_task(0, params); +} + +} // extern "C" diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/simple_vector_add_orch.cpp b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/simple_vector_add_orch.cpp new file mode 100644 index 000000000..c4bfc0315 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/kernels/orchestration/simple_vector_add_orch.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include "pto_orchestration_api.h" // NOLINT(build/include_subdir) + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +simple_vector_add_orchestration_config(const L2TaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 3, + }; +} + +__attribute__((visibility("default"))) void simple_vector_add_orchestration(const L2TaskArgs &orch_args) { + const Tensor &a = orch_args.tensor(0).ref(); + const Tensor &b = orch_args.tensor(1).ref(); + const Tensor &out = orch_args.tensor(2).ref(); + + L0TaskArgs params; + params.add_input(a); + params.add_input(b); + params.add_output(out); + rt_submit_aiv_task(0, params); +} + +} // extern "C" diff --git a/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/test_run_async_overlap.py b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/test_run_async_overlap.py new file mode 100644 index 000000000..7b45dc368 --- /dev/null +++ b/tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/test_run_async_overlap.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Hardware acceptance for L3 DAG run_async/register_async overlap. + +Run through task-submit, for example: + + task-submit --device auto --device-num 1 --run \ + "python tests/st/a2a3/tensormap_and_ringbuffer/run_async_overlap/test_run_async_overlap.py \ + --platform a2a3 --device \\$TASK_DEVICE" +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time + +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") + +import torch # noqa: E402 +from simpler.task_interface import ( # noqa: E402 + ArgDirection, + CallConfig, + ChipCallable, + CoreCallable, + DataType, + TaskArgs, + Tensor, + TensorArgType, +) +from simpler.worker import Worker # noqa: E402 + +from simpler_setup.elf_parser import extract_text_section # noqa: E402 +from simpler_setup.kernel_compiler import KernelCompiler # noqa: E402 +from simpler_setup.pto_isa import ensure_pto_isa_root # noqa: E402 +from simpler_setup.torch_interop import make_tensor_arg # noqa: E402 + +HERE = os.path.dirname(os.path.abspath(__file__)) +VECTOR_EXAMPLE = os.path.abspath( + os.path.join(HERE, "../../../../../examples/a2a3/tensormap_and_ringbuffer/vector_example") +) +RUNTIME = "tensormap_and_ringbuffer" +N_ROWS = 128 +N_COLS = 128 +N_ELEMS = N_ROWS * N_COLS +NBYTES = N_ELEMS * 4 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--platform", default="a2a3", choices=["a2a3", "a2a3sim"]) + parser.add_argument("--device", type=int, required=True) + parser.add_argument("--repeat-count", type=int, default=2000) + parser.add_argument("--skip-register", action="store_true") + parser.add_argument("--dag-baseline", action="store_true") + return parser.parse_args() + + +def _kernel_compiler(platform: str): + kc = KernelCompiler(platform=platform) + pto_isa_root = ensure_pto_isa_root(clone_protocol="https") + include_dirs = kc.get_orchestration_include_dirs(RUNTIME) + return kc, pto_isa_root, include_dirs + + +def build_repeat_vector_add_callable(platform: str) -> ChipCallable: + kc, pto_isa_root, include_dirs = _kernel_compiler(platform) + kernel_bytes = kc.compile_incore( + source_path=os.path.join(VECTOR_EXAMPLE, "kernels/aiv/kernel_add.cpp"), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=include_dirs, + ) + if not platform.endswith("sim"): + kernel_bytes = extract_text_section(kernel_bytes) + orch_bytes = kc.compile_orchestration( + runtime_name=RUNTIME, + source_path=os.path.join(HERE, "kernels/orchestration/repeat_vector_add_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + arg_index=[0, 1, 2], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + func_name="repeat_vector_add_orchestration", + config_name="repeat_vector_add_orchestration_config", + binary=orch_bytes, + children=[(0, core_callable)], + ) + + +def build_vector_add_callable(platform: str) -> ChipCallable: + kc, pto_isa_root, include_dirs = _kernel_compiler(platform) + kernel_bytes = kc.compile_incore( + source_path=os.path.join(VECTOR_EXAMPLE, "kernels/aiv/kernel_add.cpp"), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=include_dirs, + ) + if not platform.endswith("sim"): + kernel_bytes = extract_text_section(kernel_bytes) + orch_bytes = kc.compile_orchestration( + runtime_name=RUNTIME, + source_path=os.path.join(HERE, "kernels/orchestration/simple_vector_add_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + arg_index=[0, 1, 2], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + func_name="simple_vector_add_orchestration", + config_name="simple_vector_add_orchestration_config", + binary=orch_bytes, + children=[(0, core_callable)], + ) + + +def make_host_data(): + host_a = torch.full((N_ELEMS,), 2.0, dtype=torch.float32).share_memory_() + host_w = torch.full((N_ELEMS,), 3.0, dtype=torch.float32).share_memory_() + host_out = torch.zeros(N_ELEMS, dtype=torch.float32).share_memory_() + expected = host_a + host_w + return host_a, host_w, host_out, expected + + +def make_vector_args(host_a: torch.Tensor, host_out: torch.Tensor, dev_w: int): + w_dev = Tensor.make(dev_w, (N_ELEMS,), DataType.FLOAT32, child_memory=True) + args = TaskArgs() + args.add_tensor(make_tensor_arg(host_a), TensorArgType.INPUT) + args.add_tensor(w_dev, TensorArgType.INPUT) + args.add_tensor(make_tensor_arg(host_out), TensorArgType.OUTPUT_EXISTING) + return args + + +def run( + platform: str, device: int, repeat_count: int, *, skip_register: bool = False, dag_baseline: bool = False +) -> None: + repeat_callable = None if skip_register else build_repeat_vector_add_callable(platform) + vector_callable = build_vector_add_callable(platform) + + worker = Worker(level=3, platform=platform, runtime=RUNTIME, device_ids=[device], num_sub_workers=0) + repeat_handle = worker.register(repeat_callable) if repeat_callable is not None else None + skip_vector_handle = worker.register(vector_callable) if skip_register else None + host_a, host_w, host_out, expected = make_host_data() + worker.init() + + dev_w: int | None = None + try: + dev_w = worker.malloc(NBYTES, worker_id=0) + worker.copy_to(dev_w, host_w.data_ptr(), NBYTES, worker_id=0) + vec_args = make_vector_args(host_a, host_out, dev_w) + repeat_args = TaskArgs() + for i in range(vec_args.tensor_count()): + repeat_args.add_tensor(vec_args.tensor(i)) + repeat_args.add_scalar(int(repeat_count)) + cfg = CallConfig() + + submit_ns = time.perf_counter_ns() + c1_handle = skip_vector_handle if skip_register else repeat_handle + assert c1_handle is not None + c1_args = vec_args if skip_register else repeat_args + if dag_baseline: + worker.run(lambda orch, _args, _cfg: orch.submit_next_level(c1_handle, c1_args, cfg, worker=0)) + max_diff = float(torch.max(torch.abs(host_out - expected))) + print(f"[run_async_overlap] dag_baseline max_diff={max_diff:.3e}") + assert torch.allclose(host_out, expected, rtol=1e-5, atol=1e-5) + return + + def run_c1(orch, _args, _cfg): + orch.submit_next_level(c1_handle, c1_args, cfg, worker=0) + + run_handle = worker.run_async(run_c1) + if skip_register: + repeat_timing = run_handle.wait() + max_diff = float(torch.max(torch.abs(host_out - expected))) + print( + "[run_async_overlap] " + f"skip_register repeat_host_us={repeat_timing.host_wall_us:.1f} max_diff={max_diff:.3e}" + ) + assert torch.allclose(host_out, expected, rtol=1e-5, atol=1e-5) + return + time.sleep(0.005) + + reg_start_ns = time.perf_counter_ns() + vector_pending = worker.register_async(vector_callable) + vector_handle = vector_pending.wait() + reg_done_ns = time.perf_counter_ns() + + repeat_timing = run_handle.wait() + run_done_ns = time.perf_counter_ns() + + register_wait_ns = reg_done_ns - reg_start_ns + run_total_ns = run_done_ns - submit_ns + sequential_estimate_ns = repeat_timing.host_wall_ns + register_wait_ns + print( + "[run_async_overlap] " + f"register_wait_us={register_wait_ns / 1000.0:.1f} " + f"run_total_us={run_total_ns / 1000.0:.1f} " + f"repeat_host_us={repeat_timing.host_wall_us:.1f} " + f"sequential_estimate_us={sequential_estimate_ns / 1000.0:.1f}" + ) + if run_total_ns >= sequential_estimate_ns * 0.85: + raise AssertionError( + "run/register did not overlap enough: " + f"run_total_ns={run_total_ns}, sequential_estimate_ns={sequential_estimate_ns}" + ) + + worker.run(lambda orch, _args, _cfg: orch.submit_next_level(vector_handle, vec_args, cfg, worker=0)) + max_diff = float(torch.max(torch.abs(host_out - expected))) + print(f"[run_async_overlap] vector_add max_diff={max_diff:.3e}") + assert torch.allclose(host_out, expected, rtol=1e-5, atol=1e-5) + finally: + if dev_w is not None: + worker.free(dev_w, worker_id=0) + worker.close() + + +def main() -> int: + args = parse_args() + run(args.platform, args.device, args.repeat_count, skip_register=args.skip_register, dag_baseline=args.dag_baseline) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/st/a5/tensormap_and_ringbuffer/l3_l2_orch_comm/test_l3_l2_orch_comm.py b/tests/st/a5/tensormap_and_ringbuffer/l3_l2_orch_comm/test_l3_l2_orch_comm.py index 7449c194d..d5813fd15 100644 --- a/tests/st/a5/tensormap_and_ringbuffer/l3_l2_orch_comm/test_l3_l2_orch_comm.py +++ b/tests/st/a5/tensormap_and_ringbuffer/l3_l2_orch_comm/test_l3_l2_orch_comm.py @@ -67,7 +67,7 @@ def _build_chip_callable(platform: str) -> ChipCallable: signature=[], func_name="l3_l2_orch_comm_orchestration", binary=orch, - children=[(0, CoreCallable.build(signature=[D.IN, D.OUT], binary=aiv))], + children=[(0, CoreCallable.build(signature=[D.IN, D.OUT], arg_index=[0, 1], binary=aiv))], ) diff --git a/tests/ut/py/test_worker/test_host_worker.py b/tests/ut/py/test_worker/test_host_worker.py index 1bbb30970..8ae8c5520 100644 --- a/tests/ut/py/test_worker/test_host_worker.py +++ b/tests/ut/py/test_worker/test_host_worker.py @@ -107,11 +107,19 @@ def _slot_for(worker: Worker, handle: CallableHandle) -> int: class _FakeControlResult: - def __init__(self, worker_type: str, worker_id: int = 0, ok: bool = True, error_message: str = ""): + def __init__( + self, + worker_type: str, + worker_id: int = 0, + ok: bool = True, + error_message: str = "", + remote_handle: int = 1, + ): self.worker_type = worker_type self.worker_id = worker_id self.ok = ok self.error_message = error_message + self.remote_handle = remote_handle def _chip_payload_shm(callable_obj: ChipCallable) -> SharedMemory: @@ -454,13 +462,14 @@ def test_prepare_chip_callable_broadcast_runs_without_registry_lock(self): callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) observed = {} - def fake_post_init_register(target, digest, *, is_new): + def fake_post_init_register_async(target, digest, *, is_new): observed["target"] = target observed["digest"] = digest observed["is_new"] = is_new observed["locked"] = hw._registry_lock.locked() + return [] - hw._post_init_register = fake_post_init_register + hw._post_init_register_async = fake_post_init_register_async handle = hw.register(callable_obj) @@ -1183,10 +1192,13 @@ def test_duplicate_chip_prepare_broadcasts_ref_increment_without_new_slot(self): calls = [] class FakeWorker: - def broadcast_register_all(self, blob_ptr, blob_size, digest): + def broadcast_register_async_all(self, blob_ptr, blob_size, digest): calls.append(("binary_register", blob_size, digest)) return [_FakeControlResult("NEXT_LEVEL", 0, True)] + def control_wait_register(self, worker_id, handle_id): + calls.append(("wait_register", worker_id, handle_id)) + hw = Worker(level=3, num_sub_workers=1) hw._initialized = True hw._hierarchical_started = True @@ -1202,7 +1214,9 @@ def broadcast_register_all(self, blob_ptr, blob_size, digest): assert hw._identity_registry[first.digest].ref_count == 2 assert calls == [ ("binary_register", int(callable_obj.buffer_size()), first.digest), + ("wait_register", 0, 1), ("binary_register", int(callable_obj.buffer_size()), second.digest), + ("wait_register", 0, 1), ] def test_duplicate_chip_prepare_partial_failure_preserves_existing_handle(self): @@ -1212,13 +1226,16 @@ class FakeWorker: def __init__(self): self.register_count = 0 - def broadcast_register_all(self, blob_ptr, blob_size, digest): + def broadcast_register_async_all(self, blob_ptr, blob_size, digest): self.register_count += 1 calls.append(("binary_register", self.register_count, digest)) if self.register_count == 1: return [_FakeControlResult("NEXT_LEVEL", 0, True), _FakeControlResult("NEXT_LEVEL", 1, True)] return [_FakeControlResult("NEXT_LEVEL", 0, True), _FakeControlResult("NEXT_LEVEL", 1, False, "boom")] + def control_wait_register(self, worker_id, handle_id): + calls.append(("wait_register", worker_id, handle_id)) + def control_digest_only(self, worker_type, worker_id, sub_cmd, digest, timeout_s=None): calls.append(("cleanup_one", worker_type, worker_id, sub_cmd, digest)) return _FakeControlResult("NEXT_LEVEL", worker_id, True) @@ -1239,6 +1256,8 @@ def control_digest_only(self, worker_type, worker_id, sub_cmd, digest, timeout_s assert first.digest not in hw._uncertain_hashids assert calls == [ ("binary_register", 1, first.digest), + ("wait_register", 0, 1), + ("wait_register", 1, 1), ("binary_register", 2, first.digest), ("cleanup_one", WorkerType.NEXT_LEVEL, 0, _CTRL_UNREGISTER, first.digest), ] @@ -1247,7 +1266,7 @@ def test_chip_prepare_failure_rolls_back_handle_and_marks_uncertain_when_cleanup calls = [] class FakeWorker: - def broadcast_register_all(self, blob_ptr, blob_size, digest): + def broadcast_register_async_all(self, blob_ptr, blob_size, digest): calls.append(("binary_register", digest)) raise RuntimeError("register failed") diff --git a/tests/ut/py/test_worker/test_run_async.py b/tests/ut/py/test_worker/test_run_async.py new file mode 100644 index 000000000..80d945db6 --- /dev/null +++ b/tests/ut/py/test_worker/test_run_async.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +import threading +from types import SimpleNamespace + +from _task_interface import RunTiming +from simpler.task_interface import CallConfig, ChipCallable, ChipStorageTaskArgs +from simpler.worker import RegisterHandle, RunHandle, Worker + + +def _chip_callable() -> ChipCallable: + return ChipCallable.build(signature=[], func_name="test", binary=b"\x00", children=[]) + + +class _FakeChipWorker: + def __init__(self): + self.calls = [] + self.unregisters = [] + self.run_count = 0 + self.entered = threading.Event() + self.release_first = threading.Event() + + def _run_slot(self, slot_id, args, config): + index = self.run_count + self.run_count += 1 + self.calls.append(("start", index, slot_id, args, config)) + if index == 0: + self.entered.set() + self.release_first.wait(timeout=2.0) + self.calls.append(("end", index, slot_id)) + return RunTiming(index + 1, 0) + + def _unregister_slot(self, slot_id): + self.unregisters.append(slot_id) + + +def _make_l2_worker(): + worker = Worker(level=2, platform="a2a3sim", runtime="tensormap_and_ringbuffer") + handle = worker.register(_chip_callable()) + fake = _FakeChipWorker() + worker._chip_worker = fake + worker._initialized = True + worker._start_l2_run_lane() + return worker, handle, fake + + +def test_l2_run_async_executes_fifo_on_one_run_lane(): + worker, handle, fake = _make_l2_worker() + try: + first = worker.run_async(handle, ChipStorageTaskArgs(), CallConfig()) + second = worker.run_async(handle, ChipStorageTaskArgs(), CallConfig()) + + assert isinstance(first, RunHandle) + assert fake.entered.wait(timeout=2.0) + assert not second.completed + fake.release_first.set() + + assert first.wait().host_wall_ns == 1 + assert second.wait().host_wall_ns == 2 + assert [entry[:2] for entry in fake.calls] == [ + ("start", 0), + ("end", 0), + ("start", 1), + ("end", 1), + ] + finally: + worker._stop_l2_run_lane() + + +def test_l2_sync_run_waits_behind_prior_async_run(): + worker, handle, fake = _make_l2_worker() + sync_timing = [] + try: + first = worker.run_async(handle, ChipStorageTaskArgs(), CallConfig()) + assert fake.entered.wait(timeout=2.0) + + sync_thread = threading.Thread( + target=lambda: sync_timing.append(worker.run(handle, ChipStorageTaskArgs(), CallConfig())) + ) + sync_thread.start() + assert not sync_timing + + fake.release_first.set() + assert first.wait().host_wall_ns == 1 + sync_thread.join(timeout=2.0) + assert not sync_thread.is_alive() + assert sync_timing[0].host_wall_ns == 2 + assert [entry[:2] for entry in fake.calls] == [ + ("start", 0), + ("end", 0), + ("start", 1), + ("end", 1), + ] + finally: + worker._stop_l2_run_lane() + + +def test_l2_unregister_async_tombstones_and_defers_free_until_run_finishes(): + worker, handle, fake = _make_l2_worker() + try: + first = worker.run_async(handle, ChipStorageTaskArgs(), CallConfig()) + assert fake.entered.wait(timeout=2.0) + + unreg = worker.unregister_async(handle) + assert not unreg.completed + try: + worker.run_async(handle, ChipStorageTaskArgs(), CallConfig()) + except KeyError: + pass + else: + raise AssertionError("tombstoned handle should reject new runs") + + fake.release_first.set() + assert first.wait().host_wall_ns == 1 + unreg.wait() + assert fake.unregisters == [0] + finally: + worker._stop_l2_run_lane() + + +def test_async_register_unregister_reject_non_chip_targets(): + worker = Worker(level=3, device_ids=[0]) + try: + worker.register_async(lambda: None) + except TypeError: + pass + else: + raise AssertionError("register_async should reject non-ChipCallable targets") + + try: + worker.unregister_async(object()) + except TypeError: + pass + else: + raise AssertionError("unregister_async should reject non-ChipCallable handles") + + +def test_l3_run_async_runs_dag_on_worker_queue(): + worker = Worker(level=3, device_ids=[0]) + worker._initialized = True + entered = threading.Event() + release_first = threading.Event() + calls = [] + run_count = 0 + + def fake_run_dag(orch_fn, args, config): + nonlocal run_count + index = run_count + run_count += 1 + calls.append(("start", index, orch_fn, args, config)) + if index == 0: + entered.set() + release_first.wait(timeout=2.0) + calls.append(("end", index)) + return RunTiming(index + 10, 0) + + worker._run_dag_sync_impl = fake_run_dag + try: + first = worker.run_async(lambda orch, args, cfg: None, "first", CallConfig()) + assert isinstance(first, RunHandle) + assert entered.wait(timeout=2.0) + + sync_result = [] + sync_thread = threading.Thread( + target=lambda: sync_result.append(worker.run(lambda orch, args, cfg: None, "second", CallConfig())) + ) + sync_thread.start() + assert not sync_result + + release_first.set() + assert first.wait().host_wall_ns == 10 + sync_thread.join(timeout=2.0) + assert not sync_thread.is_alive() + assert sync_result[0].host_wall_ns == 11 + assert [entry[:2] for entry in calls] == [ + ("start", 0), + ("end", 0), + ("start", 1), + ("end", 1), + ] + finally: + worker._stop_dag_run_lane() + + +def test_l3_run_async_does_not_accept_worker_keyword(): + worker = Worker(level=3, device_ids=[0]) + worker._initialized = True + try: + try: + worker.run_async(lambda orch, args, cfg: None, worker=0) # type: ignore[call-arg] + except TypeError: + pass + else: + raise AssertionError("L3 public run_async must not expose direct chip worker selection") + finally: + worker._stop_dag_run_lane() + + +def test_l3_register_async_returns_handle_after_remote_wait(): + worker = Worker(level=3, device_ids=[0]) + + class FakeWorker: + def __init__(self): + self.broadcasts = [] + self.waits = [] + + def broadcast_register_async_all(self, blob_ptr, blob_size, digest): + self.broadcasts.append((blob_ptr, blob_size, digest)) + return [SimpleNamespace(worker_type="NEXT_LEVEL", worker_id=0, ok=True, remote_handle=23, error_message="")] + + def control_wait_register(self, worker_id, handle_id): + self.waits.append((worker_id, handle_id)) + + fake = FakeWorker() + worker._initialized = True + worker._hierarchical_started = True + worker._hierarchical_start_state = "started" + worker._worker = fake + + pending = worker.register_async(_chip_callable()) + assert isinstance(pending, RegisterHandle) + assert not pending.completed + + handle = pending.wait() + assert pending.completed + assert fake.broadcasts[0][2] == handle.digest + assert fake.waits == [(0, 23)]