diff --git a/docs/l3-l2-message-queue.md b/docs/l3-l2-message-queue.md new file mode 100644 index 000000000..12a0034ad --- /dev/null +++ b/docs/l3-l2-message-queue.md @@ -0,0 +1,372 @@ +# L3-L2 Message Queue + +L3-L2 Message Queue lets an L3 Host Orchestrator exchange ordered messages +with one persistent L2 AICPU Orchestrator task. + +The intended use case is repeated in-flight work: L3 enqueues input messages, +L2 consumes them while the L2 task stays alive, L2 publishes output messages, +and L3 dequeues those outputs. The queue is built on top of the lower-level +L3-L2 orchestration communication primitives described in +[l3-l2-orch-comm.md](l3-l2-orch-comm.md). For where L3 and L2 sit in +the runtime stack, see +[hierarchical_level_runtime.md](hierarchical_level_runtime.md). + +## 1. API + +L3 creates one queue for one chip worker: + +```python +queue = orch.create_l3_l2_queue( + worker_id=0, + depth=4, + input_arena_bytes=1 << 20, + output_arena_bytes=1 << 20, +) +``` + +The queue owns one underlying `L3L2OrchRegion`. Its payload range is split into +input/output descriptor rings and input/output payload arenas. Its counter +range stores descriptor head/tail signals and abort flags. + +L3 passes the primitive region descriptor and queue layout arguments to L2: + +```python +l2_args = TaskArgs() +for value in queue.l2_task_arg_scalars(): + l2_args.add_scalar(value) + +orch.submit_next_level(l2_handle, l2_args, cfg, worker=0) +``` + +`l2_task_arg_scalars()` returns: + +```text +primitive region descriptor scalars[0..5] +queue_magic_version +depth +input_arena_bytes +output_arena_bytes +payload_bytes +counter_bytes +``` + +L3 sends input messages through `queue.input`: + +```python +host_input = orch.alloc([nbytes], DataType.UINT8) +fill_input(host_input) + +queue.input.enqueue(host_input, nbytes=nbytes, timeout=timeout_s) +``` + +`try_enqueue(buffer, nbytes)` is the non-blocking form. It returns `False` +when the input descriptor ring or payload arena has no space. That result is +ordinary backpressure and does not poison the queue. + +L3 receives output messages through `queue.output`: + +```python +host_output = orch.alloc([max_output_nbytes], DataType.UINT8) + +message = queue.output.peek(timeout=timeout_s) +queue.output.read_into(message, host_output) +queue.output.release(message) +``` + +The convenience form reads and releases in one operation: + +```python +message = queue.output.dequeue_into(host_output, timeout=timeout_s) +``` + +`try_peek()` and `try_dequeue_into(buffer)` are the non-blocking forms. They +return `None` when no output message is available. + +The L3 buffer arguments currently must be runtime-managed tensors returned by +`orch.alloc(...)`. Ordinary Python `bytes`, `bytearray`, and private tensors +are rejected before shared queue state is modified. Zero-byte messages use +`buffer_or_none=None` and `nbytes=0`. + +L3 requests graceful shutdown by publishing an input-side `STOP` descriptor: + +```python +queue.request_stop(timeout=timeout_s) +queue.free() +``` + +`try_request_stop()` is the non-blocking form. `queue.free()` releases the L3 +queue handle and marks the underlying `L3L2OrchRegion` handle released. It does +not synchronously free device memory; physical cleanup follows the underlying +region lifetime model after submitted L2 work has drained. Small Python wrapper +scratch tensors used for descriptor staging are owned by the queue object and +follow normal Python object lifetime. + +On L2, orchestration code receives the primitive descriptor and queue args, +then constructs an endpoint: + +```cpp +L3L2OrchRegionDesc desc{/* scalars from TaskArgs */}; +L3L2QueueArgs queue_args{ + magic_version, + depth, + input_arena_bytes, + output_arena_bytes, + payload_bytes, + counter_bytes, +}; + +L3L2QueueEndpoint queue(desc, queue_args); +if (queue.error().kind != L3L2QueueErrorKind::NONE) { + return; +} +``` + +L2 consumes input messages from `queue.input()` and publishes outputs through +`queue.output()`: + +```cpp +while (true) { + L3L2QueueInputHandle input{}; + if (!queue.input().peek(timeout_ns, &input)) { + return; + } + + if (input.opcode == L3L2QueueOpcode::STOP) { + queue.input().release(input); + return; + } + + L3L2QueueOutputReservation output{}; + if (!queue.output().reserve(input.payload_nbytes, timeout_ns, &output)) { + return; + } + + launch_aicore(input.payload, output.payload); + wait_aicore_done(); + + queue.output().publish(output, L3L2QueueOpcode::DATA); + queue.input().release(input); +} +``` + +`queue.input().try_peek(&input)` and +`queue.output().try_reserve(nbytes, &reservation)` are non-blocking. A `false` +return can mean no progress, timeout, validation failure, or poison; check +`queue.error().kind` to distinguish ordinary no-progress from terminal error. + +## 2. Layout + +The physical region has one payload range: + +```text +payload region +|-- input descriptor ring +|-- output descriptor ring +|-- input payload arena +`-- output payload arena +``` + +The two payload arenas are separate: + +```text +input arena: producer = L3, consumer = L2 +output arena: producer = L2, consumer = L3 +``` + +`depth` is the descriptor-ring capacity in each direction. It must be a power +of two and at most `2^30`. Queue capacity is exactly `depth` messages, not +`depth - 1`. + +`input_arena_bytes` and `output_arena_bytes` must be positive 64-byte +multiples. They do not need to be powers of two. A single message payload must +fit as one contiguous span inside its direction's arena. Payloads are not split +across arena wrap. + +Python and C++ mirror the same deterministic queue layout calculation: + +```text +input_desc_offset +output_desc_offset +input_arena_offset +output_arena_offset +payload_bytes +counter_bytes +``` + +Python exposes this as `queue.layout`; L2 exposes it as `queue.layout()`. +L3 passes the derived `payload_bytes` and `counter_bytes` to L2. L2 rejects +initialization unless those values match both its local layout calculation and +the primitive region descriptor sizes. Lockstep tests cover representative +layout cases for the mirrored Python and C++ calculations. + +## 3. Descriptor ABI + +Each descriptor slot is 32 bytes: + +```cpp +struct L3L2QueueDescSlot { + uint64_t seq; + uint64_t opcode; + uint64_t payload_offset; + uint64_t payload_nbytes; +}; +``` + +`seq` is the transport sequence number for ring validation, wrap detection, and +diagnostics. It is not a user request ID. Applications that need request IDs, +batch IDs, final markers, or correlation fields should put them in their own +payload header. + +`payload_offset` is relative to the primitive region payload base. The payload +must be wholly inside the matching direction's arena. Zero-byte messages use +`payload_offset == 0` and `payload_nbytes == 0`. + +The queue currently defines these opcodes: + +| Opcode | Meaning | +| ------ | ------- | +| `DATA` | Ordinary application payload message. | +| `STOP` | Graceful input-side shutdown request. | +| `ERROR` | Ordinary application-level error payload message. | + +`STOP` is valid only on the input queue. The output queue has no `STOP` +message; L2 exit is observed through normal `Worker.run` drain. + +`ERROR` is a normal queue message. The queue layer does not interpret its +payload and does not poison the queue when an `ERROR` message is received. +Infrastructure failures use poison state instead. + +## 4. Signals And Ordering + +The queue uses the primitive signal counters as descriptor head/tail values. +Each shared signal is placed on a 64-byte stride: + +```text +offset 0: input_desc_tail writer=L3 +offset 64: input_desc_head writer=L2 +offset 128: output_desc_tail writer=L2 +offset 192: output_desc_head writer=L3 +offset 256: l3_abort_flag writer=L3 +offset 320: l2_abort_flag writer=L2 +``` + +Descriptor counters store the low 32 bits of monotonic logical head/tail +values. Each endpoint reconstructs its local 64-bit value from observed +progress. The unobserved progress must be between zero and `depth`; anything +else is inconsistent shared state and poisons the queue. + +The producer sequence is: + +```text +reserve payload space +write payload bytes +write descriptor fields +write descriptor seq +publish descriptor tail counter +``` + +The consumer sequence is: + +```text +observe descriptor tail progress +read and validate descriptor +use payload bytes or payload view +release descriptor and payload +publish descriptor head counter +``` + +All Python blocking queue operations require finite positive timeouts; passing +`timeout <= 0` is a caller error and raises `ValueError`. Python `try_*` APIs +are non-blocking and return `False` or `None` for ordinary no-progress. + +C++ blocking queue operations take `timeout_ns`; `timeout_ns == 0` is an +immediate timeout probe. They return `false` on no-progress, timeout, +validation failure, or poison. C++ `try_*` APIs are non-blocking and also +return `false` for ordinary no-progress. + +Timeout under ordinary backpressure is not poison. After timeout, an endpoint +samples the peer abort flag; if the peer flag is set, the local endpoint +reports remote abort. + +## 5. Ownership + +Queue ownership is per message. + +On L3 output, `peek()` returns a handle that remains active until +`release(handle)`. While a handle is active, repeated `try_peek()` returns the +same handle. The caller may read the payload with `read_into(handle, buffer)` +before releasing it. Releasing the wrong handle is an ownership error and +poisons the queue. + +On L2 input, `peek()` returns one active input handle. L2 must not call +`peek()` again before releasing that handle. L2 must not release an input until +all AICore work that reads the input payload has completed. + +On L2 output, `reserve()` returns one active output reservation. L2 fills the +reserved payload span, then calls `publish(reservation, opcode)`. Publishing an +unknown, stale, already-published, or cross-queue reservation is an ownership +error and poisons the queue. + +The base queue supports at most one active L2 input handle and one active L2 +output reservation. It does not provide a multi-input L2 window. + +## 6. STOP Semantics + +`STOP` is an input descriptor with no payload. It follows normal FIFO ordering: +L2 observes and releases messages before `STOP`, then releases `STOP` and +returns from the persistent run. + +After L3 successfully publishes `STOP`, the input queue rejects further input +messages locally without poisoning. L3 may still dequeue output messages that +L2 publishes before returning. + +`request_stop(timeout)` waits only until the `STOP` descriptor is published. +It does not wait for L2 exit and does not drain outputs. Applications that need +all outputs must keep dequeuing until their own protocol-level final condition +is satisfied before returning from the L3 orchestration function. + +## 7. Error Handling + +The queue distinguishes no-progress, application errors, and infrastructure +poison. + +No-progress is non-terminal: + +- descriptor ring full; +- payload arena full; +- empty output queue; +- blocking operation timeout with no peer abort flag. + +Application-level error is represented by `opcode=ERROR`. It is delivered to +the peer as a normal message and does not set an abort flag. + +Infrastructure poison is terminal for the local queue handle: + +- descriptor sequence mismatch; +- invalid opcode in a published descriptor; +- output-side `STOP`; +- descriptor payload outside its direction's arena; +- impossible counter reconstruction or payload replay; +- payload command failure after shared mutation begins; +- counter notify failure; +- stale or invalid handle/reservation ownership. + +When an endpoint enters local infrastructure poison, it sets its own abort flag +for the peer. Observing the peer abort flag reports remote abort but does not +set the local abort flag. + +After poison, normal queue operations reject. Cleanup remains valid. + +## 8. Platform Support + +The message queue uses the existing L3-L2 orchestration communication region, +payload, and counter primitives. + +- `a2a3sim`: supported. +- `a5sim`: supported. +- `a2a3` onboard: supported where the underlying L3-L2 communication + primitives are supported. +- `a5` onboard: follows the underlying L3-L2 communication support status. + +Simulation backends preserve the same API, ordering, timeout, and error +semantics as onboard backends. diff --git a/docs/l3-l2-orch-comm.md b/docs/l3-l2-orch-comm.md index 6c541dfe5..256babbf5 100644 --- a/docs/l3-l2-orch-comm.md +++ b/docs/l3-l2-orch-comm.md @@ -3,6 +3,10 @@ L3-L2 Orchestrator Communication lets an L3 Host Orchestrator exchange payload bytes and signal counters with a running L2 AICPU Orchestrator task. +This page documents the low-level region, payload, and counter primitives. For +the ordered SPSC message queue wrapper built on these primitives, see +[l3-l2-message-queue.md](l3-l2-message-queue.md). + The intended use case is in-flight interaction: L3 can write input payload, publish a data-ready counter, wait for L2/AICore completion, and read output payload without ending the L2 orchestration task. For where L3 and L2 sit in diff --git a/python/simpler/l3_l2_message_queue.py b/python/simpler/l3_l2_message_queue.py new file mode 100644 index 000000000..91236fe5e --- /dev/null +++ b/python/simpler/l3_l2_message_queue.py @@ -0,0 +1,561 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""L3-side L3-L2 SPSC message queue wrapper.""" + +from __future__ import annotations + +import ctypes +import struct +import time +from dataclasses import dataclass +from enum import IntEnum +from typing import Any + +from .l3_l2_orch_comm import ( + L3L2OrchCommCmd, + L3L2OrchCommRequest, + L3L2OrchRegion, + NotifyOp, + WaitCmp, +) +from .task_interface import DataType, Tensor + +L3L2_QUEUE_MAGIC = 0x4C335132 +L3L2_QUEUE_ABI_MAJOR = 1 +L3L2_QUEUE_ABI_MINOR = 1 +L3L2_QUEUE_DESC_SLOT_BYTES = 32 +L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT = 64 +L3L2_QUEUE_COUNTER_STRIDE = 64 +L3L2_QUEUE_INPUT_DESC_TAIL_OFFSET = 0 +L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET = 64 +L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET = 128 +L3L2_QUEUE_OUTPUT_DESC_HEAD_OFFSET = 192 +L3L2_QUEUE_L3_ABORT_FLAG_OFFSET = 256 +L3L2_QUEUE_L2_ABORT_FLAG_OFFSET = 320 +L3L2_QUEUE_COUNTER_BYTES = 384 +L3L2_QUEUE_MAX_DEPTH = 1 << 30 +_UINT64_MAX = (1 << 64) - 1 + +_DESC = struct.Struct("<4Q") +_POLL_INTERVAL_S = 0.00005 + + +class L3L2QueueOpcode(IntEnum): + INVALID = 0 + DATA = 1 + STOP = 2 + ERROR = 3 + + +class _QueueState(IntEnum): + LIVE = 0 + RELEASED = 1 + POISONED_LOCAL = 2 + POISONED_REMOTE = 3 + EXPIRED = 4 + + +@dataclass(frozen=True) +class L3L2QueueLayout: + depth: int + input_desc_offset: int + output_desc_offset: int + input_arena_offset: int + output_arena_offset: int + input_arena_bytes: int + output_arena_bytes: int + payload_bytes: int + input_desc_tail_offset: int + input_desc_head_offset: int + output_desc_tail_offset: int + output_desc_head_offset: int + l3_abort_flag_offset: int + l2_abort_flag_offset: int + counter_bytes: int + + +@dataclass(frozen=True) +class L3L2QueueMessage: + seq: int + opcode: L3L2QueueOpcode + payload_offset: int + payload_nbytes: int + + +def l3_l2_queue_magic_version() -> int: + return (L3L2_QUEUE_MAGIC << 32) | (L3L2_QUEUE_ABI_MAJOR << 16) | L3L2_QUEUE_ABI_MINOR + + +def _align_up(value: int, align: int) -> int: + if value < 0 or value > _UINT64_MAX: + raise ValueError("L3-L2 queue layout calculation overflowed uint64") + remainder = value % align + bump = 0 if remainder == 0 else align - remainder + result = value + bump + if result > _UINT64_MAX: + raise ValueError("L3-L2 queue layout calculation overflowed uint64") + return result + + +def _checked_add_u64(lhs: int, rhs: int) -> int: + result = lhs + rhs + if lhs < 0 or rhs < 0 or result > _UINT64_MAX: + raise ValueError("L3-L2 queue layout calculation overflowed uint64") + return result + + +def make_l3_l2_queue_layout(depth: int, input_arena_bytes: int, output_arena_bytes: int) -> L3L2QueueLayout: + depth = int(depth) + input_arena_bytes = int(input_arena_bytes) + output_arena_bytes = int(output_arena_bytes) + if depth <= 0 or depth & (depth - 1) != 0 or depth > L3L2_QUEUE_MAX_DEPTH: + raise ValueError("L3-L2 queue depth must be a power of two and <= 2^30") + if input_arena_bytes <= 0 or input_arena_bytes % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT != 0: + raise ValueError("L3-L2 queue input_arena_bytes must be a positive 64-byte multiple") + if output_arena_bytes <= 0 or output_arena_bytes % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT != 0: + raise ValueError("L3-L2 queue output_arena_bytes must be a positive 64-byte multiple") + + desc_ring_bytes = depth * L3L2_QUEUE_DESC_SLOT_BYTES + if desc_ring_bytes > _UINT64_MAX: + raise ValueError("L3-L2 queue layout calculation overflowed uint64") + input_desc_offset = 0 + output_desc_offset = _checked_add_u64(input_desc_offset, desc_ring_bytes) + desc_end = _checked_add_u64(output_desc_offset, desc_ring_bytes) + input_arena_offset = _align_up(desc_end, L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT) + input_arena_end = _checked_add_u64(input_arena_offset, input_arena_bytes) + output_arena_offset = _align_up(input_arena_end, L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT) + payload_bytes = _checked_add_u64(output_arena_offset, output_arena_bytes) + return L3L2QueueLayout( + depth=depth, + input_desc_offset=input_desc_offset, + output_desc_offset=output_desc_offset, + input_arena_offset=input_arena_offset, + output_arena_offset=output_arena_offset, + input_arena_bytes=input_arena_bytes, + output_arena_bytes=output_arena_bytes, + payload_bytes=payload_bytes, + input_desc_tail_offset=L3L2_QUEUE_INPUT_DESC_TAIL_OFFSET, + input_desc_head_offset=L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET, + output_desc_tail_offset=L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET, + output_desc_head_offset=L3L2_QUEUE_OUTPUT_DESC_HEAD_OFFSET, + l3_abort_flag_offset=L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, + l2_abort_flag_offset=L3L2_QUEUE_L2_ABORT_FLAG_OFFSET, + counter_bytes=L3L2_QUEUE_COUNTER_BYTES, + ) + + +def create_l3_l2_queue( + orch: Any, + *, + worker_id: int, + depth: int, + input_arena_bytes: int, + output_arena_bytes: int, +) -> L3L2Queue: + layout = make_l3_l2_queue_layout(depth, input_arena_bytes, output_arena_bytes) + region = orch.create_l3_l2_region( + worker_id=int(worker_id), + payload_bytes=layout.payload_bytes, + counter_bytes=layout.counter_bytes, + ) + try: + desc_fields = orch.alloc([24], DataType.UINT8) + desc_seq = orch.alloc([8], DataType.UINT8) + desc_read = orch.alloc([L3L2_QUEUE_DESC_SLOT_BYTES], DataType.UINT8) + for offset in ( + layout.input_desc_tail_offset, + layout.input_desc_head_offset, + layout.output_desc_tail_offset, + layout.output_desc_head_offset, + layout.l3_abort_flag_offset, + layout.l2_abort_flag_offset, + ): + region.counter(offset).notify(0, NotifyOp.Set) + except Exception: + try: + region.free() + except Exception: + pass + raise + return L3L2Queue(orch, region, layout, desc_fields, desc_seq, desc_read) + + +class L3L2Queue: + def __init__( + self, + orch: Any, + region: L3L2OrchRegion, + layout: L3L2QueueLayout, + desc_fields: Tensor, + desc_seq: Tensor, + desc_read: Tensor, + ) -> None: + self._orch = orch + self._region = region + self._layout = layout + self._desc_fields = desc_fields + self._desc_seq = desc_seq + self._desc_read = desc_read + self._state = _QueueState.LIVE + self._input_head = 0 + self._input_tail = 0 + self._output_head = 0 + self._output_tail = 0 + self._input_payload_tail = 0 + self._input_payload_head = 0 + self._output_payload_head = 0 + self._output_active: L3L2QueueMessage | None = None + self._stop_published = False + self.input = _L3InputQueue(self) + self.output = _L3OutputQueue(self) + + @property + def region(self) -> L3L2OrchRegion: + return self._region + + @property + def layout(self) -> L3L2QueueLayout: + return self._layout + + @property + def magic_version(self) -> int: + return l3_l2_queue_magic_version() + + def l2_task_arg_scalars(self) -> list[int]: + self._ensure_live() + return [ + *self._region.descriptor_scalars(), + self.magic_version, + self._layout.depth, + self._layout.input_arena_bytes, + self._layout.output_arena_bytes, + self._layout.payload_bytes, + self._layout.counter_bytes, + ] + + def try_request_stop(self) -> bool: + return self.input._try_enqueue(None, 0, L3L2QueueOpcode.STOP) + + def request_stop(self, timeout: float) -> None: + self.input._enqueue(None, 0, L3L2QueueOpcode.STOP, timeout) + + def free(self) -> None: + if self._state == _QueueState.RELEASED: + return + self._state = _QueueState.RELEASED + self._region.free() + + def _ensure_live(self) -> None: + if self._state == _QueueState.RELEASED: + raise RuntimeError("L3-L2 queue has been released") + if self._state == _QueueState.POISONED_REMOTE: + raise RuntimeError("L3-L2 queue is remote-aborted") + if self._state == _QueueState.POISONED_LOCAL: + raise RuntimeError("L3-L2 queue is poisoned") + if self._state == _QueueState.EXPIRED: + raise RuntimeError("L3-L2 queue expired after orchestration run") + if getattr(self._region, "_expired", False): + self._state = _QueueState.EXPIRED + raise RuntimeError("L3-L2 queue expired after orchestration run") + self._region._ensure_live() + + def _validate_registered_buffer(self, buffer: Any, nbytes: int) -> Tensor: + if not isinstance(buffer, Tensor): + raise ValueError("L3-L2 queue PR1 requires a registered Tensor returned by orch.alloc(...)") + self._region._owner._validate_l3_l2_orch_comm_host_buffer(buffer) + if int(nbytes) > int(buffer.nbytes()): + raise ValueError(f"L3-L2 queue nbytes={nbytes} exceeds registered Tensor size {int(buffer.nbytes())}") + return buffer + + def _refresh_counter(self, offset: int, local_value: int, depth: int) -> int: + result = self._signal_test(offset, local_value & 0xFFFF_FFFF, WaitCmp.NE) + if not result.matched: + return local_value + observed = int(result.observed) & 0xFFFF_FFFF + local_low = local_value & 0xFFFF_FFFF + delta = ctypes.c_int32((observed - local_low) & 0xFFFF_FFFF).value + if delta < 0 or delta > depth: + self._poison_local() + raise RuntimeError("L3-L2 queue counter reconstruction failed") + return local_value + delta + + def _sample_peer_abort_after_timeout(self) -> None: + result = self._signal_test(self._layout.l2_abort_flag_offset, 1, WaitCmp.GE) + if result.matched: + self._state = _QueueState.POISONED_REMOTE + raise RuntimeError("L3-L2 queue remote abort observed") + raise TimeoutError("L3-L2 queue operation timed out") + + def _poison_local(self) -> None: + if self._state != _QueueState.LIVE: + return + self._state = _QueueState.POISONED_LOCAL + try: + self._region._owner._l3_l2_orch_comm_submit( + self._region._worker_id, + L3L2OrchCommRequest( + cmd=L3L2OrchCommCmd.SIGNAL_NOTIFY, + op=int(NotifyOp.Set), + region_id=self._region.region_id, + counter_addr=int(self._region.descriptor.counter_base) + self._layout.l3_abort_flag_offset, + counter_operand=1, + ), + 5.0, + ) + except Exception: + pass + + def _run_primitive(self, fn: Any, *args: Any, **kwargs: Any) -> Any: + try: + return fn(*args, **kwargs) + except Exception: + self._poison_local() + raise + + def _signal_test(self, offset: int, cmp_value: int, cmp: WaitCmp) -> Any: + return self._run_primitive(lambda: self._region.counter(offset).test(cmp_value, cmp)) + + def _signal_notify(self, offset: int, value: int) -> None: + self._run_primitive(lambda: self._region.counter(offset).notify(value, NotifyOp.Set)) + + def _write_descriptor( + self, offset: int, seq: int, opcode: L3L2QueueOpcode, payload_offset: int, nbytes: int + ) -> None: + fields_buf = (ctypes.c_uint8 * 24).from_address(int(self._desc_fields.data)) + fields_buf[:] = _DESC.pack(0, int(opcode), int(payload_offset), int(nbytes))[8:] + seq_buf = (ctypes.c_uint8 * 8).from_address(int(self._desc_seq.data)) + seq_buf[:] = struct.pack(" L3L2QueueMessage: + self._run_primitive(self._region.payload_read, offset, self._desc_read, nbytes=L3L2_QUEUE_DESC_SLOT_BYTES) + raw = ctypes.string_at(int(self._desc_read.data), L3L2_QUEUE_DESC_SLOT_BYTES) + seq, opcode_value, payload_offset, payload_nbytes = _DESC.unpack(raw) + try: + opcode = L3L2QueueOpcode(opcode_value) + except ValueError: + self._poison_local() + raise RuntimeError("L3-L2 queue observed invalid descriptor opcode") from None + return L3L2QueueMessage( + seq=int(seq), + opcode=opcode, + payload_offset=int(payload_offset), + payload_nbytes=int(payload_nbytes), + ) + + def _advance_payload_head( + self, + cursor: int, + payload_offset: int, + payload_nbytes: int, + arena_offset: int, + arena_bytes: int, + ) -> int: + if payload_nbytes == 0: + return cursor + expected_offset = arena_offset + (cursor % arena_bytes) + if expected_offset != payload_offset: + if payload_offset != arena_offset: + self._poison_local() + raise RuntimeError("L3-L2 queue payload replay offset mismatch") + cursor += arena_bytes - (cursor % arena_bytes) + return cursor + payload_nbytes + + def _replay_released_input_descriptors(self, old_head: int, new_head: int) -> None: + cursor = old_head + while cursor < new_head: + slot_index = cursor & (self._layout.depth - 1) + slot_offset = self._layout.input_desc_offset + slot_index * L3L2_QUEUE_DESC_SLOT_BYTES + message = self._read_descriptor(slot_offset) + if message.seq != cursor + 1: + self._poison_local() + raise RuntimeError("L3-L2 queue input release replay seq mismatch") + self._input_payload_head = self._advance_payload_head( + self._input_payload_head, + message.payload_offset, + message.payload_nbytes, + self._layout.input_arena_offset, + self._layout.input_arena_bytes, + ) + cursor += 1 + + +class _L3InputQueue: + def __init__(self, queue: L3L2Queue) -> None: + self._queue = queue + + def enqueue(self, buffer_or_none: Any, nbytes: int, timeout: float) -> None: + self._enqueue(buffer_or_none, nbytes, L3L2QueueOpcode.DATA, timeout) + + def try_enqueue(self, buffer_or_none: Any, nbytes: int) -> bool: + return self._try_enqueue(buffer_or_none, nbytes, L3L2QueueOpcode.DATA) + + def _enqueue(self, buffer_or_none: Any, nbytes: int, opcode: L3L2QueueOpcode, timeout: float) -> None: + if timeout is None or float(timeout) <= 0: + raise ValueError("L3-L2 queue blocking operations require a positive timeout") + deadline = time.monotonic() + float(timeout) + while True: + if self._try_enqueue(buffer_or_none, nbytes, opcode): + return + if self._queue._stop_published: + raise RuntimeError("L3-L2 queue input is stopped") + if time.monotonic() >= deadline: + self._queue._sample_peer_abort_after_timeout() + time.sleep(_POLL_INTERVAL_S) + + def _try_enqueue(self, buffer_or_none: Any, nbytes: int, opcode: L3L2QueueOpcode) -> bool: + queue = self._queue + nbytes = int(nbytes) + if nbytes < 0: + raise ValueError("L3-L2 queue nbytes must be non-negative") + payload_tensor = None + if nbytes == 0: + if buffer_or_none is not None: + raise ValueError("L3-L2 queue zero-byte enqueue requires buffer_or_none == None") + else: + payload_tensor = queue._validate_registered_buffer(buffer_or_none, nbytes) + + queue._ensure_live() + if queue._stop_published: + return False + if opcode == L3L2QueueOpcode.STOP and nbytes != 0: + raise ValueError("L3-L2 queue STOP must be zero-byte") + + old_head = queue._input_head + queue._input_head = queue._refresh_counter( + queue._layout.input_desc_head_offset, queue._input_head, queue._layout.depth + ) + if queue._input_head != old_head: + queue._replay_released_input_descriptors(old_head, queue._input_head) + if queue._input_tail - queue._input_head >= queue._layout.depth: + return False + if nbytes > queue._layout.input_arena_bytes: + return False + + payload_offset = 0 + if nbytes != 0: + arena_pos = queue._input_payload_tail % queue._layout.input_arena_bytes + if arena_pos + nbytes > queue._layout.input_arena_bytes: + queue._input_payload_tail += queue._layout.input_arena_bytes - arena_pos + arena_pos = 0 + if queue._input_payload_tail + nbytes - queue._input_payload_head > queue._layout.input_arena_bytes: + return False + payload_offset = queue._layout.input_arena_offset + arena_pos + queue._run_primitive(queue._region.payload_write, payload_offset, payload_tensor, nbytes=nbytes) + queue._input_payload_tail += nbytes + + seq = queue._input_tail + 1 + slot_index = queue._input_tail & (queue._layout.depth - 1) + slot_offset = queue._layout.input_desc_offset + slot_index * L3L2_QUEUE_DESC_SLOT_BYTES + queue._write_descriptor(slot_offset, seq, opcode, payload_offset, nbytes) + queue._input_tail += 1 + queue._signal_notify(queue._layout.input_desc_tail_offset, queue._input_tail) + if opcode == L3L2QueueOpcode.STOP: + queue._stop_published = True + return True + + +class _L3OutputQueue: + def __init__(self, queue: L3L2Queue) -> None: + self._queue = queue + + def try_peek(self) -> L3L2QueueMessage | None: + queue = self._queue + queue._ensure_live() + if queue._output_active is not None: + return queue._output_active + queue._output_tail = queue._refresh_counter( + queue._layout.output_desc_tail_offset, queue._output_tail, queue._layout.depth + ) + if queue._output_tail == queue._output_head: + return None + slot_index = queue._output_head & (queue._layout.depth - 1) + slot_offset = queue._layout.output_desc_offset + slot_index * L3L2_QUEUE_DESC_SLOT_BYTES + message = queue._read_descriptor(slot_offset) + if message.seq != queue._output_head + 1: + queue._poison_local() + raise RuntimeError("L3-L2 queue output descriptor seq mismatch") + if message.opcode == L3L2QueueOpcode.STOP: + queue._poison_local() + raise RuntimeError("L3-L2 queue output descriptor cannot be STOP") + if message.payload_nbytes == 0: + if message.payload_offset != 0: + queue._poison_local() + raise RuntimeError("L3-L2 queue zero-byte output descriptor has nonzero offset") + else: + begin = queue._layout.output_arena_offset + end = begin + queue._layout.output_arena_bytes + if message.payload_offset < begin or message.payload_offset + message.payload_nbytes > end: + queue._poison_local() + raise RuntimeError("L3-L2 queue output payload outside output arena") + queue._advance_payload_head( + queue._output_payload_head, + message.payload_offset, + message.payload_nbytes, + queue._layout.output_arena_offset, + queue._layout.output_arena_bytes, + ) + queue._output_active = message + return message + + def peek(self, timeout: float) -> L3L2QueueMessage: + if timeout is None or float(timeout) <= 0: + raise ValueError("L3-L2 queue blocking operations require a positive timeout") + deadline = time.monotonic() + float(timeout) + while True: + message = self.try_peek() + if message is not None: + return message + if time.monotonic() >= deadline: + self._queue._sample_peer_abort_after_timeout() + time.sleep(_POLL_INTERVAL_S) + + def read_into(self, handle: L3L2QueueMessage, buffer: Any) -> None: + queue = self._queue + queue._ensure_live() + if queue._output_active != handle: + raise RuntimeError("L3-L2 queue output handle is not active") + if handle.payload_nbytes == 0: + if buffer is not None: + raise ValueError("L3-L2 queue zero-byte output read requires buffer == None") + return + target = queue._validate_registered_buffer(buffer, handle.payload_nbytes) + queue._run_primitive(queue._region.payload_read, handle.payload_offset, target, nbytes=handle.payload_nbytes) + + def release(self, handle: L3L2QueueMessage) -> None: + queue = self._queue + queue._ensure_live() + if queue._output_active != handle: + queue._poison_local() + raise RuntimeError("L3-L2 queue output handle is not active") + queue._output_payload_head = queue._advance_payload_head( + queue._output_payload_head, + handle.payload_offset, + handle.payload_nbytes, + queue._layout.output_arena_offset, + queue._layout.output_arena_bytes, + ) + queue._output_head += 1 + queue._output_active = None + queue._signal_notify(queue._layout.output_desc_head_offset, queue._output_head) + + def dequeue_into(self, buffer: Any, timeout: float) -> L3L2QueueMessage: + handle = self.peek(timeout) + self.read_into(handle, buffer) + self.release(handle) + return handle + + def try_dequeue_into(self, buffer: Any) -> L3L2QueueMessage | None: + handle = self.try_peek() + if handle is None: + return None + self.read_into(handle, buffer) + self.release(handle) + return handle diff --git a/python/simpler/orchestrator.py b/python/simpler/orchestrator.py index 87ec02e16..f998b48af 100644 --- a/python/simpler/orchestrator.py +++ b/python/simpler/orchestrator.py @@ -359,6 +359,20 @@ def create_l3_l2_region(self, *, worker_id: int, payload_bytes: int, counter_byt raise RuntimeError("create_l3_l2_region requires an Orchestrator bound to a Worker") return self._worker._create_l3_l2_region(int(worker_id), int(payload_bytes), int(counter_bytes)) + def create_l3_l2_queue(self, *, worker_id: int, depth: int, input_arena_bytes: int, output_arena_bytes: int): + """Create an L3-L2 message queue backed by one L3-L2 communication region.""" + if self._worker is None: + raise RuntimeError("create_l3_l2_queue requires an Orchestrator bound to a Worker") + from .l3_l2_message_queue import create_l3_l2_queue # noqa: PLC0415 + + return create_l3_l2_queue( + self, + worker_id=int(worker_id), + depth=int(depth), + input_arena_bytes=int(input_arena_bytes), + output_arena_bytes=int(output_arena_bytes), + ) + # ------------------------------------------------------------------ # Nested scope (Strict-1 per-scope rings) # ------------------------------------------------------------------ diff --git a/src/common/platform/include/aicpu/l3_l2_message_queue.h b/src/common/platform/include/aicpu/l3_l2_message_queue.h new file mode 100644 index 000000000..4c149ba7e --- /dev/null +++ b/src/common/platform/include/aicpu/l3_l2_message_queue.h @@ -0,0 +1,734 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#ifndef SRC_COMMON_PLATFORM_INCLUDE_AICPU_L3_L2_MESSAGE_QUEUE_H_ +#define SRC_COMMON_PLATFORM_INCLUDE_AICPU_L3_L2_MESSAGE_QUEUE_H_ + +#include +#include +#include + +#include "aicpu/l3_l2_orch_endpoint.h" + +static constexpr uint32_t L3L2_QUEUE_MAGIC = 0x4C335132u; // "L3Q2" +static constexpr uint16_t L3L2_QUEUE_ABI_MAJOR = 1; +static constexpr uint16_t L3L2_QUEUE_ABI_MINOR = 1; +static constexpr uint64_t L3L2_QUEUE_DESC_SLOT_BYTES = 32; +static constexpr uint64_t L3L2_QUEUE_DESC_RING_ALIGNMENT = 8; +static constexpr uint64_t L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT = 64; +static constexpr uint64_t L3L2_QUEUE_COUNTER_STRIDE = 64; +static constexpr uint64_t L3L2_QUEUE_INPUT_DESC_TAIL_OFFSET = 0; +static constexpr uint64_t L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET = 64; +static constexpr uint64_t L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET = 128; +static constexpr uint64_t L3L2_QUEUE_OUTPUT_DESC_HEAD_OFFSET = 192; +static constexpr uint64_t L3L2_QUEUE_L3_ABORT_FLAG_OFFSET = 256; +static constexpr uint64_t L3L2_QUEUE_L2_ABORT_FLAG_OFFSET = 320; +static constexpr uint64_t L3L2_QUEUE_COUNTER_BYTES = 384; +static constexpr uint64_t L3L2_QUEUE_MAX_DEPTH = 1ull << 30; + +struct L3L2QueueDescSlot { + uint64_t seq; + uint64_t opcode; + uint64_t payload_offset; + uint64_t payload_nbytes; +}; + +enum class L3L2QueueOpcode : uint64_t { + INVALID = 0, + DATA = 1, + STOP = 2, + ERROR = 3, +}; + +enum class L3L2QueueErrorKind : uint32_t { + NONE = 0, + BAD_ARGUMENT = 1, + BAD_DESCRIPTOR = 2, + INVALID_DESCRIPTOR = 3, + OUT_OF_SPACE = 4, + OWNERSHIP = 5, + REMOTE_ABORTED = 6, + ENDPOINT_ERROR = 7, +}; + +enum class L3L2QueueTimeoutStatus : uint32_t { + ORDINARY_TIMEOUT = 0, + REMOTE_ABORTED = 1, +}; + +struct L3L2QueueError { + L3L2QueueErrorKind kind; + const char *op; + uint64_t region_id; + const char *message; +}; + +struct L3L2QueueLayout { + uint64_t depth; + uint64_t input_desc_offset; + uint64_t output_desc_offset; + uint64_t input_arena_offset; + uint64_t output_arena_offset; + uint64_t input_arena_bytes; + uint64_t output_arena_bytes; + uint64_t payload_bytes; + uint64_t input_desc_tail_offset; + uint64_t input_desc_head_offset; + uint64_t output_desc_tail_offset; + uint64_t output_desc_head_offset; + uint64_t l3_abort_flag_offset; + uint64_t l2_abort_flag_offset; + uint64_t counter_bytes; +}; + +struct L3L2QueueArgs { + uint64_t magic_version; + uint64_t depth; + uint64_t input_arena_bytes; + uint64_t output_arena_bytes; + uint64_t payload_bytes; + uint64_t counter_bytes; +}; + +struct L3L2QueueInputHandle { + uint64_t seq; + L3L2QueueOpcode opcode; + uint64_t payload_offset; + uint64_t payload_nbytes; + L3L2OrchPayloadView payload; +}; + +struct L3L2QueueOutputReservation { + uint64_t seq; + uint64_t payload_offset; + uint64_t payload_nbytes; + L3L2OrchPayloadView payload; + bool valid; +}; + +static inline uint64_t l3_l2_queue_magic_version() { + return l3_l2_orch_comm_pack_magic_version(L3L2_QUEUE_MAGIC, L3L2_QUEUE_ABI_MAJOR, L3L2_QUEUE_ABI_MINOR); +} + +static inline bool l3_l2_queue_is_power_of_two(uint64_t value) { return value != 0 && (value & (value - 1)) == 0; } + +static inline uint64_t l3_l2_queue_align_up(uint64_t value, uint64_t align) { + if (align == 0) { + return value; + } + uint64_t remainder = value % align; + return remainder == 0 ? value : value + (align - remainder); +} + +static inline bool l3_l2_queue_align_up_checked(uint64_t value, uint64_t align, uint64_t *out) { + if (out == nullptr || align == 0) { + return false; + } + uint64_t remainder = value % align; + uint64_t bump = remainder == 0 ? 0 : align - remainder; + if (l3_l2_orch_comm_add_overflows(value, bump)) { + return false; + } + *out = value + bump; + return true; +} + +static inline bool l3_l2_queue_valid_opcode(L3L2QueueOpcode opcode) { + return opcode == L3L2QueueOpcode::DATA || opcode == L3L2QueueOpcode::STOP || opcode == L3L2QueueOpcode::ERROR; +} + +static inline bool +l3_l2_queue_make_layout(uint64_t depth, uint64_t input_arena_bytes, uint64_t output_arena_bytes, L3L2QueueLayout *out) { + if (out == nullptr || !l3_l2_queue_is_power_of_two(depth) || depth > L3L2_QUEUE_MAX_DEPTH || + input_arena_bytes == 0 || output_arena_bytes == 0 || + input_arena_bytes % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT != 0 || + output_arena_bytes % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT != 0) { + return false; + } + + uint64_t desc_ring_bytes = depth * L3L2_QUEUE_DESC_SLOT_BYTES; + uint64_t input_desc_offset = 0; + if (l3_l2_orch_comm_add_overflows(input_desc_offset, desc_ring_bytes)) { + return false; + } + uint64_t output_desc_offset = input_desc_offset + desc_ring_bytes; + if (l3_l2_orch_comm_add_overflows(output_desc_offset, desc_ring_bytes)) { + return false; + } + uint64_t desc_end = output_desc_offset + desc_ring_bytes; + uint64_t input_arena_offset = 0; + if (!l3_l2_queue_align_up_checked(desc_end, L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT, &input_arena_offset)) { + return false; + } + if (l3_l2_orch_comm_add_overflows(input_arena_offset, input_arena_bytes)) { + return false; + } + uint64_t input_arena_end = input_arena_offset + input_arena_bytes; + uint64_t output_arena_offset = 0; + if (!l3_l2_queue_align_up_checked(input_arena_end, L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT, &output_arena_offset)) { + return false; + } + if (l3_l2_orch_comm_add_overflows(output_arena_offset, output_arena_bytes)) { + return false; + } + uint64_t payload_bytes = output_arena_offset + output_arena_bytes; + + *out = L3L2QueueLayout{ + depth, + input_desc_offset, + output_desc_offset, + input_arena_offset, + output_arena_offset, + input_arena_bytes, + output_arena_bytes, + payload_bytes, + L3L2_QUEUE_INPUT_DESC_TAIL_OFFSET, + L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET, + L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET, + L3L2_QUEUE_OUTPUT_DESC_HEAD_OFFSET, + L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, + L3L2_QUEUE_L2_ABORT_FLAG_OFFSET, + L3L2_QUEUE_COUNTER_BYTES, + }; + return output_desc_offset % L3L2_QUEUE_DESC_RING_ALIGNMENT == 0 && + input_arena_offset % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT == 0 && + output_arena_offset % L3L2_QUEUE_PAYLOAD_ARENA_ALIGNMENT == 0; +} + +static inline bool +l3_l2_queue_validate_region(const L3L2OrchRegionDesc &desc, const L3L2QueueArgs &args, L3L2QueueLayout *out_layout) { + L3L2QueueLayout layout{}; + if (args.magic_version != l3_l2_queue_magic_version() || + l3_l2_orch_comm_validate_desc(desc) != L3L2OrchCommValidationError::OK || + !l3_l2_queue_make_layout(args.depth, args.input_arena_bytes, args.output_arena_bytes, &layout)) { + return false; + } + if (args.payload_bytes != layout.payload_bytes || args.counter_bytes != layout.counter_bytes || + desc.payload_bytes != layout.payload_bytes || desc.counter_bytes != layout.counter_bytes) { + return false; + } + if (out_layout != nullptr) { + *out_layout = layout; + } + return true; +} + +static inline void l3_l2_queue_encode_desc( + L3L2QueueDescSlot *slot, uint64_t seq, L3L2QueueOpcode opcode, uint64_t payload_offset, uint64_t payload_nbytes +) { + if (slot == nullptr) { + return; + } + slot->seq = seq; + slot->opcode = static_cast(opcode); + slot->payload_offset = payload_offset; + slot->payload_nbytes = payload_nbytes; +} + +static inline bool l3_l2_queue_reconstruct_counter(int32_t observed_low32, uint64_t depth, uint64_t *local_value) { + if (local_value == nullptr || depth > L3L2_QUEUE_MAX_DEPTH) { + return false; + } + uint32_t local_low32 = static_cast(*local_value); + int32_t delta = static_cast(static_cast(observed_low32) - local_low32); + if (delta < 0 || static_cast(delta) > depth) { + return false; + } + *local_value += static_cast(delta); + return true; +} + +class L3L2QueueEndpoint { +public: + class InputQueue { + public: + explicit InputQueue(L3L2QueueEndpoint *parent) : + parent_(parent) {} + + bool peek(uint64_t timeout_ns, L3L2QueueInputHandle *out) { + if (out == nullptr) { + return false; + } + uint64_t start = l3_l2_orch_endpoint_now(); + uint64_t frequency_hz = l3_l2_orch_endpoint_timer_frequency_hz(); + uint64_t spins = 0; + while (true) { + if (try_peek(out)) { + return true; + } + if (parent_->error_.kind != L3L2QueueErrorKind::NONE) { + return false; + } + spins += 1; + if (timeout_ns == 0 || (spins & 1023ull) == 0) { + uint64_t now = l3_l2_orch_endpoint_now(); + if (timeout_ns == 0 || l3_l2_orch_endpoint_elapsed_ns(start, now, frequency_hz) >= timeout_ns) { + parent_->disambiguate_timeout(); + return false; + } + } + } + } + + bool try_peek(L3L2QueueInputHandle *out) { + if (out != nullptr) { + *out = L3L2QueueInputHandle{0, L3L2QueueOpcode::INVALID, 0, 0, L3L2OrchPayloadView{0, 0}}; + } + if (!parent_->ensure_live("input.try_peek") || out == nullptr) { + return false; + } + if (active_) { + parent_->poison(L3L2QueueErrorKind::OWNERSHIP, "input.try_peek", "input handle already active"); + return false; + } + if (!parent_->refresh_counter( + parent_->layout_.input_desc_tail_offset, parent_->input_tail_, parent_->layout_.depth, + "input.try_peek" + )) { + return false; + } + if (stopped_) { + if (parent_->input_tail_ != parent_->input_head_) { + parent_->poison( + L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", + "input descriptor published after STOP" + ); + } + return false; + } + if (parent_->input_tail_ == parent_->input_head_) { + return false; + } + if (parent_->input_tail_ - parent_->input_head_ > parent_->layout_.depth) { + parent_->poison( + L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", "input descriptor state invalid" + ); + return false; + } + + L3L2QueueDescSlot slot{}; + uint64_t slot_index = parent_->input_head_ & (parent_->layout_.depth - 1); + uint64_t slot_offset = parent_->layout_.input_desc_offset + slot_index * sizeof(L3L2QueueDescSlot); + if (!parent_->read_desc_slot(slot_offset, &slot, "input.try_peek")) { + return false; + } + uint64_t expected_seq = parent_->input_head_ + 1; + if (slot.seq != expected_seq) { + parent_->poison( + L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", "input descriptor seq mismatch" + ); + return false; + } + L3L2QueueOpcode opcode = static_cast(slot.opcode); + if (!l3_l2_queue_valid_opcode(opcode)) { + parent_->poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", "invalid input opcode"); + return false; + } + + L3L2OrchPayloadView view{0, 0}; + if (slot.payload_nbytes == 0) { + if (slot.payload_offset != 0) { + parent_->poison( + L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", + "zero-byte descriptor uses nonzero payload offset" + ); + return false; + } + } else if (!parent_->payload_in_arena( + slot.payload_offset, slot.payload_nbytes, parent_->layout_.input_arena_offset, + parent_->layout_.input_arena_bytes + )) { + parent_->poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, "input.try_peek", "input payload out of arena"); + return false; + } else if (!parent_->payload_matches_head( + parent_->input_payload_head_, slot.payload_offset, slot.payload_nbytes, + parent_->layout_.input_arena_offset, parent_->layout_.input_arena_bytes, "input.try_peek" + )) { + return false; + } else if (!parent_->endpoint_.payload_read(slot.payload_offset, slot.payload_nbytes, &view)) { + parent_->poison( + L3L2QueueErrorKind::ENDPOINT_ERROR, "input.try_peek", parent_->endpoint_.error().message + ); + return false; + } + + *out = L3L2QueueInputHandle{slot.seq, opcode, slot.payload_offset, slot.payload_nbytes, view}; + active_ = true; + active_seq_ = slot.seq; + active_opcode_ = opcode; + active_payload_offset_ = slot.payload_offset; + active_payload_nbytes_ = slot.payload_nbytes; + return true; + } + + bool release(const L3L2QueueInputHandle &handle) { + if (!parent_->ensure_live("input.release")) { + return false; + } + if (!active_ || handle.seq != active_seq_ || handle.seq != parent_->input_head_ + 1 || + handle.opcode != active_opcode_ || handle.payload_offset != active_payload_offset_ || + handle.payload_nbytes != active_payload_nbytes_) { + parent_->poison(L3L2QueueErrorKind::OWNERSHIP, "input.release", "input handle is not active"); + return false; + } + if (active_payload_nbytes_ != 0) { + parent_->advance_payload_head( + parent_->input_payload_head_, active_payload_offset_, active_payload_nbytes_, + parent_->layout_.input_arena_offset, parent_->layout_.input_arena_bytes, "input.release" + ); + if (parent_->error_.kind != L3L2QueueErrorKind::NONE) { + return false; + } + } + parent_->input_head_ += 1; + if (active_opcode_ == L3L2QueueOpcode::STOP) { + stopped_ = true; + } + active_ = false; + active_seq_ = 0; + active_opcode_ = L3L2QueueOpcode::INVALID; + active_payload_offset_ = 0; + active_payload_nbytes_ = 0; + return parent_->notify_counter( + parent_->layout_.input_desc_head_offset, static_cast(parent_->input_head_), "input.release" + ); + } + + private: + L3L2QueueEndpoint *parent_; + bool active_{false}; + uint64_t active_seq_{0}; + L3L2QueueOpcode active_opcode_{L3L2QueueOpcode::INVALID}; + uint64_t active_payload_offset_{0}; + uint64_t active_payload_nbytes_{0}; + bool stopped_{false}; + }; + + class OutputQueue { + public: + explicit OutputQueue(L3L2QueueEndpoint *parent) : + parent_(parent) {} + + bool reserve(uint64_t nbytes, uint64_t timeout_ns, L3L2QueueOutputReservation *out) { + if (out == nullptr) { + return false; + } + uint64_t start = l3_l2_orch_endpoint_now(); + uint64_t frequency_hz = l3_l2_orch_endpoint_timer_frequency_hz(); + uint64_t spins = 0; + while (true) { + if (try_reserve(nbytes, out)) { + return true; + } + if (parent_->error_.kind != L3L2QueueErrorKind::NONE) { + return false; + } + spins += 1; + if (timeout_ns == 0 || (spins & 1023ull) == 0) { + uint64_t now = l3_l2_orch_endpoint_now(); + if (timeout_ns == 0 || l3_l2_orch_endpoint_elapsed_ns(start, now, frequency_hz) >= timeout_ns) { + parent_->disambiguate_timeout(); + return false; + } + } + } + } + + bool try_reserve(uint64_t nbytes, L3L2QueueOutputReservation *out) { + if (out != nullptr) { + *out = L3L2QueueOutputReservation{0, 0, 0, L3L2OrchPayloadView{0, 0}, false}; + } + if (!parent_->ensure_live("output.try_reserve") || out == nullptr) { + return false; + } + if (reservation_active_) { + parent_->poison( + L3L2QueueErrorKind::OWNERSHIP, "output.try_reserve", "output reservation already active" + ); + return false; + } + if (nbytes > parent_->layout_.output_arena_bytes) { + return false; + } + uint64_t old_head = parent_->output_head_; + if (!parent_->refresh_counter( + parent_->layout_.output_desc_head_offset, parent_->output_head_, parent_->layout_.depth, + "output.try_reserve" + )) { + return false; + } + if (parent_->output_head_ != old_head && + !parent_->replay_output_releases(old_head, parent_->output_head_, "output.try_reserve")) { + return false; + } + if (parent_->output_tail_ - parent_->output_head_ >= parent_->layout_.depth) { + return false; + } + + uint64_t payload_offset = 0; + L3L2OrchPayloadView view{0, 0}; + if (nbytes != 0) { + uint64_t arena_base = parent_->layout_.output_arena_offset; + uint64_t arena_bytes = parent_->layout_.output_arena_bytes; + uint64_t arena_pos = parent_->output_payload_tail_ % arena_bytes; + if (arena_pos + nbytes > arena_bytes) { + // Payloads are never split across arena wrap. The skipped tail bytes are retired in the + // monotonic virtual cursor even if this reservation later finds the arena full. + parent_->output_payload_tail_ += arena_bytes - arena_pos; + arena_pos = 0; + } + if (parent_->output_payload_tail_ + nbytes - parent_->output_payload_head_ > arena_bytes) { + return false; + } + payload_offset = arena_base + arena_pos; + view = L3L2OrchPayloadView{parent_->endpoint_.descriptor().payload_base + payload_offset, nbytes}; + parent_->output_payload_tail_ += nbytes; + } + + reservation_active_ = true; + reservation_seq_ = parent_->output_tail_ + 1; + reservation_offset_ = payload_offset; + reservation_nbytes_ = nbytes; + *out = L3L2QueueOutputReservation{reservation_seq_, payload_offset, nbytes, view, true}; + return true; + } + + bool publish(const L3L2QueueOutputReservation &reservation, L3L2QueueOpcode opcode) { + if (!parent_->ensure_live("output.publish")) { + return false; + } + if (!reservation_active_ || !reservation.valid || reservation.seq != reservation_seq_ || + reservation.payload_offset != reservation_offset_ || + reservation.payload_nbytes != reservation_nbytes_) { + parent_->poison(L3L2QueueErrorKind::OWNERSHIP, "output.publish", "unknown output reservation"); + return false; + } + if (opcode == L3L2QueueOpcode::STOP || !l3_l2_queue_valid_opcode(opcode)) { + parent_->poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, "output.publish", "invalid output opcode"); + return false; + } + L3L2QueueDescSlot slot{}; + l3_l2_queue_encode_desc(&slot, 0, opcode, reservation.payload_offset, reservation.payload_nbytes); + uint64_t slot_index = parent_->output_tail_ & (parent_->layout_.depth - 1); + uint64_t slot_offset = parent_->layout_.output_desc_offset + slot_index * sizeof(L3L2QueueDescSlot); + if (!parent_->write_desc_slot(slot_offset, slot, reservation.seq, "output.publish")) { + return false; + } + parent_->output_tail_ += 1; + reservation_active_ = false; + reservation_seq_ = 0; + reservation_offset_ = 0; + reservation_nbytes_ = 0; + return parent_->notify_counter( + parent_->layout_.output_desc_tail_offset, static_cast(parent_->output_tail_), "output.publish" + ); + } + + private: + L3L2QueueEndpoint *parent_; + bool reservation_active_{false}; + uint64_t reservation_seq_{0}; + uint64_t reservation_offset_{0}; + uint64_t reservation_nbytes_{0}; + }; + + L3L2QueueEndpoint(const L3L2OrchRegionDesc &desc, const L3L2QueueArgs &args) : + endpoint_(desc), + input_queue_(this), + output_queue_(this) { + if (endpoint_.error().kind != L3L2EndpointErrorKind::NONE || + !l3_l2_queue_validate_region(desc, args, &layout_)) { + set_error(L3L2QueueErrorKind::BAD_DESCRIPTOR, "init", desc.region_id, "invalid queue descriptor"); + } + } + + const L3L2QueueError &error() const { return error_; } + const L3L2QueueLayout &layout() const { return layout_; } + InputQueue &input() { return input_queue_; } + OutputQueue &output() { return output_queue_; } + + L3L2QueueTimeoutStatus disambiguate_timeout() { + if (error_.kind != L3L2QueueErrorKind::NONE) { + return error_.kind == L3L2QueueErrorKind::REMOTE_ABORTED ? L3L2QueueTimeoutStatus::REMOTE_ABORTED : + L3L2QueueTimeoutStatus::ORDINARY_TIMEOUT; + } + L3L2OrchSignalTestResult result{}; + uint64_t addr = 0; + if (!endpoint_.counter_addr(layout_.l3_abort_flag_offset, &addr) || + !endpoint_.signal_test(addr, 1, L3L2OrchWaitCmp::GE, &result)) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, "timeout", endpoint_.error().message); + return L3L2QueueTimeoutStatus::ORDINARY_TIMEOUT; + } + if (result.matched) { + set_error(L3L2QueueErrorKind::REMOTE_ABORTED, "timeout", endpoint_.descriptor().region_id, "remote abort"); + return L3L2QueueTimeoutStatus::REMOTE_ABORTED; + } + return L3L2QueueTimeoutStatus::ORDINARY_TIMEOUT; + } + +private: + bool ensure_live(const char *op) { + if (error_.kind == L3L2QueueErrorKind::NONE) { + return true; + } + (void)op; + return false; + } + + void set_error(L3L2QueueErrorKind kind, const char *op, uint64_t region_id, const char *message) { + if (error_.kind != L3L2QueueErrorKind::NONE) { + return; + } + error_ = L3L2QueueError{kind, op, region_id, message}; + } + + void poison(L3L2QueueErrorKind kind, const char *op, const char *message) { + set_error(kind, op, endpoint_.descriptor().region_id, message); + if (kind != L3L2QueueErrorKind::REMOTE_ABORTED) { + uint64_t addr = 0; + if (endpoint_.counter_addr(layout_.l2_abort_flag_offset, &addr)) { + endpoint_.signal_notify(addr, 1, L3L2OrchNotifyOp::Set); + } + } + } + + bool notify_counter(uint64_t offset, int32_t value, const char *op) { + uint64_t addr = 0; + if (!endpoint_.counter_addr(offset, &addr) || !endpoint_.signal_notify(addr, value, L3L2OrchNotifyOp::Set)) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, op, endpoint_.error().message); + return false; + } + return true; + } + + bool refresh_counter(uint64_t offset, uint64_t &local, uint64_t depth, const char *op) { + uint64_t addr = 0; + L3L2OrchSignalTestResult result{}; + if (!endpoint_.counter_addr(offset, &addr) || + !endpoint_.signal_test(addr, static_cast(local), L3L2OrchWaitCmp::NE, &result)) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, op, endpoint_.error().message); + return false; + } + if (!result.matched) { + return true; + } + if (!l3_l2_queue_reconstruct_counter(result.observed, depth, &local)) { + poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, op, "counter reconstruction failed"); + return false; + } + return true; + } + + bool read_desc_slot(uint64_t slot_offset, L3L2QueueDescSlot *slot, const char *op) { + L3L2OrchPayloadView view{}; + if (!endpoint_.payload_read(slot_offset, sizeof(L3L2QueueDescSlot), &view)) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, op, endpoint_.error().message); + return false; + } + memcpy(slot, reinterpret_cast(static_cast(view.gm_addr)), sizeof(L3L2QueueDescSlot)); + return true; + } + + bool write_desc_slot(uint64_t slot_offset, const L3L2QueueDescSlot &slot, uint64_t seq, const char *op) { + L3L2QueueDescSlot fields = slot; + fields.seq = 0; + if (!endpoint_.payload_write(slot_offset + offsetof(L3L2QueueDescSlot, opcode), &fields.opcode, 24)) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, op, endpoint_.error().message); + return false; + } + if (!endpoint_.payload_write(slot_offset + offsetof(L3L2QueueDescSlot, seq), &seq, sizeof(seq))) { + poison(L3L2QueueErrorKind::ENDPOINT_ERROR, op, endpoint_.error().message); + return false; + } + return true; + } + + bool payload_in_arena(uint64_t offset, uint64_t nbytes, uint64_t arena_offset, uint64_t arena_bytes) const { + if (nbytes == 0 || l3_l2_orch_comm_add_overflows(offset, nbytes)) { + return false; + } + return offset >= arena_offset && offset + nbytes <= arena_offset + arena_bytes; + } + + bool payload_matches_head( + uint64_t cursor, uint64_t payload_offset, uint64_t nbytes, uint64_t arena_offset, uint64_t arena_bytes, + const char *op + ) { + if (nbytes == 0) { + return true; + } + uint64_t arena_pos = cursor % arena_bytes; + uint64_t expected_offset = arena_pos + nbytes > arena_bytes ? arena_offset : arena_offset + arena_pos; + if (payload_offset != expected_offset) { + poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, op, "payload replay offset mismatch"); + return false; + } + return true; + } + + void advance_payload_head( + uint64_t &cursor, uint64_t payload_offset, uint64_t nbytes, uint64_t arena_offset, uint64_t arena_bytes, + const char *op + ) { + uint64_t arena_pos = cursor % arena_bytes; + uint64_t expected_offset = arena_pos + nbytes > arena_bytes ? arena_offset : arena_offset + arena_pos; + if (expected_offset != payload_offset) { + poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, op, "payload replay offset mismatch"); + return; + } + if (arena_pos + nbytes > arena_bytes) { + cursor += arena_bytes - (cursor % arena_bytes); + } + cursor += nbytes; + } + + bool replay_output_releases(uint64_t old_head, uint64_t new_head, const char *op) { + uint64_t cursor = old_head; + while (cursor < new_head) { + L3L2QueueDescSlot slot{}; + uint64_t slot_index = cursor & (layout_.depth - 1); + uint64_t slot_offset = layout_.output_desc_offset + slot_index * sizeof(L3L2QueueDescSlot); + if (!read_desc_slot(slot_offset, &slot, op)) { + return false; + } + if (slot.seq != cursor + 1) { + poison(L3L2QueueErrorKind::INVALID_DESCRIPTOR, op, "output release replay seq mismatch"); + return false; + } + if (slot.payload_nbytes != 0) { + advance_payload_head( + output_payload_head_, slot.payload_offset, slot.payload_nbytes, layout_.output_arena_offset, + layout_.output_arena_bytes, op + ); + if (error_.kind != L3L2QueueErrorKind::NONE) { + return false; + } + } + cursor += 1; + } + return true; + } + + L3L2OrchEndpoint endpoint_; + L3L2QueueLayout layout_{}; + L3L2QueueError error_{L3L2QueueErrorKind::NONE, "", 0, ""}; + uint64_t input_head_{0}; + uint64_t input_tail_{0}; + uint64_t output_head_{0}; + uint64_t output_tail_{0}; + uint64_t input_payload_head_{0}; + uint64_t output_payload_head_{0}; + uint64_t output_payload_tail_{0}; + InputQueue input_queue_; + OutputQueue output_queue_; +}; + +#endif // SRC_COMMON_PLATFORM_INCLUDE_AICPU_L3_L2_MESSAGE_QUEUE_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 5fe6dd186..d4fcc497f 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -369,6 +369,23 @@ add_common_utils_test(test_elf_build_id common/test_elf_build_id.cpp) add_common_utils_test(test_runtime_orch_so common/test_runtime_orch_so.cpp) add_common_utils_test(test_device_arena common/test_device_arena.cpp) add_common_utils_test(test_l3_l2_orch_comm common/test_l3_l2_orch_comm.cpp) +add_executable(test_l3_l2_message_queue + common/test_l3_l2_message_queue.cpp + stubs/test_stubs.cpp +) +target_include_directories(test_l3_l2_message_queue PRIVATE + ${GTEST_INCLUDE_DIRS} + ${CMAKE_SOURCE_DIR}/../../../src/a2a3/platform/include + ${CMAKE_SOURCE_DIR}/../../../src/common/platform/include +) +target_link_libraries(test_l3_l2_message_queue PRIVATE + ${GTEST_MAIN_LIB} + ${GTEST_LIB} + pthread +) +add_test(NAME test_l3_l2_message_queue COMMAND test_l3_l2_message_queue) +set_tests_properties(test_l3_l2_message_queue PROPERTIES LABELS "no_hardware") + add_executable(test_l3_l2_orch_endpoint common/test_l3_l2_orch_endpoint.cpp stubs/test_stubs.cpp diff --git a/tests/ut/cpp/common/test_l3_l2_message_queue.cpp b/tests/ut/cpp/common/test_l3_l2_message_queue.cpp new file mode 100644 index 000000000..e7db495d4 --- /dev/null +++ b/tests/ut/cpp/common/test_l3_l2_message_queue.cpp @@ -0,0 +1,528 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "aicpu/l3_l2_message_queue.h" + +namespace { + +struct RegionStorage { + alignas(64) std::array payload{}; + alignas(64) std::array counters{}; +}; + +L3L2OrchRegionDesc make_desc(RegionStorage *storage, uint64_t payload_bytes = 512, uint64_t counter_bytes = 512) { + return L3L2OrchRegionDesc{ + l3_l2_orch_comm_magic_version(), + 19, + reinterpret_cast(storage->payload.data()), + payload_bytes, + reinterpret_cast(storage->counters.data()), + counter_bytes, + }; +} + +size_t counter_index(uint64_t offset) { return static_cast(offset / sizeof(int32_t)); } + +L3L2QueueArgs make_args(uint64_t depth, uint64_t input_arena_bytes, uint64_t output_arena_bytes) { + L3L2QueueLayout layout{}; + EXPECT_TRUE(l3_l2_queue_make_layout(depth, input_arena_bytes, output_arena_bytes, &layout)); + return L3L2QueueArgs{ + l3_l2_queue_magic_version(), depth, input_arena_bytes, output_arena_bytes, layout.payload_bytes, + layout.counter_bytes, + }; +} + +L3L2OrchRegionDesc make_desc(RegionStorage *storage, const L3L2QueueArgs &args) { + return make_desc(storage, args.payload_bytes, args.counter_bytes); +} + +void publish_input_desc( + RegionStorage *storage, const L3L2QueueLayout &layout, uint64_t seq, L3L2QueueOpcode opcode, + uint64_t payload_offset = 0, uint64_t payload_nbytes = 0 +) { + L3L2QueueDescSlot slot{}; + l3_l2_queue_encode_desc(&slot, seq, opcode, payload_offset, payload_nbytes); + uint64_t desc_offset = layout.input_desc_offset + ((seq - 1) & (layout.depth - 1)) * sizeof(L3L2QueueDescSlot); + std::memcpy(storage->payload.data() + desc_offset, &slot, sizeof(slot)); + storage->counters[counter_index(layout.input_desc_tail_offset)] = static_cast(seq); +} + +TEST(L3L2MessageQueueTest, LayoutAssignsPayloadAndAbortCounterOffsets) { + L3L2QueueLayout layout{}; + + ASSERT_TRUE(l3_l2_queue_make_layout(4, 128, 192, &layout)); + + EXPECT_EQ(layout.input_desc_offset, 0u); + EXPECT_EQ(layout.output_desc_offset, 4u * sizeof(L3L2QueueDescSlot)); + EXPECT_EQ(layout.input_arena_offset % 64u, 0u); + EXPECT_EQ(layout.output_arena_offset % 64u, 0u); + EXPECT_EQ(layout.input_desc_tail_offset, 0u); + EXPECT_EQ(layout.input_desc_head_offset, 64u); + EXPECT_EQ(layout.output_desc_tail_offset, 128u); + EXPECT_EQ(layout.output_desc_head_offset, 192u); + EXPECT_EQ(layout.l3_abort_flag_offset, 256u); + EXPECT_EQ(layout.l2_abort_flag_offset, 320u); + EXPECT_EQ(layout.counter_bytes, 384u); + EXPECT_GE(layout.payload_bytes, layout.output_arena_offset + 192u); +} + +TEST(L3L2MessageQueueTest, LayoutLockstepCasesMatchPythonMirrorExpectations) { + struct LayoutCase { + uint64_t depth; + uint64_t input_arena_bytes; + uint64_t output_arena_bytes; + uint64_t output_desc_offset; + uint64_t input_arena_offset; + uint64_t output_arena_offset; + uint64_t payload_bytes; + }; + + const std::array cases{{ + {1, 64, 64, 32, 64, 128, 192}, + {4, 128, 192, 128, 256, 384, 576}, + {8, 192, 64, 256, 512, 704, 768}, + }}; + + for (const auto &test_case : cases) { + L3L2QueueLayout layout{}; + ASSERT_TRUE( + l3_l2_queue_make_layout(test_case.depth, test_case.input_arena_bytes, test_case.output_arena_bytes, &layout) + ); + + EXPECT_EQ(layout.input_desc_offset, 0u); + EXPECT_EQ(layout.output_desc_offset, test_case.output_desc_offset); + EXPECT_EQ(layout.output_desc_offset, test_case.depth * sizeof(L3L2QueueDescSlot)); + EXPECT_EQ(layout.input_arena_offset, test_case.input_arena_offset); + EXPECT_EQ(layout.output_arena_offset, test_case.output_arena_offset); + EXPECT_EQ(layout.payload_bytes, test_case.payload_bytes); + EXPECT_EQ(layout.input_desc_tail_offset, 0u); + EXPECT_EQ(layout.input_desc_head_offset, 64u); + EXPECT_EQ(layout.output_desc_tail_offset, 128u); + EXPECT_EQ(layout.output_desc_head_offset, 192u); + EXPECT_EQ(layout.l3_abort_flag_offset, 256u); + EXPECT_EQ(layout.l2_abort_flag_offset, 320u); + EXPECT_EQ(layout.counter_bytes, 384u); + } +} + +TEST(L3L2MessageQueueTest, LayoutRejectsInvalidDepthArenaAndCounterBytes) { + L3L2QueueLayout layout{}; + + EXPECT_FALSE(l3_l2_queue_make_layout(3, 64, 64, &layout)); + EXPECT_FALSE(l3_l2_queue_make_layout((1ull << 30) + 1, 64, 64, &layout)); + EXPECT_FALSE(l3_l2_queue_make_layout(2, 0, 64, &layout)); + EXPECT_FALSE(l3_l2_queue_make_layout(2, 65, 64, &layout)); + + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + EXPECT_FALSE(l3_l2_queue_validate_region(make_desc(&storage, 256, 320), args, &layout)); + EXPECT_FALSE(l3_l2_queue_validate_region(make_desc(&storage, 512, 384), args, &layout)); + EXPECT_TRUE(l3_l2_queue_validate_region(make_desc(&storage, args), args, &layout)); +} + +TEST(L3L2MessageQueueTest, LayoutOverflowFailsClosedWithoutModifyingOutput) { + L3L2QueueLayout layout{ + 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, + }; + const L3L2QueueLayout original = layout; + + EXPECT_FALSE(l3_l2_queue_make_layout(2, std::numeric_limits::max() - 63, 64, &layout)); + + EXPECT_EQ(layout.depth, original.depth); + EXPECT_EQ(layout.input_desc_offset, original.input_desc_offset); + EXPECT_EQ(layout.output_desc_offset, original.output_desc_offset); + EXPECT_EQ(layout.input_arena_offset, original.input_arena_offset); + EXPECT_EQ(layout.output_arena_offset, original.output_arena_offset); + EXPECT_EQ(layout.input_arena_bytes, original.input_arena_bytes); + EXPECT_EQ(layout.output_arena_bytes, original.output_arena_bytes); + EXPECT_EQ(layout.payload_bytes, original.payload_bytes); + EXPECT_EQ(layout.counter_bytes, original.counter_bytes); +} + +TEST(L3L2MessageQueueTest, DescriptorSlotEncodingIsStable) { + static_assert(std::is_standard_layout::value, "descriptor must be POD-like"); + static_assert(std::is_trivially_copyable::value, "descriptor must be fixed-size"); + + EXPECT_EQ(sizeof(L3L2QueueDescSlot), 32u); + EXPECT_EQ(offsetof(L3L2QueueDescSlot, seq), 0u); + EXPECT_EQ(offsetof(L3L2QueueDescSlot, opcode), 8u); + EXPECT_EQ(offsetof(L3L2QueueDescSlot, payload_offset), 16u); + EXPECT_EQ(offsetof(L3L2QueueDescSlot, payload_nbytes), 24u); + + L3L2QueueDescSlot slot{}; + l3_l2_queue_encode_desc(&slot, 7, L3L2QueueOpcode::ERROR, 128, 16); + EXPECT_EQ(slot.seq, 7u); + EXPECT_EQ(slot.opcode, 3u); + EXPECT_EQ(slot.payload_offset, 128u); + EXPECT_EQ(slot.payload_nbytes, 16u); +} + +TEST(L3L2MessageQueueTest, Low32ReconstructionAcceptsWrapAndRejectsImpossibleDeltas) { + uint64_t value = 0xFFFF'FFFFull; + + EXPECT_TRUE(l3_l2_queue_reconstruct_counter(0, 4, &value)); + EXPECT_EQ(value, 0x1'0000'0000ull); + + value = (1ull << 31) - 2; + EXPECT_TRUE(l3_l2_queue_reconstruct_counter(static_cast(0x8000'0001u), 4, &value)); + EXPECT_EQ(value, (1ull << 31) + 1); + + value = 100; + EXPECT_TRUE(l3_l2_queue_reconstruct_counter(104, 4, &value)); + EXPECT_EQ(value, 104u); + + value = 100; + EXPECT_FALSE(l3_l2_queue_reconstruct_counter(99, 4, &value)); + + value = 100; + EXPECT_FALSE(l3_l2_queue_reconstruct_counter(105, 4, &value)); +} + +TEST(L3L2MessageQueueTest, L2InputPeekHandlesZeroByteDescriptorBeforeArenaValidation) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueDescSlot slot{}; + l3_l2_queue_encode_desc(&slot, 1, L3L2QueueOpcode::DATA, 0, 0); + std::memcpy(storage.payload.data() + queue.layout().input_desc_offset, &slot, sizeof(slot)); + storage.counters[0] = 1; + + L3L2QueueInputHandle handle{}; + ASSERT_TRUE(queue.input().try_peek(&handle)) << queue.error().message; + + EXPECT_EQ(handle.seq, 1u); + EXPECT_EQ(handle.opcode, L3L2QueueOpcode::DATA); + EXPECT_EQ(handle.payload_nbytes, 0u); + EXPECT_EQ(handle.payload.gm_addr, 0u); + EXPECT_TRUE(queue.input().release(handle)) << queue.error().message; + EXPECT_EQ(storage.counters[16], 1); +} + +TEST(L3L2MessageQueueTest, L2InputPeekPoisonsZeroByteDescriptorWithNonzeroOffset) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueDescSlot slot{}; + l3_l2_queue_encode_desc(&slot, 1, L3L2QueueOpcode::DATA, 8, 0); + std::memcpy(storage.payload.data() + queue.layout().input_desc_offset, &slot, sizeof(slot)); + storage.counters[0] = 1; + + L3L2QueueInputHandle handle{}; + EXPECT_FALSE(queue.input().try_peek(&handle)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::INVALID_DESCRIPTOR); + EXPECT_EQ(storage.counters[80], 1); +} + +TEST(L3L2MessageQueueTest, L2InputPeekExposesNonzeroPayloadBytes) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + const std::array payload{{0x11, 0x22, 0x33, 0x44}}; + std::memcpy(storage.payload.data() + queue.layout().input_arena_offset, payload.data(), payload.size()); + publish_input_desc( + &storage, queue.layout(), 1, L3L2QueueOpcode::DATA, queue.layout().input_arena_offset, payload.size() + ); + + L3L2QueueInputHandle handle{}; + ASSERT_TRUE(queue.input().try_peek(&handle)) << queue.error().message; + + ASSERT_EQ(handle.payload_nbytes, payload.size()); + const auto *observed = reinterpret_cast(static_cast(handle.payload.gm_addr)); + EXPECT_EQ(std::memcmp(observed, payload.data(), payload.size()), 0); + ASSERT_TRUE(queue.input().release(handle)) << queue.error().message; + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); +} + +TEST(L3L2MessageQueueTest, L2InputPeekAllowsArenaWrapAtExpectedPayloadHead) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 128, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::DATA, queue.layout().input_arena_offset, 80); + L3L2QueueInputHandle first{}; + ASSERT_TRUE(queue.input().try_peek(&first)) << queue.error().message; + ASSERT_TRUE(queue.input().release(first)) << queue.error().message; + + publish_input_desc(&storage, queue.layout(), 2, L3L2QueueOpcode::DATA, queue.layout().input_arena_offset, 64); + L3L2QueueInputHandle second{}; + ASSERT_TRUE(queue.input().try_peek(&second)) << queue.error().message; + + EXPECT_EQ(second.payload_offset, queue.layout().input_arena_offset); + EXPECT_EQ(second.payload_nbytes, 64u); + ASSERT_TRUE(queue.input().release(second)) << queue.error().message; + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); +} + +TEST(L3L2MessageQueueTest, L2InputPeekRejectsPayloadOffsetMismatchBeforeRelease) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 128, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::DATA, queue.layout().input_arena_offset + 64, 16); + + L3L2QueueInputHandle handle{}; + EXPECT_FALSE(queue.input().try_peek(&handle)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::INVALID_DESCRIPTOR); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET)], 0); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 1); +} + +TEST(L3L2MessageQueueTest, L2OutputReservePublishWritesDescriptorAndTail) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueOutputReservation reservation{}; + ASSERT_TRUE(queue.output().try_reserve(16, &reservation)) << queue.error().message; + EXPECT_EQ(reservation.payload_nbytes, 16u); + EXPECT_NE(reservation.payload.gm_addr, 0u); + + ASSERT_TRUE(queue.output().publish(reservation, L3L2QueueOpcode::DATA)) << queue.error().message; + + L3L2QueueDescSlot slot{}; + std::memcpy(&slot, storage.payload.data() + queue.layout().output_desc_offset, sizeof(slot)); + EXPECT_EQ(slot.seq, 1u); + EXPECT_EQ(slot.opcode, 1u); + EXPECT_EQ(slot.payload_nbytes, 16u); + EXPECT_EQ(storage.counters[32], 1); +} + +TEST(L3L2MessageQueueTest, L2OutputReserveReplaysReleasedDescriptorsBeforeReusingArena) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(4, 64, 128); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueOutputReservation first{}; + ASSERT_TRUE(queue.output().try_reserve(80, &first)) << queue.error().message; + ASSERT_EQ(first.payload_offset, queue.layout().output_arena_offset); + ASSERT_TRUE(queue.output().publish(first, L3L2QueueOpcode::DATA)) << queue.error().message; + + storage.counters[48] = 1; + L3L2QueueOutputReservation second{}; + ASSERT_TRUE(queue.output().try_reserve(80, &second)) << queue.error().message; + + EXPECT_EQ(second.payload_offset, queue.layout().output_arena_offset); +} + +TEST(L3L2MessageQueueTest, RemoteAbortObservationDoesNotSetOwnAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + storage.counters[64] = 1; + + EXPECT_EQ(queue.disambiguate_timeout(), L3L2QueueTimeoutStatus::REMOTE_ABORTED); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::REMOTE_ABORTED); + EXPECT_EQ(storage.counters[80], 0); +} + +TEST(L3L2MessageQueueTest, OrdinaryTimeoutDoesNotSetOwnAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + EXPECT_EQ(queue.disambiguate_timeout(), L3L2QueueTimeoutStatus::ORDINARY_TIMEOUT); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, OutputCapacityEqualsDepthAndFullIsNoProgressWithoutAbort) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + for (int i = 0; i < 2; ++i) { + L3L2QueueOutputReservation reservation{}; + ASSERT_TRUE(queue.output().try_reserve(0, &reservation)) << queue.error().message; + ASSERT_TRUE(queue.output().publish(reservation, L3L2QueueOpcode::DATA)) << queue.error().message; + } + L3L2QueueOutputReservation third{}; + EXPECT_FALSE(queue.output().try_reserve(0, &third)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET)], 2); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, FullAndEmptyUseMonotonicCountersNotMaskedIndices) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + for (int i = 0; i < 2; ++i) { + L3L2QueueOutputReservation reservation{}; + ASSERT_TRUE(queue.output().try_reserve(0, &reservation)) << queue.error().message; + ASSERT_TRUE(queue.output().publish(reservation, L3L2QueueOpcode::DATA)) << queue.error().message; + } + storage.counters[counter_index(L3L2_QUEUE_OUTPUT_DESC_HEAD_OFFSET)] = 1; + + L3L2QueueOutputReservation third{}; + ASSERT_TRUE(queue.output().try_reserve(0, &third)) << queue.error().message; + ASSERT_TRUE(queue.output().publish(third, L3L2QueueOpcode::DATA)) << queue.error().message; + + EXPECT_EQ(third.seq, 3u); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET)], 3); + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, OutputReserveTooLargeIsPreMutationNoProgressWithoutAbort) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueOutputReservation reservation{}; + EXPECT_FALSE(queue.output().try_reserve(65, &reservation)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_OUTPUT_DESC_TAIL_OFFSET)], 0); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, OutputPublishApplicationErrorDoesNotSetAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueOutputReservation reservation{}; + ASSERT_TRUE(queue.output().try_reserve(0, &reservation)) << queue.error().message; + ASSERT_TRUE(queue.output().publish(reservation, L3L2QueueOpcode::ERROR)) << queue.error().message; + + L3L2QueueDescSlot slot{}; + std::memcpy(&slot, storage.payload.data() + queue.layout().output_desc_offset, sizeof(slot)); + EXPECT_EQ(slot.opcode, static_cast(L3L2QueueOpcode::ERROR)); + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, OutputPublishStaleReservationPoisonsAndSetsOwnAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + L3L2QueueOutputReservation reservation{}; + ASSERT_TRUE(queue.output().try_reserve(0, &reservation)) << queue.error().message; + ASSERT_TRUE(queue.output().publish(reservation, L3L2QueueOpcode::DATA)) << queue.error().message; + EXPECT_FALSE(queue.output().publish(reservation, L3L2QueueOpcode::DATA)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::OWNERSHIP); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 1); +} + +TEST(L3L2MessageQueueTest, InputApplicationErrorIsNormalMessageAndDoesNotSetAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::ERROR); + + L3L2QueueInputHandle handle{}; + ASSERT_TRUE(queue.input().try_peek(&handle)) << queue.error().message; + EXPECT_EQ(handle.opcode, L3L2QueueOpcode::ERROR); + ASSERT_TRUE(queue.input().release(handle)) << queue.error().message; + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, InputReleaseRejectsCallerMutatedHandleMetadata) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::DATA, queue.layout().input_arena_offset, 16); + + L3L2QueueInputHandle handle{}; + ASSERT_TRUE(queue.input().try_peek(&handle)) << queue.error().message; + handle.payload_nbytes = 0; + + EXPECT_FALSE(queue.input().release(handle)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::OWNERSHIP); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_INPUT_DESC_HEAD_OFFSET)], 0); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 1); +} + +TEST(L3L2MessageQueueTest, InputStopReleaseRejectsLaterPublishedInputAsInvalidState) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::STOP); + + L3L2QueueInputHandle stop{}; + ASSERT_TRUE(queue.input().try_peek(&stop)) << queue.error().message; + ASSERT_TRUE(queue.input().release(stop)) << queue.error().message; + + publish_input_desc(&storage, queue.layout(), 2, L3L2QueueOpcode::DATA); + L3L2QueueInputHandle later{}; + EXPECT_FALSE(queue.input().try_peek(&later)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::INVALID_DESCRIPTOR); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 1); +} + +TEST(L3L2MessageQueueTest, NullInputPeekOutputIsPreMutationRejectionWithoutAbort) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + + EXPECT_FALSE(queue.input().try_peek(nullptr)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 0); +} + +TEST(L3L2MessageQueueTest, InputSecondPeekBeforeReleasePoisonsOwnershipAndSetsOwnAbortFlag) { + RegionStorage storage{}; + L3L2QueueArgs args = make_args(2, 64, 64); + L3L2QueueEndpoint queue(make_desc(&storage, args), args); + ASSERT_EQ(queue.error().kind, L3L2QueueErrorKind::NONE) << queue.error().message; + publish_input_desc(&storage, queue.layout(), 1, L3L2QueueOpcode::DATA); + + L3L2QueueInputHandle handle{}; + ASSERT_TRUE(queue.input().try_peek(&handle)) << queue.error().message; + L3L2QueueInputHandle second{}; + EXPECT_FALSE(queue.input().try_peek(&second)); + + EXPECT_EQ(queue.error().kind, L3L2QueueErrorKind::OWNERSHIP); + EXPECT_EQ(storage.counters[counter_index(L3L2_QUEUE_L2_ABORT_FLAG_OFFSET)], 1); +} + +} // namespace diff --git a/tests/ut/py/test_worker/test_l3_l2_message_queue.py b/tests/ut/py/test_worker/test_l3_l2_message_queue.py new file mode 100644 index 000000000..2a83b04a9 --- /dev/null +++ b/tests/ut/py/test_worker/test_l3_l2_message_queue.py @@ -0,0 +1,707 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +import ctypes +import math +import struct +from multiprocessing.shared_memory import SharedMemory +from typing import Optional + +import pytest +from simpler.l3_l2_message_queue import ( + L3L2_QUEUE_COUNTER_BYTES, + L3L2_QUEUE_DESC_SLOT_BYTES, + L3L2_QUEUE_L2_ABORT_FLAG_OFFSET, + L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, + L3L2QueueMessage, + L3L2QueueOpcode, + make_l3_l2_queue_layout, +) +from simpler.l3_l2_orch_comm import ( + L3L2OrchCommCmd, + L3L2OrchCommRequest, + L3L2OrchCommResponse, + L3L2OrchRegionDesc, + NotifyOp, + WaitCmp, +) +from simpler.orchestrator import Orchestrator +from simpler.task_interface import DataType, Tensor, get_element_size +from simpler.worker import _IDLE, _OFF_STATE, Worker, _buffer_field_addr, _mailbox_store_i32 + + +class _FakeCWorker: + def __init__(self): + self.bootstrap_calls: list[tuple[int, str]] = [] + + def control_l3_l2_orch_comm_init(self, worker_id: int, control_shm_name: str) -> None: + self.bootstrap_calls.append((int(worker_id), str(control_shm_name))) + + +class _FakeCOrch: + def __init__(self): + self._buffers = [] + + def alloc(self, shape, dtype): + nbytes = math.prod(int(x) for x in shape) * int(get_element_size(dtype)) + storage_t = ctypes.c_uint8 * nbytes + storage = storage_t() + self._buffers.append(storage) + return Tensor.make(ctypes.addressof(storage), tuple(int(x) for x in shape), dtype) + + +class _FakeClient: + def __init__(self): + self.requests: list[tuple[L3L2OrchCommRequest, float]] = [] + self.payload_writes: list[tuple[int, bytes]] = [] + self.next_region_id = 1 + self.payload_base = 0x1000_0000 + self.counter_base = 0x2000_0000 + self.payload = bytearray() + self.counters: dict[int, int] = {} + self.peer_abort = False + self.fail_next_cmd: Optional[L3L2OrchCommCmd] = None + + def submit(self, request, timeout_s: float): + self.requests.append((request, timeout_s)) + if self.fail_next_cmd == request.cmd: + self.fail_next_cmd = None + raise RuntimeError(f"injected failure for {request.cmd.name}") + if request.cmd == L3L2OrchCommCmd.ALLOC_REGION: + region_id = self.next_region_id + self.next_region_id += 1 + self.payload = bytearray(int(request.payload_bytes)) + self.counters = {} + return L3L2OrchCommResponse( + status=0, + error_kind=0, + region_id=region_id, + observed_counter=0, + matched=False, + desc=L3L2OrchRegionDesc( + magic_version=0x4C334C3200020000, + region_id=region_id, + payload_base=self.payload_base, + payload_bytes=request.payload_bytes, + counter_base=self.counter_base, + counter_bytes=request.counter_bytes, + ), + message="", + ) + if request.cmd == L3L2OrchCommCmd.PAYLOAD_WRITE: + data = ctypes.string_at(int(request.host_ptr), int(request.payload_bytes)) + self.payload_writes.append( + ( + int(request.payload_offset), + data, + ) + ) + begin = int(request.payload_offset) + self.payload[begin : begin + int(request.payload_bytes)] = data + if request.cmd == L3L2OrchCommCmd.PAYLOAD_READ: + begin = int(request.payload_offset) + data = bytes(self.payload[begin : begin + int(request.payload_bytes)]) + ctypes.memmove(int(request.host_ptr), data, len(data)) + if request.cmd == L3L2OrchCommCmd.SIGNAL_NOTIFY: + offset = int(request.counter_addr) - self.counter_base + if int(request.op) == int(NotifyOp.Add): + self.counters[offset] = int(self.counters.get(offset, 0)) + int(request.counter_operand) + else: + self.counters[offset] = int(request.counter_operand) + if request.cmd == L3L2OrchCommCmd.SIGNAL_TEST: + offset = int(request.counter_addr) - self.counter_base + observed = ( + 1 if self.peer_abort and offset == L3L2_QUEUE_L2_ABORT_FLAG_OFFSET else self.counters.get(offset, 0) + ) + matched = _compare_counter(observed, int(request.counter_operand), int(request.op)) + return L3L2OrchCommResponse( + status=0, + error_kind=0, + region_id=request.region_id, + observed_counter=observed, + matched=matched, + desc=None, + message="", + ) + return L3L2OrchCommResponse( + status=0, + error_kind=0, + region_id=request.region_id, + observed_counter=request.counter_operand, + matched=True, + desc=None, + message="", + ) + + +def _compare_counter(observed: int, operand: int, cmp: int) -> bool: + if cmp == int(WaitCmp.EQ): + return observed == operand + if cmp == int(WaitCmp.NE): + return observed != operand + if cmp == int(WaitCmp.GT): + return observed > operand + if cmp == int(WaitCmp.GE): + return observed >= operand + if cmp == int(WaitCmp.LT): + return observed < operand + if cmp == int(WaitCmp.LE): + return observed <= operand + return False + + +def _make_orchestrator() -> tuple[Orchestrator, Worker, SharedMemory, _FakeClient]: + worker = Worker(level=3, device_ids=[0], platform="a2a3", runtime="tensormap_and_ringbuffer") + shm = SharedMemory(create=True, size=4096) + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + fake_client = _FakeClient() + worker._initialized = True + worker._hierarchical_started = True + worker._worker = _FakeCWorker() + worker._chip_shms = [shm] + worker._make_l3_l2_orch_comm_client = lambda _shm: fake_client + return Orchestrator(_FakeCOrch(), worker), worker, shm, fake_client + + +def _close(worker: Worker, shm: SharedMemory) -> None: + worker._close_l3_l2_orch_comm() + shm.close() + shm.unlink() + + +def _publish_output( + fake_client: _FakeClient, + queue, + *, + seq: int = 1, + payload: bytes = b"", + opcode: int = int(L3L2QueueOpcode.DATA), + payload_offset: Optional[int] = None, +) -> None: + if payload_offset is None: + payload_offset = queue.layout.output_arena_offset if payload else 0 + if payload: + fake_client.payload[payload_offset : payload_offset + len(payload)] = payload + desc = struct.pack("<4Q", seq, int(opcode), payload_offset, len(payload)) + desc_offset = queue.layout.output_desc_offset + ((seq - 1) & (queue.layout.depth - 1)) * L3L2_QUEUE_DESC_SLOT_BYTES + fake_client.payload[desc_offset : desc_offset + L3L2_QUEUE_DESC_SLOT_BYTES] = desc + fake_client.counters[queue.layout.output_desc_tail_offset] = seq + + +def test_layout_rejects_invalid_pr1_parameters(): + invalid_args = [ + (3, 128, 128), + ((1 << 30) + 1, 128, 128), + (4, 0, 128), + (4, 127, 128), + (4, 128, 0), + (4, 128, 127), + ] + + for depth, input_arena_bytes, output_arena_bytes in invalid_args: + with pytest.raises(ValueError): + make_l3_l2_queue_layout(depth, input_arena_bytes, output_arena_bytes) + + +def test_layout_rejects_uint64_overflow_to_match_cpp_helper(): + with pytest.raises(ValueError, match="overflowed uint64"): + make_l3_l2_queue_layout(2, (1 << 64) - 64, 64) + + +@pytest.mark.parametrize( + ("depth", "input_arena_bytes", "output_arena_bytes", "expected"), + [ + ( + 1, + 64, + 64, + { + "output_desc_offset": 32, + "input_arena_offset": 64, + "output_arena_offset": 128, + "payload_bytes": 192, + }, + ), + ( + 4, + 128, + 192, + { + "output_desc_offset": 128, + "input_arena_offset": 256, + "output_arena_offset": 384, + "payload_bytes": 576, + }, + ), + ( + 8, + 192, + 64, + { + "output_desc_offset": 256, + "input_arena_offset": 512, + "output_arena_offset": 704, + "payload_bytes": 768, + }, + ), + ], +) +def test_layout_lockstep_cases_match_cpp_helper_expectations(depth, input_arena_bytes, output_arena_bytes, expected): + layout = make_l3_l2_queue_layout( + depth=depth, + input_arena_bytes=input_arena_bytes, + output_arena_bytes=output_arena_bytes, + ) + + assert layout.input_desc_offset == 0 + assert layout.output_desc_offset == expected["output_desc_offset"] + assert layout.output_desc_offset == depth * L3L2_QUEUE_DESC_SLOT_BYTES + assert layout.input_arena_offset == expected["input_arena_offset"] + assert layout.output_arena_offset == expected["output_arena_offset"] + assert layout.payload_bytes == expected["payload_bytes"] + assert layout.input_arena_offset % 64 == 0 + assert layout.output_arena_offset % 64 == 0 + assert layout.input_desc_tail_offset == 0 + assert layout.input_desc_head_offset == 64 + assert layout.output_desc_tail_offset == 128 + assert layout.output_desc_head_offset == 192 + assert layout.l3_abort_flag_offset == L3L2_QUEUE_L3_ABORT_FLAG_OFFSET + assert layout.l2_abort_flag_offset == L3L2_QUEUE_L2_ABORT_FLAG_OFFSET + assert layout.counter_bytes == L3L2_QUEUE_COUNTER_BYTES + + +def test_create_l3_l2_queue_allocates_region_and_exposes_l2_task_scalars(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=192) + + alloc_req = fake_client.requests[0][0] + assert alloc_req.cmd == L3L2OrchCommCmd.ALLOC_REGION + assert alloc_req.payload_bytes == queue.layout.payload_bytes + assert alloc_req.counter_bytes == L3L2_QUEUE_COUNTER_BYTES + assert queue.l2_task_arg_scalars() == [ + *queue.region.descriptor_scalars(), + queue.magic_version, + 4, + 128, + 192, + queue.layout.payload_bytes, + queue.layout.counter_bytes, + ] + assert fake_client.counters == { + queue.layout.input_desc_tail_offset: 0, + queue.layout.input_desc_head_offset: 0, + queue.layout.output_desc_tail_offset: 0, + queue.layout.output_desc_head_offset: 0, + queue.layout.l3_abort_flag_offset: 0, + queue.layout.l2_abort_flag_offset: 0, + } + finally: + _close(worker, shm) + + +def test_create_l3_l2_queue_frees_region_on_post_region_alloc_failure(): + orch, worker, shm, _fake_client = _make_orchestrator() + original_alloc = orch._o.alloc + + def fail_alloc(_shape, _dtype): + raise RuntimeError("injected alloc failure") + + orch._o.alloc = fail_alloc + try: + with pytest.raises(RuntimeError, match="injected alloc failure"): + orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + + assert len(worker._live_l3_l2_regions) == 1 + assert worker._live_l3_l2_regions[0]._released is True + finally: + orch._o.alloc = original_alloc + _close(worker, shm) + + +def test_zero_byte_enqueue_skips_message_payload_write_and_publishes_descriptor(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + fake_client.requests.clear() + fake_client.payload_writes.clear() + + queue.input.enqueue(None, nbytes=0, timeout=0.001) + + payload_write_offsets = [offset for offset, _data in fake_client.payload_writes] + assert queue.layout.input_arena_offset not in payload_write_offsets + assert queue.layout.input_desc_offset in payload_write_offsets + notify_req = fake_client.requests[-1][0] + assert notify_req.cmd == L3L2OrchCommCmd.SIGNAL_NOTIFY + assert notify_req.op == int(NotifyOp.Set) + assert notify_req.counter_addr == queue.region.descriptor.counter_base + queue.layout.input_desc_tail_offset + assert notify_req.counter_operand == 1 + finally: + _close(worker, shm) + + +def test_enqueue_registered_tensor_uses_fast_path_without_staging(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + host = orch.alloc([16], DataType.UINT8) + fake_client.requests.clear() + fake_client.payload_writes.clear() + + queue.input.enqueue(host, nbytes=16, timeout=0.001) + + payload_write_offsets = [offset for offset, _data in fake_client.payload_writes] + assert queue.layout.input_arena_offset in payload_write_offsets + assert queue.layout.input_desc_offset in payload_write_offsets + assert all(req.cmd != L3L2OrchCommCmd.ALLOC_REGION for req, _timeout in fake_client.requests) + finally: + _close(worker, shm) + + +def test_enqueue_replays_released_descriptors_before_reusing_input_arena(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + first = orch.alloc([80], DataType.UINT8) + second = orch.alloc([80], DataType.UINT8) + + queue.input.enqueue(first, nbytes=80, timeout=0.001) + fake_client.counters[queue.layout.input_desc_head_offset] = 1 + queue.input.enqueue(second, nbytes=80, timeout=0.001) + + payload_offsets = [offset for offset, data in fake_client.payload_writes if len(data) == 80] + assert payload_offsets == [queue.layout.input_arena_offset, queue.layout.input_arena_offset] + finally: + _close(worker, shm) + + +def test_enqueue_rejects_ordinary_host_bytes_before_shared_mutation(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + fake_client.requests.clear() + + with pytest.raises(ValueError, match="registered.*orch.alloc"): + queue.input.enqueue(b"ordinary", nbytes=8, timeout=0.001) + + assert fake_client.requests == [] + assert queue.region.descriptor_scalars()[1] == 1 + finally: + _close(worker, shm) + + +def test_output_read_into_registered_tensor_uses_fast_path_and_release_notifies_head(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"abcdefghijklmnop") + output = orch.alloc([16], DataType.UINT8) + + handle = queue.output.peek(timeout=0.001) + queue.output.read_into(handle, output) + queue.output.release(handle) + + assert ctypes.string_at(int(output.data), 16) == b"abcdefghijklmnop" + assert fake_client.counters[queue.layout.output_desc_head_offset] == 1 + finally: + _close(worker, shm) + + +def test_dequeue_into_reads_and_releases_output(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"abcdefghijklmnop") + output = orch.alloc([16], DataType.UINT8) + + message = queue.output.dequeue_into(output, timeout=0.001) + + assert message.seq == 1 + assert message.opcode == L3L2QueueOpcode.DATA + assert ctypes.string_at(int(output.data), 16) == b"abcdefghijklmnop" + assert fake_client.counters[queue.layout.output_desc_head_offset] == 1 + finally: + _close(worker, shm) + + +def test_output_error_opcode_is_delivered_without_poison(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"error-detail", opcode=int(L3L2QueueOpcode.ERROR)) + output = orch.alloc([12], DataType.UINT8) + + message = queue.output.dequeue_into(output, timeout=0.001) + + assert message.opcode == L3L2QueueOpcode.ERROR + assert ctypes.string_at(int(output.data), 12) == b"error-detail" + assert fake_client.counters[queue.layout.output_desc_head_offset] == 1 + assert fake_client.counters.get(L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, 0) == 0 + finally: + _close(worker, shm) + + +def test_try_dequeue_into_empty_returns_none_without_abort(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + output = orch.alloc([16], DataType.UINT8) + fake_client.requests.clear() + + assert queue.output.try_dequeue_into(output) is None + + assert fake_client.counters.get(queue.layout.output_desc_head_offset, 0) == 0 + assert all( + not ( + req.cmd == L3L2OrchCommCmd.SIGNAL_NOTIFY + and req.counter_addr == queue.region.descriptor.counter_base + L3L2_QUEUE_L3_ABORT_FLAG_OFFSET + ) + for req, _timeout in fake_client.requests + ) + finally: + _close(worker, shm) + + +def test_output_read_rejects_ordinary_buffer_before_shared_mutation(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"abcdefghijklmnop") + handle = queue.output.peek(timeout=0.001) + fake_client.requests.clear() + + with pytest.raises(ValueError, match="registered.*orch.alloc"): + queue.output.read_into(handle, bytearray(16)) + + assert fake_client.requests == [] + assert fake_client.counters.get(queue.layout.output_desc_head_offset, 0) == 0 + finally: + _close(worker, shm) + + +def test_output_release_inactive_handle_poisons_and_sets_l3_abort_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"abcdefghijklmnop") + handle = queue.output.peek(timeout=0.001) + wrong = L3L2QueueMessage(handle.seq + 1, handle.opcode, handle.payload_offset, handle.payload_nbytes) + fake_client.requests.clear() + + with pytest.raises(RuntimeError, match="not active"): + queue.output.release(wrong) + + assert fake_client.counters[L3L2_QUEUE_L3_ABORT_FLAG_OFFSET] == 1 + with pytest.raises(RuntimeError, match="poisoned"): + queue.output.try_peek() + finally: + _close(worker, shm) + + +def test_output_stop_descriptor_poisons_and_sets_l3_abort_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, opcode=int(L3L2QueueOpcode.STOP)) + + with pytest.raises(RuntimeError, match="cannot be STOP"): + queue.output.peek(timeout=0.001) + + assert fake_client.counters[L3L2_QUEUE_L3_ABORT_FLAG_OFFSET] == 1 + finally: + _close(worker, shm) + + +def test_zero_byte_output_descriptor_with_nonzero_offset_poisons_and_sets_l3_abort_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload_offset=queue.layout.output_arena_offset) + + with pytest.raises(RuntimeError, match="zero-byte.*nonzero"): + queue.output.peek(timeout=0.001) + + assert fake_client.counters[L3L2_QUEUE_L3_ABORT_FLAG_OFFSET] == 1 + finally: + _close(worker, shm) + + +def test_zero_byte_output_read_accepts_none_and_skips_payload_read(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output(fake_client, queue, payload=b"") + handle = queue.output.peek(timeout=0.001) + fake_client.requests.clear() + + queue.output.read_into(handle, None) + queue.output.release(handle) + + assert all(req.cmd != L3L2OrchCommCmd.PAYLOAD_READ for req, _timeout in fake_client.requests) + assert fake_client.counters[queue.layout.output_desc_head_offset] == 1 + finally: + _close(worker, shm) + + +def test_try_enqueue_full_queue_returns_false_without_poison_or_publish(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=2, input_arena_bytes=128, output_arena_bytes=128) + queue.input.enqueue(None, nbytes=0, timeout=0.001) + queue.input.enqueue(None, nbytes=0, timeout=0.001) + fake_client.requests.clear() + fake_client.payload_writes.clear() + + assert queue.input.try_enqueue(None, nbytes=0) is False + + assert fake_client.payload_writes == [] + assert fake_client.counters[queue.layout.input_desc_tail_offset] == 2 + assert fake_client.counters.get(L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, 0) == 0 + finally: + _close(worker, shm) + + +def test_enqueue_after_stop_rejects_locally_without_polling_or_abort(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + queue.request_stop(timeout=0.001) + fake_client.requests.clear() + + assert queue.input.try_enqueue(None, nbytes=0) is False + with pytest.raises(RuntimeError, match="stopped"): + queue.input.enqueue(None, nbytes=0, timeout=0.001) + + assert fake_client.requests == [] + assert fake_client.counters.get(L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, 0) == 0 + finally: + _close(worker, shm) + + +def test_try_enqueue_payload_larger_than_arena_returns_false_without_poison_or_publish(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + host = orch.alloc([256], DataType.UINT8) + fake_client.requests.clear() + fake_client.payload_writes.clear() + + assert queue.input.try_enqueue(host, nbytes=256) is False + + assert fake_client.payload_writes == [] + assert fake_client.counters.get(queue.layout.input_desc_tail_offset, 0) == 0 + assert fake_client.counters.get(L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, 0) == 0 + finally: + _close(worker, shm) + + +def test_output_payload_offset_mismatch_poisons_before_payload_read(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + _publish_output( + fake_client, + queue, + payload=b"abcdefghijklmnop", + payload_offset=queue.layout.output_arena_offset + 16, + ) + fake_client.requests.clear() + + with pytest.raises(RuntimeError, match="payload.*mismatch"): + queue.output.peek(timeout=0.001) + + assert fake_client.counters[L3L2_QUEUE_L3_ABORT_FLAG_OFFSET] == 1 + assert all( + not ( + req.cmd == L3L2OrchCommCmd.PAYLOAD_READ and req.payload_offset == queue.layout.output_arena_offset + 16 + ) + for req, _timeout in fake_client.requests + ) + finally: + _close(worker, shm) + + +def test_enqueue_payload_write_failure_sets_l3_abort_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + host = orch.alloc([16], DataType.UINT8) + fake_client.fail_next_cmd = L3L2OrchCommCmd.PAYLOAD_WRITE + + with pytest.raises(RuntimeError, match="injected failure"): + queue.input.enqueue(host, nbytes=16, timeout=0.001) + + assert fake_client.counters[L3L2_QUEUE_L3_ABORT_FLAG_OFFSET] == 1 + with pytest.raises(RuntimeError, match="poisoned"): + queue.input.try_enqueue(None, nbytes=0) + finally: + _close(worker, shm) + + +def test_timeout_without_peer_abort_flag_returns_timeout_and_keeps_queue_live(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + fake_client.requests.clear() + + with pytest.raises(TimeoutError, match="timed out"): + queue.output.peek(timeout=0.0001) + + assert queue.region.descriptor_scalars()[1] == 1 + assert all( + not ( + req.cmd == L3L2OrchCommCmd.SIGNAL_NOTIFY + and req.counter_addr == queue.region.descriptor.counter_base + L3L2_QUEUE_L3_ABORT_FLAG_OFFSET + ) + for req, _timeout in fake_client.requests + ) + finally: + _close(worker, shm) + + +def test_timeout_with_peer_abort_flag_reports_remote_aborted_without_setting_own_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + fake_client.peer_abort = True + fake_client.requests.clear() + + with pytest.raises(RuntimeError, match="remote.*abort"): + queue.output.peek(timeout=0.0001) + + with pytest.raises(RuntimeError, match="remote.*abort"): + queue.input.try_enqueue(None, nbytes=0) + assert all( + not ( + req.cmd == L3L2OrchCommCmd.SIGNAL_NOTIFY + and req.counter_addr == queue.region.descriptor.counter_base + L3L2_QUEUE_L3_ABORT_FLAG_OFFSET + ) + for req, _timeout in fake_client.requests + ) + finally: + _close(worker, shm) + + +def test_expired_queue_rejects_later_operations_without_abort_flag(): + orch, worker, shm, fake_client = _make_orchestrator() + try: + queue = orch.create_l3_l2_queue(worker_id=0, depth=4, input_arena_bytes=128, output_arena_bytes=128) + queue.region._expire() + fake_client.requests.clear() + + with pytest.raises(RuntimeError, match="expired"): + queue.input.try_enqueue(None, nbytes=0) + with pytest.raises(RuntimeError, match="expired"): + queue.output.try_peek() + + assert fake_client.requests == [] + assert fake_client.counters.get(L3L2_QUEUE_L3_ABORT_FLAG_OFFSET, 0) == 0 + finally: + _close(worker, shm)