From 92e25d7c6c669caba502ebc0c0093d6a1bdd6b23 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 01:30:57 +0000 Subject: [PATCH 1/3] Initial plan From 52dd8a8d037b44a2457e7f76e80dde89dd69eb79 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 01:48:14 +0000 Subject: [PATCH 2/3] Implement power-of-two VMem allocator (VMemPow2Allocator) Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/allocators/__init__.py | 3 +- iris/allocators/vmem_pow2_allocator.py | 477 ++++++++++++++++++++ iris/iris.py | 13 +- iris/symmetric_heap.py | 10 +- tests/unittests/test_vmem_pow2_allocator.py | 412 +++++++++++++++++ 5 files changed, 907 insertions(+), 8 deletions(-) create mode 100644 iris/allocators/vmem_pow2_allocator.py create mode 100644 tests/unittests/test_vmem_pow2_allocator.py diff --git a/iris/allocators/__init__.py b/iris/allocators/__init__.py index 460c53d3..0c58494d 100644 --- a/iris/allocators/__init__.py +++ b/iris/allocators/__init__.py @@ -8,5 +8,6 @@ from .base import BaseAllocator from .torch_allocator import TorchAllocator from .vmem_allocator import VMemAllocator +from .vmem_pow2_allocator import VMemPow2Allocator -__all__ = ["BaseAllocator", "TorchAllocator", "VMemAllocator"] +__all__ = ["BaseAllocator", "TorchAllocator", "VMemAllocator", "VMemPow2Allocator"] diff --git a/iris/allocators/vmem_pow2_allocator.py b/iris/allocators/vmem_pow2_allocator.py new file mode 100644 index 00000000..da9df0c7 --- /dev/null +++ b/iris/allocators/vmem_pow2_allocator.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Power-of-two VMem allocator using HIP's virtual memory management APIs. + +This allocator provides efficient reuse of virtual memory allocations by +rounding all requests up to the next power-of-two size class and maintaining +per-class free lists for O(1) allocation and deallocation. +""" + +import os +from typing import Dict, List, Tuple +from threading import Lock + +import torch + +from .base import BaseAllocator +from ..hip import ( + get_allocation_granularity, + get_address_range, + export_dmabuf_handle, + mem_import_from_shareable_handle, + mem_create, + mem_address_reserve, + mem_map, + mem_unmap, + mem_address_free, + mem_release, + mem_set_access, + hipMemAccessDesc, + hipMemLocationTypeDevice, + hipMemAccessFlagsProtReadWrite, +) + + +def _next_pow2(n: int) -> int: + """Round n up to the next power of two (>= 1).""" + if n <= 1: + return 1 + return 1 << (n - 1).bit_length() + + +class VMemPow2Allocator(BaseAllocator): + """ + Power-of-two virtual memory allocator using HIP's VMem APIs. + + All allocation requests are rounded up to the nearest power-of-two size + class. Freed blocks are returned to per-class free lists and reused by + subsequent allocations of the same (or smaller) size class, giving O(1) + amortised alloc and free. + + Physical memory is **never unmapped** when a block is freed; only its + logical ownership changes. This keeps ``get_allocation_segments()`` + correct for the symmetric-heap multi-rank DMA-BUF exchange: every segment + that has ever been allocated is still physically present at the same VA + offset, so peer ranks can import it once and it stays valid. + + Args: + heap_size: Total virtual address space to reserve, in bytes. + device_id: HIP/CUDA device index. + rank: Current process rank. + world_size: Total number of ranks. + va_multiplier: Reserved for future use (currently unused). + """ + + def __init__( + self, + heap_size: int, + device_id: int, + rank: int, + world_size: int, + va_multiplier: float = 1.0, + ): + super().__init__(heap_size, device_id, rank, world_size) + self.va_multiplier = va_multiplier + self.device = torch.device(f"cuda:{device_id}") + self.lock = Lock() + + # HIP allocation granularity (always a power of two, e.g. 2 MiB on MI300X). + self.granularity = get_allocation_granularity(self.device_id) + + # The minimum size class is the granularity itself. + # Because granularity is always a power of two this doubles as min_size_class. + self.min_size_class: int = self.granularity + + # Align the heap to the granularity. + self.aligned_heap_size = (heap_size + self.granularity - 1) & ~(self.granularity - 1) + self.va_size = self.aligned_heap_size + self.base_va: int = mem_address_reserve(self.va_size, self.granularity, 0) + + # Bootstrap: map a minimal chunk at VA base so mem_set_access has + # something to work on (hipMemCreate(0) is invalid). + self.minimal_size: int = self.min_size_class + self.minimal_handle = mem_create(self.minimal_size, self.device_id) + mem_map(self.base_va, self.minimal_size, 0, self.minimal_handle) + + # Access descriptors: allow read/write from every peer device. + self.access_descs: List[hipMemAccessDesc] = [] + for peer_device_id in range(world_size): + desc = hipMemAccessDesc() + desc.location.type = hipMemLocationTypeDevice + desc.location.id = peer_device_id + desc.flags = hipMemAccessFlagsProtReadWrite + self.access_descs.append(desc) + + self.cumulative_mapped_size: int = self.minimal_size + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + # Physical-segment tracking (for get_allocation_segments / cleanup). + # Maps VA-offset -> (size, is_imported, handle, va). + self.allocations: Dict[int, Tuple] = {} + # Ordered list of (offset, size) for get_allocation_segments(). + self.allocation_order: List[Tuple[int, int]] = [] + self._track_allocation(0, self.minimal_size, False, self.minimal_handle, self.base_va) + + # Next available VA offset for a brand-new physical segment. + self.current_offset: int = self.minimal_size + + # Free lists: size_class (power-of-two bytes) -> [(offset, va), …] + self.free_lists: Dict[int, List[Tuple[int, int]]] = {} + + # Logical-allocation tracking: va -> size_class (needed by free()). + self.logical_allocations: Dict[int, int] = {} + + self.world_size = world_size + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _track_allocation(self, offset: int, size: int, is_imported: bool, handle, va: int): + """Record a physical segment for cleanup and segmented DMA-BUF export.""" + self.allocations[offset] = (size, is_imported, handle, va) + self.allocation_order.append((offset, size)) + + def _size_class(self, size_bytes: int) -> int: + """Return the smallest power-of-two size class >= size_bytes and >= granularity.""" + raw = _next_pow2(max(size_bytes, 1)) + return max(raw, self.min_size_class) + + def _map_new_segment(self, size_class: int) -> Tuple[int, int]: + """ + Map a fresh physical segment of ``size_class`` bytes at the next + available VA offset. + + The VA offset is aligned to the HIP allocation granularity (not to + ``size_class``) so that consecutive segments occupy contiguous VA + ranges. Contiguous mapping is required because HIP's + ``hipMemSetAccess`` must be called cumulatively from ``base_va`` and + treats any unmapped gap as an invalid argument. + + Returns: + (offset, va) of the newly mapped segment. + + Raises: + RuntimeError: If the heap VA space is exhausted. + """ + # Align to granularity (not size_class) to keep the VA range gap-free. + aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) + + if aligned_offset + size_class > self.aligned_heap_size: + raise RuntimeError( + f"VMemPow2Allocator: out of VA space. " + f"Need {size_class} bytes at offset {aligned_offset}, " + f"heap size is {self.aligned_heap_size}, " + f"current offset is {self.current_offset}." + ) + + va = self.base_va + aligned_offset + handle = mem_create(size_class, self.device_id) + mem_map(va, size_class, 0, handle) + + new_cumulative = aligned_offset + size_class + if new_cumulative > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self._track_allocation(aligned_offset, size_class, False, handle, va) + self.current_offset = aligned_offset + size_class + return aligned_offset, va + + # ------------------------------------------------------------------ + # BaseAllocator interface + # ------------------------------------------------------------------ + + def get_base_address(self) -> int: + """Return the base virtual address of this allocator's VA range.""" + return self.base_va + + def get_minimum_allocation_size(self) -> int: + """Minimum allocation size in bytes (one size-class / granule).""" + return self.granularity + + def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) -> torch.Tensor: + """ + Allocate a tensor on the power-of-two symmetric heap. + + The physical size is rounded up to the next power-of-two size class + (and is at least ``granularity`` bytes). If a block of the required + size class is already on the free list it is reused; otherwise a new + physical segment is mapped. + + Args: + num_elements: Number of tensor elements. + dtype: PyTorch data type. + alignment: Ignored for this allocator (alignment is provided by + the power-of-two size class itself). + + Returns: + A PyTorch tensor of shape ``(num_elements,)`` backed by symmetric + heap memory. + + Raises: + RuntimeError: If the heap VA space is exhausted. + """ + with self.lock: + element_size = torch.tensor([], dtype=dtype).element_size() + size_bytes = num_elements * element_size + size_class = self._size_class(size_bytes) + + # Try the free list first. + free_entry = None + if size_class in self.free_lists and self.free_lists[size_class]: + free_entry = self.free_lists[size_class].pop() + + if free_entry is not None: + offset, va = free_entry + else: + offset, va = self._map_new_segment(size_class) + + # Record the logical allocation so free() can find the size class. + self.logical_allocations[va] = size_class + + # Expose the physical memory as a PyTorch tensor via __cuda_array_interface__. + interface_size = (size_class // element_size) * element_size + + class _CUDAArrayInterface: + def __init__(self_, ptr: int, nbytes: int, device: torch.device): + self_.ptr = ptr + self_.nbytes = nbytes + self_.device = device + + @property + def __cuda_array_interface__(self_): + return { + "shape": (self_.nbytes,), + "typestr": "|u1", + "data": (self_.ptr, False), + "version": 3, + } + + cuda_array = _CUDAArrayInterface(va, interface_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + full = tensor_bytes.view(dtype) + if num_elements == 0: + return full.narrow(0, 1, 0) + return full.narrow(0, 0, num_elements) + + def free(self, tensor: torch.Tensor) -> None: + """ + Return a tensor's physical block to the appropriate free list. + + The physical memory is **not** unmapped; it remains accessible at its + VA so that peer-rank DMA-BUF imports stay valid. The block is simply + made available for reuse by the next ``allocate`` call of the same + size class. + + Args: + tensor: A tensor previously returned by :meth:`allocate`. + + Raises: + ValueError: If the tensor was not allocated by this allocator. + """ + if tensor.numel() == 0: + # Zero-element tensors share the minimal bootstrap block; skip. + return + + with self.lock: + va = tensor.data_ptr() + if va not in self.logical_allocations: + raise ValueError( + f"VMemPow2Allocator.free(): tensor at VA 0x{va:x} was not " + "allocated by this allocator (or was already freed)." + ) + size_class = self.logical_allocations.pop(va) + offset = va - self.base_va + self.free_lists.setdefault(size_class, []).append((offset, va)) + + def get_device(self) -> torch.device: + """Return the PyTorch device for this allocator.""" + return self.device + + def owns_tensor(self, tensor: torch.Tensor) -> bool: + """ + Return True if *tensor* was allocated from this allocator's heap. + + Args: + tensor: PyTorch tensor to check. + + Returns: + True if the tensor's data pointer lies within the heap VA range. + """ + if not tensor.is_cuda: + return False + if tensor.numel() == 0: + return True + ptr = tensor.data_ptr() + return self.base_va <= ptr < self.base_va + self.aligned_heap_size + + # ------------------------------------------------------------------ + # Symmetric-heap segment API (used by SymmetricHeap.refresh_peer_access) + # ------------------------------------------------------------------ + + def get_allocation_segments(self) -> List[Tuple[int, int, int]]: + """ + Return the ordered list of physical segments for DMA-BUF export. + + Each element is ``(offset, size, va)`` describing one physically-backed + segment that must be exported and imported across ranks. Segments on + the free list are included because they are still physically mapped and + their peer imports must remain valid. + + Returns: + List of ``(offset, size, va)`` tuples in allocation order. + """ + segments = [] + for offset, size in self.allocation_order: + va = self.base_va + offset + segments.append((offset, size, va)) + return segments + + # ------------------------------------------------------------------ + # as_symmetric() support + # ------------------------------------------------------------------ + + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Import an external PyTorch tensor into the symmetric heap. + + This remaps the external tensor's physical memory into the symmetric + heap VA range so that peer ranks can access it via the standard + DMA-BUF exchange. The returned tensor **shares physical memory** with + the original; changes to one are immediately visible in the other. + + Args: + external_tensor: A contiguous CUDA tensor allocated by PyTorch. + + Returns: + A tensor view in the symmetric heap that shares memory with + *external_tensor*. + + Raises: + RuntimeError: If the tensor is not a contiguous CUDA tensor, or + if the heap VA space is exhausted. + """ + with self.lock: + if not external_tensor.is_cuda: + raise RuntimeError("VMemPow2Allocator: can only import CUDA tensors.") + if not external_tensor.is_contiguous(): + raise RuntimeError( + "VMemPow2Allocator: only contiguous tensors can be imported; " + "call .contiguous() before as_symmetric()." + ) + + external_ptr = external_tensor.data_ptr() + alloc_base, alloc_size = get_address_range(external_ptr) + offset_in_alloc = external_ptr - alloc_base + aligned_size = (alloc_size + self.granularity - 1) & ~(self.granularity - 1) + aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) + + if aligned_offset + aligned_size > self.aligned_heap_size: + raise RuntimeError( + f"VMemPow2Allocator: out of VA space for import. " + f"Need {aligned_size} bytes at offset {aligned_offset}, " + f"heap size is {self.aligned_heap_size}." + ) + + dmabuf_fd, export_base, export_size = export_dmabuf_handle(alloc_base, alloc_size) + aligned_export_size = (export_size + self.granularity - 1) & ~(self.granularity - 1) + target_va = self.base_va + aligned_offset + imported_handle = mem_import_from_shareable_handle(dmabuf_fd) + os.close(dmabuf_fd) + + mem_map(target_va, aligned_export_size, 0, imported_handle) + + new_cumulative = aligned_offset + aligned_export_size + if new_cumulative > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + tensor_va = target_va + offset_in_alloc + self._track_allocation(aligned_offset, aligned_export_size, True, imported_handle, target_va) + self.current_offset = aligned_offset + aligned_export_size + + tensor_size = external_tensor.numel() * external_tensor.element_size() + + class _CUDAArrayInterface: + def __init__(self_, ptr: int, nbytes: int, device: torch.device): + self_.ptr = ptr + self_.nbytes = nbytes + self_.device = device + + @property + def __cuda_array_interface__(self_): + return { + "shape": (self_.nbytes,), + "typestr": "|u1", + "data": (self_.ptr, False), + "version": 3, + } + + cuda_array = _CUDAArrayInterface(tensor_va, tensor_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + return tensor_bytes.view(external_tensor.dtype).reshape(external_tensor.shape) + + # ------------------------------------------------------------------ + # Resource management + # ------------------------------------------------------------------ + + def close(self) -> None: + """Release all VMem resources (unmap and free physical handles).""" + if getattr(self, "_closed", False): + return + + with self.lock: + for offset, alloc_info in self.allocations.items(): + if len(alloc_info) == 4: + size, is_imported, handle, va = alloc_info + if handle is not None: + aligned_size = (size + self.granularity - 1) & ~(self.granularity - 1) + mem_unmap(va, aligned_size) + mem_release(handle) + + self.allocations.clear() + self.free_lists.clear() + self.logical_allocations.clear() + + if getattr(self, "base_va", 0): + mem_address_free(self.base_va, self.va_size) + self.base_va = 0 + + self._closed = True + + def __del__(self) -> None: + """Cleanup VMem resources on garbage collection.""" + self.close() + + # ------------------------------------------------------------------ + # Diagnostics + # ------------------------------------------------------------------ + + def stats(self) -> dict: + """ + Return a snapshot of allocator statistics. + + Returns: + A dict with keys: + + * ``heap_size`` – requested heap size in bytes + * ``aligned_heap_size`` – actual VA reservation in bytes + * ``granularity`` – HIP allocation granularity in bytes + * ``current_offset`` – bytes consumed from VA space + * ``num_segments`` – number of physical segments ever mapped + * ``num_live_allocations`` – logical allocations currently in use + * ``free_list_counts`` – dict of {size_class: count} for free lists + """ + with self.lock: + return { + "heap_size": self.heap_size, + "aligned_heap_size": self.aligned_heap_size, + "granularity": self.granularity, + "current_offset": self.current_offset, + "num_segments": len(self.allocation_order), + "num_live_allocations": len(self.logical_allocations), + "free_list_counts": {sc: len(bl) for sc, bl in self.free_lists.items() if bl}, + } diff --git a/iris/iris.py b/iris/iris.py index f0effbb2..a6e9c891 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -78,7 +78,7 @@ class Iris: Args: heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) - allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem" + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem", "vmem_pow2" Example: >>> ctx = iris.iris(heap_size=2**31) # 2GB heap with torch allocator @@ -87,6 +87,9 @@ class Iris: >>> # Use VMem allocator for memory oversubscription >>> ctx = iris.iris(heap_size=2**31, allocator_type="vmem") + + >>> # Use power-of-two VMem allocator for efficient memory reuse + >>> ctx = iris.iris(heap_size=2**31, allocator_type="vmem_pow2") """ def __init__(self, heap_size=1 << 30, allocator_type="torch"): @@ -2399,8 +2402,8 @@ def iris(heap_size=1 << 30, allocator_type="torch"): Args: heap_size (int): Size of the heap in bytes. Defaults to 1GB. - allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem". - Can be overridden with IRIS_ALLOCATOR environment variable. + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem", + "vmem_pow2". Can be overridden with IRIS_ALLOCATOR environment variable. Returns: Iris: An initialized Iris instance. @@ -2413,5 +2416,9 @@ def iris(heap_size=1 << 30, allocator_type="torch"): >>> # Use VMem allocator >>> iris_ctx = iris.iris(2**30, allocator_type="vmem") >>> tensor = iris_ctx.zeros(1024, 1024) + + >>> # Use power-of-two VMem allocator for efficient memory reuse + >>> iris_ctx = iris.iris(2**30, allocator_type="vmem_pow2") + >>> tensor = iris_ctx.zeros(1024, 1024) """ return Iris(heap_size, allocator_type) diff --git a/iris/symmetric_heap.py b/iris/symmetric_heap.py index eef39197..c6f8e901 100644 --- a/iris/symmetric_heap.py +++ b/iris/symmetric_heap.py @@ -12,7 +12,7 @@ import torch import os -from iris.allocators import TorchAllocator, VMemAllocator +from iris.allocators import TorchAllocator, VMemAllocator, VMemPow2Allocator from iris.fd_passing import setup_fd_infrastructure from iris._distributed_helpers import distributed_allgather @@ -24,7 +24,7 @@ class SymmetricHeap: Manages distributed memory with symmetric addressing across ranks, handling all allocator coordination and memory sharing internally. - Supports multiple allocator backends: 'torch' (default) and 'vmem'. + Supports multiple allocator backends: 'torch' (default), 'vmem', and 'vmem_pow2'. """ def __init__( @@ -43,7 +43,7 @@ def __init__( device_id: GPU device ID cur_rank: Current process rank num_ranks: Total number of ranks - allocator_type: Type of allocator ("torch" or "vmem"); default "torch" + allocator_type: Type of allocator ("torch", "vmem", or "vmem_pow2"); default "torch" Raises: ValueError: If allocator_type is not supported @@ -58,8 +58,10 @@ def __init__( self.allocator = TorchAllocator(heap_size, device_id, cur_rank, num_ranks) elif allocator_type == "vmem": self.allocator = VMemAllocator(heap_size, device_id, cur_rank, num_ranks) + elif allocator_type == "vmem_pow2": + self.allocator = VMemPow2Allocator(heap_size, device_id, cur_rank, num_ranks) else: - raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem'") + raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem', 'vmem_pow2'") self.fd_conns = setup_fd_infrastructure(cur_rank, num_ranks) device = self.allocator.get_device() diff --git a/tests/unittests/test_vmem_pow2_allocator.py b/tests/unittests/test_vmem_pow2_allocator.py new file mode 100644 index 00000000..5ce63fc9 --- /dev/null +++ b/tests/unittests/test_vmem_pow2_allocator.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for the power-of-two VMem allocator (VMemPow2Allocator). +""" + +import gc + +import pytest +import torch + +import iris +from iris.allocators.vmem_pow2_allocator import _next_pow2 + + +# --------------------------------------------------------------------------- +# Unit tests for _next_pow2 helper +# --------------------------------------------------------------------------- + + +def test_next_pow2_small(): + assert _next_pow2(1) == 1 + assert _next_pow2(2) == 2 + assert _next_pow2(3) == 4 + assert _next_pow2(4) == 4 + assert _next_pow2(5) == 8 + assert _next_pow2(7) == 8 + assert _next_pow2(8) == 8 + assert _next_pow2(9) == 16 + + +def test_next_pow2_large(): + assert _next_pow2(1 << 20) == 1 << 20 + assert _next_pow2((1 << 20) + 1) == 1 << 21 + + +def test_next_pow2_one(): + assert _next_pow2(0) == 1 + + +# --------------------------------------------------------------------------- +# Allocator creation +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_allocator_creation(): + """VMemPow2Allocator can be created via the iris context.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + assert ctx.cur_rank >= 0 + assert ctx.num_ranks >= 1 + assert ctx.heap_size == 4 << 20 + + from iris.allocators.vmem_pow2_allocator import VMemPow2Allocator + + assert isinstance(ctx.heap.allocator, VMemPow2Allocator) + print(f"Rank {ctx.cur_rank}: VMemPow2Allocator created successfully.") + + +# --------------------------------------------------------------------------- +# Basic allocation +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_basic_allocation(): + """Basic tensor allocation and write.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + tensor = ctx.zeros(1024, dtype=torch.float32) + + assert tensor.shape == (1024,) + assert tensor.device.type == "cuda" + assert torch.all(tensor == 0) + + tensor.fill_(42.0) + assert torch.all(tensor == 42.0) + + print(f"Rank {ctx.cur_rank}: basic allocation test passed.") + + +def test_vmem_pow2_multiple_allocations(): + """Multiple allocations from the same heap.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + + tensors = [] + for i in range(8): + t = ctx.zeros(256, dtype=torch.float32) + t.fill_(float(i)) + tensors.append(t) + + for i, t in enumerate(tensors): + assert torch.all(t == float(i)), f"Tensor {i} has wrong value." + + print(f"Rank {ctx.cur_rank}: multiple allocations test passed.") + + +def test_vmem_pow2_different_dtypes(): + """Allocations with different dtypes.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + + t_f32 = ctx.zeros(128, dtype=torch.float32) + t_f16 = ctx.zeros(128, dtype=torch.float16) + t_i32 = ctx.zeros(128, dtype=torch.int32) + + t_f32.fill_(1.0) + t_f16.fill_(2.0) + t_i32.fill_(3) + + assert torch.all(t_f32 == 1.0) + assert torch.all(t_f16 == 2.0) + assert torch.all(t_i32 == 3) + + print(f"Rank {ctx.cur_rank}: different dtypes test passed.") + + +# --------------------------------------------------------------------------- +# owns_tensor +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_owns_tensor(): + """owns_tensor correctly identifies heap vs. non-heap tensors.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + heap_tensor = ctx.zeros(100, dtype=torch.float32) + assert ctx.heap.allocator.owns_tensor(heap_tensor), "Heap tensor should be owned." + + external = torch.zeros(100, dtype=torch.float32, device=ctx.device) + assert not ctx.heap.allocator.owns_tensor(external), "External tensor should not be owned." + + del heap_tensor, external + torch.cuda.synchronize() + torch.cuda.empty_cache() + print(f"Rank {ctx.cur_rank}: owns_tensor test passed.") + + +# --------------------------------------------------------------------------- +# Free-list reuse +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_free_reuse(): + """After free(), the same physical block is returned on the next allocate().""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + # Allocate a tensor and record its pointer. + t1 = ctx.zeros(512, dtype=torch.float32) + ptr1 = t1.data_ptr() + + # Free it. + allocator.free(t1) + + # The next allocation of the same size class must reuse the same VA. + t2 = ctx.zeros(512, dtype=torch.float32) + ptr2 = t2.data_ptr() + + assert ptr2 == ptr1, f"Expected reuse of VA 0x{ptr1:x}, got 0x{ptr2:x}." + print(f"Rank {ctx.cur_rank}: free-list reuse test passed (VA 0x{ptr1:x}).") + + +def test_vmem_pow2_free_reuse_multiple(): + """Multiple free + realloc cycles for different size classes.""" + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + for num_elems in [64, 256, 1024]: + t = ctx.zeros(num_elems, dtype=torch.float32) + ptr = t.data_ptr() + allocator.free(t) + + t2 = ctx.zeros(num_elems, dtype=torch.float32) + assert t2.data_ptr() == ptr, f"Expected VA reuse for {num_elems} elements." + + print(f"Rank {ctx.cur_rank}: multi-size free-list reuse test passed.") + + +def test_vmem_pow2_free_wrong_tensor_raises(): + """Freeing a non-heap tensor raises ValueError.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + external = torch.zeros(64, dtype=torch.float32, device=ctx.device) + with pytest.raises(ValueError): + allocator.free(external) + + print(f"Rank {ctx.cur_rank}: free() error-check test passed.") + + +# --------------------------------------------------------------------------- +# Heap bases +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_heap_bases(): + """Heap bases are properly initialised.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + assert ctx.heap_bases.shape == (ctx.num_ranks,) + assert int(ctx.heap_bases[ctx.cur_rank].item()) > 0 + + if ctx.num_ranks > 1: + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) + + print(f"Rank {ctx.cur_rank}: heap bases test passed.") + + +# --------------------------------------------------------------------------- +# Granularity alignment +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_granularity_alignment(): + """The aligned heap size must be a multiple of the HIP granularity.""" + from iris.hip import get_allocation_granularity + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + granularity = get_allocation_granularity(ctx.gpu_id) + + assert ctx.heap.allocator.aligned_heap_size % granularity == 0 + print(f"Rank {ctx.cur_rank}: granularity alignment test passed (granularity={granularity}).") + + +# --------------------------------------------------------------------------- +# Size-class rounding +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_size_class_rounding(): + """ + Each allocation is rounded up to the nearest power-of-two >= granularity. + Verify by checking that two allocations of slightly different sizes that + round to the same size class produce VA blocks of the same physical size + and can be interchangeably reused. + """ + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + granularity = allocator.granularity + + # Two sizes that both round up to 2*granularity. + size_a = granularity + 1 # bytes + size_b = granularity + granularity // 2 # bytes + + elem_size = torch.tensor([], dtype=torch.int8).element_size() # 1 + elems_a = size_a // elem_size + elems_b = size_b // elem_size + + # Allocate with size_a, free, then allocate with size_b – should reuse. + t_a = allocator.allocate(elems_a, torch.int8) + ptr_a = t_a.data_ptr() + allocator.free(t_a) + + t_b = allocator.allocate(elems_b, torch.int8) + ptr_b = t_b.data_ptr() + + assert ptr_b == ptr_a, ( + f"Expected VA reuse: both sizes should share size class. ptr_a=0x{ptr_a:x}, ptr_b=0x{ptr_b:x}" + ) + print(f"Rank {ctx.cur_rank}: size-class rounding test passed.") + + +# --------------------------------------------------------------------------- +# stats() +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_stats(): + """stats() returns sensible values.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + s0 = allocator.stats() + assert s0["heap_size"] == 8 << 20 + assert s0["granularity"] > 0 + assert s0["num_live_allocations"] == 0 + + t = ctx.zeros(512, dtype=torch.float32) + s1 = allocator.stats() + assert s1["num_live_allocations"] == 1 + + allocator.free(t) + s2 = allocator.stats() + assert s2["num_live_allocations"] == 0 + + print(f"Rank {ctx.cur_rank}: stats() test passed.") + + +# --------------------------------------------------------------------------- +# get_allocation_segments() +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_allocation_segments_grow(): + """ + get_allocation_segments() grows when new physical segments are mapped + but does NOT grow when free-listed blocks are reused. + """ + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + # Segments after init (bootstrap only). + seg_count_0 = len(allocator.get_allocation_segments()) + + # First allocation -> maps a new segment. + t1 = ctx.zeros(512, dtype=torch.float32) + seg_count_1 = len(allocator.get_allocation_segments()) + assert seg_count_1 == seg_count_0 + 1 + + # Free and reallocate same size -> reuses free-list, no new segment. + allocator.free(t1) + t2 = ctx.zeros(512, dtype=torch.float32) + seg_count_2 = len(allocator.get_allocation_segments()) + assert seg_count_2 == seg_count_1, "Free-list reuse must not create a new segment." + + print(f"Rank {ctx.cur_rank}: allocation segments grow test passed.") + + +# --------------------------------------------------------------------------- +# as_symmetric() (import_external_tensor) +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_import_external_tensor(): + """ + Importing an external tensor gives a symmetric-heap view that shares + physical memory; the original tensor remains valid after ctx is destroyed. + """ + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + original = torch.randn(64, dtype=torch.float32, device=ctx.device) + original_data = original.clone() + + imported = ctx.as_symmetric(original) + assert torch.allclose(imported, original_data), "Imported data should match original." + + # Mutation via imported is visible in original. + imported.fill_(7.0) + assert torch.all(original == 7.0), "Original should see changes through shared memory." + + # Mutation via original is visible in imported. + original.fill_(13.0) + assert torch.all(imported == 13.0), "Imported should see changes through shared memory." + + # Destroy ctx – original must survive. + del ctx, imported + gc.collect() + torch.cuda.synchronize() + + assert torch.all(original == 13.0), "Original tensor should survive ctx destruction." + original.fill_(99.0) + assert torch.all(original == 99.0), "Original tensor should still be writable." + + print("import_external_tensor test passed.") + + +# --------------------------------------------------------------------------- +# Multi-rank tests +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_multirank_heap_bases(): + """Multi-rank: each rank sees all peers' heap bases.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + tensor = ctx.zeros(1024, dtype=torch.float32) + tensor.fill_(float(ctx.cur_rank * 100)) + + assert ctx.heap_bases.shape == (ctx.num_ranks,) + assert int(ctx.heap_bases[ctx.cur_rank].item()) > 0 + + if ctx.num_ranks > 1: + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) + + ctx.barrier() + tensor.fill_(float(ctx.cur_rank * 100)) + ctx.barrier() + assert torch.all(tensor == float(ctx.cur_rank * 100)) + + print(f"Rank {ctx.cur_rank}: multi-rank heap-bases test passed.") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + test_next_pow2_small() + test_next_pow2_large() + test_next_pow2_one() + test_vmem_pow2_allocator_creation() + test_vmem_pow2_basic_allocation() + test_vmem_pow2_multiple_allocations() + test_vmem_pow2_different_dtypes() + test_vmem_pow2_owns_tensor() + test_vmem_pow2_free_reuse() + test_vmem_pow2_free_reuse_multiple() + test_vmem_pow2_heap_bases() + test_vmem_pow2_granularity_alignment() + test_vmem_pow2_stats() + test_vmem_pow2_allocation_segments_grow() + test_vmem_pow2_import_external_tensor() + print("All VMemPow2Allocator tests passed.") From 17edd0be3f569f797febb67e9f9f8d56dad350db Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 02:53:27 +0000 Subject: [PATCH 3/3] Address review feedback: fix all 12 issues in VMemPow2Allocator Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- iris/allocators/vmem_allocator.py | 10 +- iris/allocators/vmem_pow2_allocator.py | 401 +++++++++++++------- iris/symmetric_heap.py | 49 ++- tests/unittests/test_vmem_pow2_allocator.py | 205 +++++++--- 4 files changed, 459 insertions(+), 206 deletions(-) diff --git a/iris/allocators/vmem_allocator.py b/iris/allocators/vmem_allocator.py index e5427edf..f8bb2080 100644 --- a/iris/allocators/vmem_allocator.py +++ b/iris/allocators/vmem_allocator.py @@ -107,14 +107,16 @@ def get_allocation_segments(self): Get list of allocation segments for segmented DMA-BUF export. Returns: - List of (offset, size, va) tuples for each allocation in order. - Each tuple describes one physically-backed segment that needs - to be exported/imported separately. + List of ``(offset, size, va, generation)`` tuples for each + allocation in order. Each tuple describes one physically-backed + segment that needs to be exported/imported separately. + *generation* is always ``0`` for this allocator since it never + remaps a VA with fresh physical memory. """ segments = [] for offset, size in self.allocation_order: va = self.base_va + offset - segments.append((offset, size, va)) + segments.append((offset, size, va, 0)) return segments def get_minimum_allocation_size(self) -> int: diff --git a/iris/allocators/vmem_pow2_allocator.py b/iris/allocators/vmem_pow2_allocator.py index da9df0c7..3ac59867 100644 --- a/iris/allocators/vmem_pow2_allocator.py +++ b/iris/allocators/vmem_pow2_allocator.py @@ -10,6 +10,8 @@ """ import os +import weakref +from collections import deque from typing import Dict, List, Tuple from threading import Lock @@ -41,27 +43,74 @@ def _next_pow2(n: int) -> int: return 1 << (n - 1).bit_length() +# Module-level element-size cache: avoids creating a temporary tensor on every allocation. +_DTYPE_ELEMENT_SIZE: Dict[torch.dtype, int] = {} + + +def _element_size(dtype: torch.dtype) -> int: + """Return the element size in bytes for *dtype*, using a module-level cache.""" + if dtype not in _DTYPE_ELEMENT_SIZE: + _DTYPE_ELEMENT_SIZE[dtype] = torch.empty((), dtype=dtype).element_size() + return _DTYPE_ELEMENT_SIZE[dtype] + + +class _CUDAArrayInterface: + """ + Minimal ``__cuda_array_interface__`` wrapper. + + Lets ``torch.as_tensor`` create a tensor view over a raw device-memory + pointer without going through PyTorch's caching allocator. + """ + + __slots__ = ("ptr", "nbytes", "device") + + def __init__(self, ptr: int, nbytes: int, device: torch.device) -> None: + self.ptr = ptr + self.nbytes = nbytes + self.device = device + + @property + def __cuda_array_interface__(self) -> dict: + return { + "shape": (self.nbytes,), + "typestr": "|u1", + "data": (self.ptr, False), + "version": 3, + } + + class VMemPow2Allocator(BaseAllocator): """ Power-of-two virtual memory allocator using HIP's VMem APIs. All allocation requests are rounded up to the nearest power-of-two size - class. Freed blocks are returned to per-class free lists and reused by - subsequent allocations of the same (or smaller) size class, giving O(1) - amortised alloc and free. - - Physical memory is **never unmapped** when a block is freed; only its - logical ownership changes. This keeps ``get_allocation_segments()`` - correct for the symmetric-heap multi-rank DMA-BUF exchange: every segment - that has ever been allocated is still physically present at the same VA - offset, so peer ranks can import it once and it stays valid. + class (minimum: HIP allocation granularity). Freed blocks are returned to + per-class free lists. When a free-listed VA is reused, the old physical + handle is released (``mem_unmap`` + ``mem_release``) and fresh physical + memory is mapped in its place (``mem_create`` + ``mem_map``). + + Physical memory is therefore renewed at **reuse time** rather than at + free time. This design respects the ROCm constraint that + ``hipMemSetAccess`` must be called cumulatively from ``base_va`` + (see rocm-systems#2667): the VA range always remains contiguous, so the + cumulative access call never encounters unmapped gaps. + + A ``weakref`` finalizer is registered on every returned tensor's storage + so that blocks are automatically returned to the free list when the last + view of a tensor is garbage collected, without requiring explicit calls to + :meth:`free`. + + .. note:: + Callers that release a tensor whose memory may still be in use by an + in-flight CUDA kernel should call ``torch.cuda.synchronize()`` before + dropping the last reference (or before calling :meth:`free`) to avoid + races during the physical remap that happens on next reuse. Args: - heap_size: Total virtual address space to reserve, in bytes. - device_id: HIP/CUDA device index. - rank: Current process rank. - world_size: Total number of ranks. - va_multiplier: Reserved for future use (currently unused). + heap_size: Total virtual address space to reserve, in bytes. + device_id: HIP/CUDA device index. + rank: Current process rank. + world_size: Total number of ranks. """ def __init__( @@ -70,18 +119,13 @@ def __init__( device_id: int, rank: int, world_size: int, - va_multiplier: float = 1.0, ): super().__init__(heap_size, device_id, rank, world_size) - self.va_multiplier = va_multiplier self.device = torch.device(f"cuda:{device_id}") self.lock = Lock() # HIP allocation granularity (always a power of two, e.g. 2 MiB on MI300X). self.granularity = get_allocation_granularity(self.device_id) - - # The minimum size class is the granularity itself. - # Because granularity is always a power of two this doubles as min_size_class. self.min_size_class: int = self.granularity # Align the heap to the granularity. @@ -89,8 +133,7 @@ def __init__( self.va_size = self.aligned_heap_size self.base_va: int = mem_address_reserve(self.va_size, self.granularity, 0) - # Bootstrap: map a minimal chunk at VA base so mem_set_access has - # something to work on (hipMemCreate(0) is invalid). + # Bootstrap: map one minimal segment so we always have something mapped. self.minimal_size: int = self.min_size_class self.minimal_handle = mem_create(self.minimal_size, self.device_id) mem_map(self.base_va, self.minimal_size, 0, self.minimal_handle) @@ -104,59 +147,85 @@ def __init__( desc.flags = hipMemAccessFlagsProtReadWrite self.access_descs.append(desc) + # ROCm: mem_set_access must be called cumulatively from base_va + # (see rocm-systems#2667). We therefore always call it as + # mem_set_access(base_va, cumulative_size, ...) so the range is + # always contiguous from the reservation start. self.cumulative_mapped_size: int = self.minimal_size mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) - # Physical-segment tracking (for get_allocation_segments / cleanup). + # Physical-segment tracking. # Maps VA-offset -> (size, is_imported, handle, va). + # Segments remain physically mapped at all times (even when on the free list). self.allocations: Dict[int, Tuple] = {} # Ordered list of (offset, size) for get_allocation_segments(). self.allocation_order: List[Tuple[int, int]] = [] self._track_allocation(0, self.minimal_size, False, self.minimal_handle, self.base_va) + # Per-offset generation counter. Incremented every time a VA block is + # remapped with fresh physical memory (i.e., each reuse from the free + # list). The symmetric heap uses (offset, size, generation) as the + # deduplication key so peers re-import when physical backing changes. + self._segment_generation: Dict[int, int] = {0: 0} + # Next available VA offset for a brand-new physical segment. self.current_offset: int = self.minimal_size - # Free lists: size_class (power-of-two bytes) -> [(offset, va), …] + # Free lists: size_class (power-of-two bytes) -> [(offset, va), ...] + # Physical memory at these VAs is STILL MAPPED; it will be replaced + # when the VA is next popped from the list (_remap_free_block). self.free_lists: Dict[int, List[Tuple[int, int]]] = {} # Logical-allocation tracking: va -> size_class (needed by free()). self.logical_allocations: Dict[int, int] = {} + # Weak-reference finalizers: va -> weakref.finalize object. + # Stored so we can detach them on explicit free() to prevent double-free. + # NOTE: finalizers are attached to tensor.untyped_storage(), not to the + # Python tensor wrapper, so they survive tensor.reshape() / other view ops. + self._finalizers: Dict[int, weakref.finalize] = {} + + # Pending GC-triggered frees. The GC callback appends here instead of + # acquiring the lock directly, which would deadlock if GC fires while + # the lock is already held on the same thread. Entries are processed at + # the start of allocate() and free(). + # Thread-safety note: ``deque.append`` and ``deque.popleft`` are atomic + # in CPython due to the GIL. This is a documented property of + # ``collections.deque`` for single-element operations and is relied on + # here to allow GC finalizers (running without the allocator lock) to + # safely enqueue work for the owning thread. + self._pending_free: deque = deque() + self.world_size = world_size # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ - def _track_allocation(self, offset: int, size: int, is_imported: bool, handle, va: int): + def _track_allocation(self, offset: int, size: int, is_imported: bool, handle, va: int) -> None: """Record a physical segment for cleanup and segmented DMA-BUF export.""" self.allocations[offset] = (size, is_imported, handle, va) self.allocation_order.append((offset, size)) def _size_class(self, size_bytes: int) -> int: """Return the smallest power-of-two size class >= size_bytes and >= granularity.""" - raw = _next_pow2(max(size_bytes, 1)) - return max(raw, self.min_size_class) + return max(_next_pow2(max(size_bytes, 1)), self.min_size_class) def _map_new_segment(self, size_class: int) -> Tuple[int, int]: """ - Map a fresh physical segment of ``size_class`` bytes at the next - available VA offset. + Map a fresh physical segment of *size_class* bytes at the next + available VA offset and extend cumulative access. - The VA offset is aligned to the HIP allocation granularity (not to - ``size_class``) so that consecutive segments occupy contiguous VA - ranges. Contiguous mapping is required because HIP's - ``hipMemSetAccess`` must be called cumulatively from ``base_va`` and - treats any unmapped gap as an invalid argument. + VA offsets are aligned to the HIP granularity so that consecutive + segments remain contiguous, which is required by the cumulative + ``hipMemSetAccess`` call. Returns: - (offset, va) of the newly mapped segment. + ``(offset, va)`` of the newly mapped segment. Raises: RuntimeError: If the heap VA space is exhausted. """ - # Align to granularity (not size_class) to keep the VA range gap-free. aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) if aligned_offset + size_class > self.aligned_heap_size: @@ -171,15 +240,95 @@ def _map_new_segment(self, size_class: int) -> Tuple[int, int]: handle = mem_create(size_class, self.device_id) mem_map(va, size_class, 0, handle) + # Extend cumulative access to include the new segment. new_cumulative = aligned_offset + size_class if new_cumulative > self.cumulative_mapped_size: self.cumulative_mapped_size = new_cumulative mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) self._track_allocation(aligned_offset, size_class, False, handle, va) + self._segment_generation[aligned_offset] = 0 self.current_offset = aligned_offset + size_class return aligned_offset, va + def _remap_free_block(self, offset: int, va: int, size_class: int) -> None: + """ + Refresh the physical backing of a free-listed VA block. + + The *old* physical handle is released (``mem_unmap`` + ``mem_release``) + and fresh physical memory is created (``mem_create`` + ``mem_map``) at + the same VA. Cumulative access is re-set for the full range after the + physical swap so the remapped segment is accessible. + + The generation counter for *offset* is incremented so the symmetric + heap detects the physical change and re-imports the segment on peers. + """ + alloc_info = self.allocations.get(offset) + if alloc_info is not None: + _old_size, _is_imported, old_handle, _old_va = alloc_info + if old_handle is not None: + mem_unmap(_old_va, _old_size) + mem_release(old_handle) + + new_handle = mem_create(size_class, self.device_id) + mem_map(va, size_class, 0, new_handle) + + # Re-set cumulative access after the physical swap (access is tied to + # the physical mapping and must be restored after unmap/remap). + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self.allocations[offset] = (size_class, False, new_handle, va) + self._segment_generation[offset] = self._segment_generation.get(offset, 0) + 1 + + def _process_pending_frees(self) -> None: + """ + Process GC-triggered frees that were queued to avoid lock re-entry. + + Must be called while holding ``self.lock``. + """ + while self._pending_free: + va, size_class = self._pending_free.popleft() + if va not in self.logical_allocations: + continue # already freed manually + del self.logical_allocations[va] + self._finalizers.pop(va, None) + offset = va - self.base_va + self.free_lists.setdefault(size_class, []).append((offset, va)) + + def _register_finalizer(self, tensor: torch.Tensor, va: int, size_class: int) -> None: + """Register a weak-reference GC finalizer on the tensor's storage. + + The finalizer is attached to ``tensor.untyped_storage()`` (the shared + C++ storage object) rather than to the Python tensor wrapper. This is + essential because callers routinely create new wrappers over the same + storage (e.g. via ``.reshape()``): if the finalizer were on the wrapper + it would fire as soon as the first wrapper was discarded, even while + other wrappers still hold the storage alive. Attaching to the storage + ensures the block is only freed when *every* view has been released. + """ + allocator_ref = weakref.ref(self) + + def _gc_free() -> None: + alloc = allocator_ref() + if alloc is None: + return + # Enqueue rather than locking directly to avoid deadlock when GC + # fires inside a locked section on the same thread. + alloc._pending_free.append((va, size_class)) + + self._finalizers[va] = weakref.finalize(tensor.untyped_storage(), _gc_free) + + def _make_tensor_view(self, va: int, size_class: int, num_elements: int, dtype: torch.dtype) -> torch.Tensor: + """Create a PyTorch tensor view over the given VA-backed device memory.""" + elem_sz = _element_size(dtype) + interface_size = (size_class // elem_sz) * elem_sz + cuda_array = _CUDAArrayInterface(va, interface_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + full = tensor_bytes.view(dtype) + if num_elements == 0: + return full.narrow(0, 1, 0) + return full.narrow(0, 0, num_elements) + # ------------------------------------------------------------------ # BaseAllocator interface # ------------------------------------------------------------------ @@ -196,87 +345,64 @@ def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) """ Allocate a tensor on the power-of-two symmetric heap. - The physical size is rounded up to the next power-of-two size class - (and is at least ``granularity`` bytes). If a block of the required - size class is already on the free list it is reused; otherwise a new - physical segment is mapped. + If a block of the required size class is on the free list it is reused + (old physical memory released, fresh physical memory mapped at the same + VA); otherwise a new segment is mapped at the next available offset. Args: num_elements: Number of tensor elements. dtype: PyTorch data type. - alignment: Ignored for this allocator (alignment is provided by - the power-of-two size class itself). + alignment: Ignored; alignment is guaranteed by the size class. Returns: - A PyTorch tensor of shape ``(num_elements,)`` backed by symmetric - heap memory. + A PyTorch tensor of shape ``(num_elements,)``. Raises: RuntimeError: If the heap VA space is exhausted. """ with self.lock: - element_size = torch.tensor([], dtype=dtype).element_size() - size_bytes = num_elements * element_size + self._process_pending_frees() + + elem_sz = _element_size(dtype) + size_bytes = num_elements * elem_sz size_class = self._size_class(size_bytes) - # Try the free list first. - free_entry = None + # Try the free list first; otherwise map a new segment. if size_class in self.free_lists and self.free_lists[size_class]: - free_entry = self.free_lists[size_class].pop() - - if free_entry is not None: - offset, va = free_entry + offset, va = self.free_lists[size_class].pop() + self._remap_free_block(offset, va, size_class) else: offset, va = self._map_new_segment(size_class) - # Record the logical allocation so free() can find the size class. self.logical_allocations[va] = size_class - - # Expose the physical memory as a PyTorch tensor via __cuda_array_interface__. - interface_size = (size_class // element_size) * element_size - - class _CUDAArrayInterface: - def __init__(self_, ptr: int, nbytes: int, device: torch.device): - self_.ptr = ptr - self_.nbytes = nbytes - self_.device = device - - @property - def __cuda_array_interface__(self_): - return { - "shape": (self_.nbytes,), - "typestr": "|u1", - "data": (self_.ptr, False), - "version": 3, - } - - cuda_array = _CUDAArrayInterface(va, interface_size, self.device) - tensor_bytes = torch.as_tensor(cuda_array, device=self.device) - full = tensor_bytes.view(dtype) - if num_elements == 0: - return full.narrow(0, 1, 0) - return full.narrow(0, 0, num_elements) + tensor = self._make_tensor_view(va, size_class, num_elements, dtype) + self._register_finalizer(tensor, va, size_class) + return tensor def free(self, tensor: torch.Tensor) -> None: """ - Return a tensor's physical block to the appropriate free list. + Return a tensor's VA block to the free list. - The physical memory is **not** unmapped; it remains accessible at its - VA so that peer-rank DMA-BUF imports stay valid. The block is simply - made available for reuse by the next ``allocate`` call of the same - size class. + The block remains physically mapped so that the VA range stays + contiguous (required for the cumulative ``hipMemSetAccess`` call). + Physical memory is released and renewed when the block is next + reused from the free list. + + Zero-element tensors are silently ignored. Args: tensor: A tensor previously returned by :meth:`allocate`. Raises: - ValueError: If the tensor was not allocated by this allocator. + ValueError: If the tensor was not allocated by this allocator or + was already freed. """ if tensor.numel() == 0: - # Zero-element tensors share the minimal bootstrap block; skip. return with self.lock: + self._process_pending_frees() + va = tensor.data_ptr() if va not in self.logical_allocations: raise ValueError( @@ -284,6 +410,12 @@ def free(self, tensor: torch.Tensor) -> None: "allocated by this allocator (or was already freed)." ) size_class = self.logical_allocations.pop(va) + + # Detach the GC finalizer to prevent double-free. + fin = self._finalizers.pop(va, None) + if fin is not None: + fin.detach() + offset = va - self.base_va self.free_lists.setdefault(size_class, []).append((offset, va)) @@ -293,41 +425,50 @@ def get_device(self) -> torch.device: def owns_tensor(self, tensor: torch.Tensor) -> bool: """ - Return True if *tensor* was allocated from this allocator's heap. + Return True if *tensor*'s data pointer lies within this heap's VA range. + + The check is purely address-based; zero-element tensors are checked by + pointer rather than being unconditionally claimed as owned, which would + incorrectly claim externally-created zero-element tensors. Args: tensor: PyTorch tensor to check. - - Returns: - True if the tensor's data pointer lies within the heap VA range. """ if not tensor.is_cuda: return False - if tensor.numel() == 0: - return True ptr = tensor.data_ptr() + # data_ptr() returns 0 for tensors that have no storage (e.g. meta tensors) + # or for certain zero-element tensors on some backends. Such tensors are + # never part of this allocator's heap. + if ptr == 0: + return False return self.base_va <= ptr < self.base_va + self.aligned_heap_size # ------------------------------------------------------------------ # Symmetric-heap segment API (used by SymmetricHeap.refresh_peer_access) # ------------------------------------------------------------------ - def get_allocation_segments(self) -> List[Tuple[int, int, int]]: + def get_allocation_segments(self) -> List[Tuple[int, int, int, int]]: """ Return the ordered list of physical segments for DMA-BUF export. - Each element is ``(offset, size, va)`` describing one physically-backed - segment that must be exported and imported across ranks. Segments on - the free list are included because they are still physically mapped and - their peer imports must remain valid. + All tracked segments are included (both live and free-listed), since + free-listed blocks remain physically mapped. + + Each element is ``(offset, size, va, generation)`` where *generation* + is a monotonically increasing counter bumped each time the block is + remapped with fresh physical memory. The symmetric heap uses + ``(offset, size, generation)`` as the de-duplication key so that + remapped segments are recognised as new and peer ranks re-import them. Returns: - List of ``(offset, size, va)`` tuples in allocation order. + List of ``(offset, size, va, generation)`` tuples in allocation order. """ segments = [] for offset, size in self.allocation_order: va = self.base_va + offset - segments.append((offset, size, va)) + generation = self._segment_generation.get(offset, 0) + segments.append((offset, size, va, generation)) return segments # ------------------------------------------------------------------ @@ -338,10 +479,8 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: """ Import an external PyTorch tensor into the symmetric heap. - This remaps the external tensor's physical memory into the symmetric - heap VA range so that peer ranks can access it via the standard - DMA-BUF exchange. The returned tensor **shares physical memory** with - the original; changes to one are immediately visible in the other. + The returned tensor **shares physical memory** with the original; + changes to one are immediately visible in the other. Args: external_tensor: A contiguous CUDA tensor allocated by PyTorch. @@ -351,8 +490,8 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: *external_tensor*. Raises: - RuntimeError: If the tensor is not a contiguous CUDA tensor, or - if the heap VA space is exhausted. + RuntimeError: If the tensor is not a contiguous CUDA tensor or + the heap VA space is exhausted. """ with self.lock: if not external_tensor.is_cuda: @@ -366,22 +505,26 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: external_ptr = external_tensor.data_ptr() alloc_base, alloc_size = get_address_range(external_ptr) offset_in_alloc = external_ptr - alloc_base - aligned_size = (alloc_size + self.granularity - 1) & ~(self.granularity - 1) + + # Export first so we know the actual export_size before the OOM check. + dmabuf_fd, _export_base, export_size = export_dmabuf_handle(alloc_base, alloc_size) + aligned_export_size = (export_size + self.granularity - 1) & ~(self.granularity - 1) aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) - if aligned_offset + aligned_size > self.aligned_heap_size: + if aligned_offset + aligned_export_size > self.aligned_heap_size: + os.close(dmabuf_fd) raise RuntimeError( f"VMemPow2Allocator: out of VA space for import. " - f"Need {aligned_size} bytes at offset {aligned_offset}, " + f"Need {aligned_export_size} bytes at offset {aligned_offset}, " f"heap size is {self.aligned_heap_size}." ) - dmabuf_fd, export_base, export_size = export_dmabuf_handle(alloc_base, alloc_size) - aligned_export_size = (export_size + self.granularity - 1) & ~(self.granularity - 1) - target_va = self.base_va + aligned_offset - imported_handle = mem_import_from_shareable_handle(dmabuf_fd) - os.close(dmabuf_fd) + try: + imported_handle = mem_import_from_shareable_handle(dmabuf_fd) + finally: + os.close(dmabuf_fd) + target_va = self.base_va + aligned_offset mem_map(target_va, aligned_export_size, 0, imported_handle) new_cumulative = aligned_offset + aligned_export_size @@ -391,25 +534,10 @@ def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: tensor_va = target_va + offset_in_alloc self._track_allocation(aligned_offset, aligned_export_size, True, imported_handle, target_va) + self._segment_generation[aligned_offset] = 0 self.current_offset = aligned_offset + aligned_export_size tensor_size = external_tensor.numel() * external_tensor.element_size() - - class _CUDAArrayInterface: - def __init__(self_, ptr: int, nbytes: int, device: torch.device): - self_.ptr = ptr - self_.nbytes = nbytes - self_.device = device - - @property - def __cuda_array_interface__(self_): - return { - "shape": (self_.nbytes,), - "typestr": "|u1", - "data": (self_.ptr, False), - "version": 3, - } - cuda_array = _CUDAArrayInterface(tensor_va, tensor_size, self.device) tensor_bytes = torch.as_tensor(cuda_array, device=self.device) return tensor_bytes.view(external_tensor.dtype).reshape(external_tensor.shape) @@ -419,18 +547,23 @@ def __cuda_array_interface__(self_): # ------------------------------------------------------------------ def close(self) -> None: - """Release all VMem resources (unmap and free physical handles).""" + """Release all VMem resources (unmap, release handles, free VA range).""" if getattr(self, "_closed", False): return with self.lock: + # Detach all GC finalizers so they cannot fire after close(). + for fin in self._finalizers.values(): + fin.detach() + self._finalizers.clear() + self._pending_free.clear() + + # Release all physical segments (both live and free-listed). for offset, alloc_info in self.allocations.items(): - if len(alloc_info) == 4: - size, is_imported, handle, va = alloc_info - if handle is not None: - aligned_size = (size + self.granularity - 1) & ~(self.granularity - 1) - mem_unmap(va, aligned_size) - mem_release(handle) + _size, _is_imported, handle, va = alloc_info + if handle is not None: + mem_unmap(va, _size) + mem_release(handle) self.allocations.clear() self.free_lists.clear() diff --git a/iris/symmetric_heap.py b/iris/symmetric_heap.py index c6f8e901..1b7b13f8 100644 --- a/iris/symmetric_heap.py +++ b/iris/symmetric_heap.py @@ -162,6 +162,7 @@ def refresh_peer_access(self): export_dmabuf_handle, mem_import_from_shareable_handle, mem_map, + mem_unmap, mem_set_access, mem_address_reserve, hipMemAccessDesc, @@ -199,9 +200,11 @@ def refresh_peer_access(self): my_segments = self.allocator.get_allocation_segments() my_exported_fds = [] - for offset, size, va in my_segments: + for seg in my_segments: + offset, size, va = seg[0], seg[1], seg[2] + generation = seg[3] if len(seg) > 3 else 0 dmabuf_fd, export_base, export_size = export_dmabuf_handle(va, size) - my_exported_fds.append((dmabuf_fd, export_size, offset)) + my_exported_fds.append((dmabuf_fd, export_size, offset, generation)) access_desc = hipMemAccessDesc() access_desc.location.type = hipMemLocationTypeDevice @@ -222,7 +225,7 @@ def refresh_peer_access(self): peer_va_base = self._peer_va_ranges[peer] peer_fds = [] - for seg_idx, (my_fd, my_size, my_offset) in enumerate(my_exported_fds): + for my_fd, my_size, my_offset, my_gen in my_exported_fds: # Exchange FDs (higher rank sends first to avoid deadlock) if self.cur_rank > peer: send_fd(sock, my_fd) @@ -231,25 +234,38 @@ def refresh_peer_access(self): peer_fd, _ = recv_fd(sock) send_fd(sock, my_fd) - peer_fds.append((peer_fd, my_size, my_offset)) + peer_fds.append((peer_fd, my_size, my_offset, my_gen)) if not hasattr(self, "_peer_cumulative_sizes"): self._peer_cumulative_sizes = {} cumulative_size = self._peer_cumulative_sizes.get(peer, 0) - if not hasattr(self, "_peer_imported_segments"): - self._peer_imported_segments = {} - if peer not in self._peer_imported_segments: - self._peer_imported_segments[peer] = set() - - for peer_fd, segment_size, offset in peer_fds: - segment_key = (offset, segment_size) - if segment_key in self._peer_imported_segments[peer]: + # _peer_segment_generations maps (offset, size) -> generation for each + # segment already imported from this peer. When the generation changes + # the VA has been remapped with new physical memory and must be + # re-imported (unmap old handle, map new handle). + if not hasattr(self, "_peer_segment_generations"): + self._peer_segment_generations = {} + if peer not in self._peer_segment_generations: + self._peer_segment_generations[peer] = {} + + for peer_fd, segment_size, offset, generation in peer_fds: + seg_key = (offset, segment_size) + known_gen = self._peer_segment_generations[peer].get(seg_key) + + if known_gen == generation: + # Already imported this exact physical mapping; skip. import os os.close(peer_fd) continue + if known_gen is not None: + # Physical backing has changed (generation bumped after free+remap). + # Unmap the stale peer mapping before importing the new one. + peer_va = peer_va_base + offset + mem_unmap(peer_va, segment_size) + imported_handle = mem_import_from_shareable_handle(peer_fd) import os @@ -257,17 +273,22 @@ def refresh_peer_access(self): peer_va = peer_va_base + offset mem_map(peer_va, segment_size, 0, imported_handle) - self._peer_imported_segments[peer].add(segment_key) + self._peer_segment_generations[peer][seg_key] = generation new_cumulative = offset + segment_size if new_cumulative > cumulative_size: cumulative_size = new_cumulative mem_set_access(peer_va_base, cumulative_size, access_desc) + elif known_gen is not None: + # Physical backing changed but VA offset is already within the + # cumulative range (no extension needed). Re-set access for + # just this segment so the new physical mapping is accessible. + mem_set_access(peer_va, segment_size, access_desc) self._peer_cumulative_sizes[peer] = cumulative_size self.heap_bases[peer] = peer_va_base - for fd, _, _ in my_exported_fds: + for fd, _sz, _off, _gen in my_exported_fds: import os os.close(fd) diff --git a/tests/unittests/test_vmem_pow2_allocator.py b/tests/unittests/test_vmem_pow2_allocator.py index 5ce63fc9..d0fc258a 100644 --- a/tests/unittests/test_vmem_pow2_allocator.py +++ b/tests/unittests/test_vmem_pow2_allocator.py @@ -6,6 +6,7 @@ """ import gc +import threading import pytest import torch @@ -55,7 +56,6 @@ def test_vmem_pow2_allocator_creation(): from iris.allocators.vmem_pow2_allocator import VMemPow2Allocator assert isinstance(ctx.heap.allocator, VMemPow2Allocator) - print(f"Rank {ctx.cur_rank}: VMemPow2Allocator created successfully.") # --------------------------------------------------------------------------- @@ -76,8 +76,6 @@ def test_vmem_pow2_basic_allocation(): tensor.fill_(42.0) assert torch.all(tensor == 42.0) - print(f"Rank {ctx.cur_rank}: basic allocation test passed.") - def test_vmem_pow2_multiple_allocations(): """Multiple allocations from the same heap.""" @@ -92,8 +90,6 @@ def test_vmem_pow2_multiple_allocations(): for i, t in enumerate(tensors): assert torch.all(t == float(i)), f"Tensor {i} has wrong value." - print(f"Rank {ctx.cur_rank}: multiple allocations test passed.") - def test_vmem_pow2_different_dtypes(): """Allocations with different dtypes.""" @@ -111,8 +107,6 @@ def test_vmem_pow2_different_dtypes(): assert torch.all(t_f16 == 2.0) assert torch.all(t_i32 == 3) - print(f"Rank {ctx.cur_rank}: different dtypes test passed.") - # --------------------------------------------------------------------------- # owns_tensor @@ -129,10 +123,15 @@ def test_vmem_pow2_owns_tensor(): external = torch.zeros(100, dtype=torch.float32, device=ctx.device) assert not ctx.heap.allocator.owns_tensor(external), "External tensor should not be owned." - del heap_tensor, external + # Zero-element tensors NOT from the heap must NOT be claimed as owned. + external_empty = torch.zeros(0, dtype=torch.float32, device=ctx.device) + assert not ctx.heap.allocator.owns_tensor(external_empty), ( + "External zero-element tensor must not be claimed as owned." + ) + + del heap_tensor, external, external_empty torch.cuda.synchronize() torch.cuda.empty_cache() - print(f"Rank {ctx.cur_rank}: owns_tensor test passed.") # --------------------------------------------------------------------------- @@ -141,27 +140,23 @@ def test_vmem_pow2_owns_tensor(): def test_vmem_pow2_free_reuse(): - """After free(), the same physical block is returned on the next allocate().""" + """After free(), the same VA is returned on the next allocate() of the same size class.""" ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator - # Allocate a tensor and record its pointer. t1 = ctx.zeros(512, dtype=torch.float32) ptr1 = t1.data_ptr() - # Free it. allocator.free(t1) - # The next allocation of the same size class must reuse the same VA. t2 = ctx.zeros(512, dtype=torch.float32) ptr2 = t2.data_ptr() - assert ptr2 == ptr1, f"Expected reuse of VA 0x{ptr1:x}, got 0x{ptr2:x}." - print(f"Rank {ctx.cur_rank}: free-list reuse test passed (VA 0x{ptr1:x}).") + assert ptr2 == ptr1, f"Expected VA reuse 0x{ptr1:x}, got 0x{ptr2:x}." def test_vmem_pow2_free_reuse_multiple(): - """Multiple free + realloc cycles for different size classes.""" + """Multiple free + realloc cycles for different size classes all reuse VAs.""" ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator @@ -173,11 +168,9 @@ def test_vmem_pow2_free_reuse_multiple(): t2 = ctx.zeros(num_elems, dtype=torch.float32) assert t2.data_ptr() == ptr, f"Expected VA reuse for {num_elems} elements." - print(f"Rank {ctx.cur_rank}: multi-size free-list reuse test passed.") - def test_vmem_pow2_free_wrong_tensor_raises(): - """Freeing a non-heap tensor raises ValueError.""" + """Freeing a tensor not allocated by this allocator raises ValueError.""" ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator @@ -185,7 +178,27 @@ def test_vmem_pow2_free_wrong_tensor_raises(): with pytest.raises(ValueError): allocator.free(external) - print(f"Rank {ctx.cur_rank}: free() error-check test passed.") + +# --------------------------------------------------------------------------- +# GC-based auto-free +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_gc_auto_free(): + """Tensors that go out of scope are automatically returned to the free list.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + def alloc_and_drop(): + t = ctx.zeros(512, dtype=torch.float32) + return t.data_ptr() # tensor dies at function exit (CPython refcount → 0) + + ptr = alloc_and_drop() + gc.collect() # ensure finalizer has run across all Python implementations + + # The next allocation of the same size class must reuse the freed VA. + t2 = ctx.zeros(512, dtype=torch.float32) + assert t2.data_ptr() == ptr, f"Expected GC-freed VA 0x{ptr:x} to be reused, got 0x{t2.data_ptr():x}." # --------------------------------------------------------------------------- @@ -206,8 +219,6 @@ def test_vmem_pow2_heap_bases(): assert int(ctx.heap_bases[peer].item()) > 0 assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) - print(f"Rank {ctx.cur_rank}: heap bases test passed.") - # --------------------------------------------------------------------------- # Granularity alignment @@ -225,7 +236,6 @@ def test_vmem_pow2_granularity_alignment(): granularity = get_allocation_granularity(ctx.gpu_id) assert ctx.heap.allocator.aligned_heap_size % granularity == 0 - print(f"Rank {ctx.cur_rank}: granularity alignment test passed (granularity={granularity}).") # --------------------------------------------------------------------------- @@ -235,24 +245,20 @@ def test_vmem_pow2_granularity_alignment(): def test_vmem_pow2_size_class_rounding(): """ - Each allocation is rounded up to the nearest power-of-two >= granularity. - Verify by checking that two allocations of slightly different sizes that - round to the same size class produce VA blocks of the same physical size - and can be interchangeably reused. + Two allocation sizes that map to the same power-of-two size class can + interchangeably reuse each other's freed VA block. """ ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator granularity = allocator.granularity - # Two sizes that both round up to 2*granularity. + # Both sizes round up to 2*granularity. size_a = granularity + 1 # bytes size_b = granularity + granularity // 2 # bytes - elem_size = torch.tensor([], dtype=torch.int8).element_size() # 1 - elems_a = size_a // elem_size - elems_b = size_b // elem_size + elems_a = size_a # dtype=torch.int8 → element_size == 1 + elems_b = size_b - # Allocate with size_a, free, then allocate with size_b – should reuse. t_a = allocator.allocate(elems_a, torch.int8) ptr_a = t_a.data_ptr() allocator.free(t_a) @@ -260,10 +266,7 @@ def test_vmem_pow2_size_class_rounding(): t_b = allocator.allocate(elems_b, torch.int8) ptr_b = t_b.data_ptr() - assert ptr_b == ptr_a, ( - f"Expected VA reuse: both sizes should share size class. ptr_a=0x{ptr_a:x}, ptr_b=0x{ptr_b:x}" - ) - print(f"Rank {ctx.cur_rank}: size-class rounding test passed.") + assert ptr_b == ptr_a, f"Expected VA reuse: both sizes share size class. ptr_a=0x{ptr_a:x}, ptr_b=0x{ptr_b:x}" # --------------------------------------------------------------------------- @@ -272,7 +275,7 @@ def test_vmem_pow2_size_class_rounding(): def test_vmem_pow2_stats(): - """stats() returns sensible values.""" + """stats() returns sensible values before, during, and after an allocation.""" ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator @@ -289,37 +292,130 @@ def test_vmem_pow2_stats(): s2 = allocator.stats() assert s2["num_live_allocations"] == 0 - print(f"Rank {ctx.cur_rank}: stats() test passed.") - # --------------------------------------------------------------------------- -# get_allocation_segments() +# get_allocation_segments() and generation counter # --------------------------------------------------------------------------- def test_vmem_pow2_allocation_segments_grow(): """ get_allocation_segments() grows when new physical segments are mapped - but does NOT grow when free-listed blocks are reused. + and does NOT grow when free-listed VAs are reused (remap same entry). """ ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") allocator = ctx.heap.allocator - # Segments after init (bootstrap only). seg_count_0 = len(allocator.get_allocation_segments()) - # First allocation -> maps a new segment. + # First allocation: new physical segment. t1 = ctx.zeros(512, dtype=torch.float32) seg_count_1 = len(allocator.get_allocation_segments()) assert seg_count_1 == seg_count_0 + 1 - # Free and reallocate same size -> reuses free-list, no new segment. + # Free + reallocate: reuse VA, no new entry in allocation_order. allocator.free(t1) t2 = ctx.zeros(512, dtype=torch.float32) seg_count_2 = len(allocator.get_allocation_segments()) - assert seg_count_2 == seg_count_1, "Free-list reuse must not create a new segment." + assert seg_count_2 == seg_count_1, "Free-list reuse must not create a new segment entry." + + +def test_vmem_pow2_generation_increments_on_remap(): + """Generation counter increases when a freed VA is remapped.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + va = t.data_ptr() + offset = va - allocator.base_va + + gen_before = allocator._segment_generation[offset] - print(f"Rank {ctx.cur_rank}: allocation segments grow test passed.") + allocator.free(t) + _ = ctx.zeros(512, dtype=torch.float32) # remap same offset + + gen_after = allocator._segment_generation[offset] + assert gen_after == gen_before + 1, "Generation must increment on remap." + + +# --------------------------------------------------------------------------- +# OOM / heap exhaustion +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_oom(): + """Allocating beyond the VA space raises RuntimeError.""" + # A heap of exactly 4 granules; bootstrap uses 1, so we have 3 left. + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + tensors = [] + with pytest.raises(RuntimeError, match="out of VA space"): + while True: + tensors.append(allocator.allocate(1, torch.int8)) + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_thread_safety(): + """Concurrent alloc/free from multiple threads does not corrupt state.""" + ctx = iris.iris(256 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + errors: list = [] + + def worker(): + try: + for _ in range(20): + t = allocator.allocate(16, torch.float32) + allocator.free(t) + except Exception as exc: # noqa: BLE001 – capture any error from any thread + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for th in threads: + th.start() + for th in threads: + th.join() + + assert not errors, f"Thread-safety errors: {errors}" + + +# --------------------------------------------------------------------------- +# close() and resource cleanup +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_close(): + """close() releases all resources and is idempotent.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + allocator.free(t) + + allocator.close() + assert allocator._closed + assert allocator.base_va == 0 + + # close() must be safe to call multiple times. + allocator.close() + assert allocator._closed + + +def test_vmem_pow2_close_disables_finalizers(): + """close() detaches GC finalizers so they cannot run after the allocator is gone.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + va = t.data_ptr() + + assert va in allocator._finalizers + allocator.close() + assert not allocator._finalizers, "All finalizers must be cleared by close()." # --------------------------------------------------------------------------- @@ -340,15 +436,12 @@ def test_vmem_pow2_import_external_tensor(): imported = ctx.as_symmetric(original) assert torch.allclose(imported, original_data), "Imported data should match original." - # Mutation via imported is visible in original. imported.fill_(7.0) assert torch.all(original == 7.0), "Original should see changes through shared memory." - # Mutation via original is visible in imported. original.fill_(13.0) assert torch.all(imported == 13.0), "Imported should see changes through shared memory." - # Destroy ctx – original must survive. del ctx, imported gc.collect() torch.cuda.synchronize() @@ -357,8 +450,6 @@ def test_vmem_pow2_import_external_tensor(): original.fill_(99.0) assert torch.all(original == 99.0), "Original tensor should still be writable." - print("import_external_tensor test passed.") - # --------------------------------------------------------------------------- # Multi-rank tests @@ -366,7 +457,7 @@ def test_vmem_pow2_import_external_tensor(): def test_vmem_pow2_multirank_heap_bases(): - """Multi-rank: each rank sees all peers' heap bases.""" + """Multi-rank: each rank sees all peers' heap bases after setup.""" ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") tensor = ctx.zeros(1024, dtype=torch.float32) @@ -386,8 +477,6 @@ def test_vmem_pow2_multirank_heap_bases(): ctx.barrier() assert torch.all(tensor == float(ctx.cur_rank * 100)) - print(f"Rank {ctx.cur_rank}: multi-rank heap-bases test passed.") - # --------------------------------------------------------------------------- # Entry point @@ -404,9 +493,17 @@ def test_vmem_pow2_multirank_heap_bases(): test_vmem_pow2_owns_tensor() test_vmem_pow2_free_reuse() test_vmem_pow2_free_reuse_multiple() + test_vmem_pow2_gc_auto_free() test_vmem_pow2_heap_bases() test_vmem_pow2_granularity_alignment() + test_vmem_pow2_size_class_rounding() test_vmem_pow2_stats() test_vmem_pow2_allocation_segments_grow() + test_vmem_pow2_generation_increments_on_remap() + test_vmem_pow2_oom() + test_vmem_pow2_thread_safety() + test_vmem_pow2_close() + test_vmem_pow2_close_disables_finalizers() test_vmem_pow2_import_external_tensor() + test_vmem_pow2_multirank_heap_bases() print("All VMemPow2Allocator tests passed.")