diff --git a/iris/allocators/__init__.py b/iris/allocators/__init__.py index 460c53d3..0c58494d 100644 --- a/iris/allocators/__init__.py +++ b/iris/allocators/__init__.py @@ -8,5 +8,6 @@ from .base import BaseAllocator from .torch_allocator import TorchAllocator from .vmem_allocator import VMemAllocator +from .vmem_pow2_allocator import VMemPow2Allocator -__all__ = ["BaseAllocator", "TorchAllocator", "VMemAllocator"] +__all__ = ["BaseAllocator", "TorchAllocator", "VMemAllocator", "VMemPow2Allocator"] diff --git a/iris/allocators/vmem_allocator.py b/iris/allocators/vmem_allocator.py index e5427edf..f8bb2080 100644 --- a/iris/allocators/vmem_allocator.py +++ b/iris/allocators/vmem_allocator.py @@ -107,14 +107,16 @@ def get_allocation_segments(self): Get list of allocation segments for segmented DMA-BUF export. Returns: - List of (offset, size, va) tuples for each allocation in order. - Each tuple describes one physically-backed segment that needs - to be exported/imported separately. + List of ``(offset, size, va, generation)`` tuples for each + allocation in order. Each tuple describes one physically-backed + segment that needs to be exported/imported separately. + *generation* is always ``0`` for this allocator since it never + remaps a VA with fresh physical memory. """ segments = [] for offset, size in self.allocation_order: va = self.base_va + offset - segments.append((offset, size, va)) + segments.append((offset, size, va, 0)) return segments def get_minimum_allocation_size(self) -> int: diff --git a/iris/allocators/vmem_pow2_allocator.py b/iris/allocators/vmem_pow2_allocator.py new file mode 100644 index 00000000..3ac59867 --- /dev/null +++ b/iris/allocators/vmem_pow2_allocator.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Power-of-two VMem allocator using HIP's virtual memory management APIs. + +This allocator provides efficient reuse of virtual memory allocations by +rounding all requests up to the next power-of-two size class and maintaining +per-class free lists for O(1) allocation and deallocation. +""" + +import os +import weakref +from collections import deque +from typing import Dict, List, Tuple +from threading import Lock + +import torch + +from .base import BaseAllocator +from ..hip import ( + get_allocation_granularity, + get_address_range, + export_dmabuf_handle, + mem_import_from_shareable_handle, + mem_create, + mem_address_reserve, + mem_map, + mem_unmap, + mem_address_free, + mem_release, + mem_set_access, + hipMemAccessDesc, + hipMemLocationTypeDevice, + hipMemAccessFlagsProtReadWrite, +) + + +def _next_pow2(n: int) -> int: + """Round n up to the next power of two (>= 1).""" + if n <= 1: + return 1 + return 1 << (n - 1).bit_length() + + +# Module-level element-size cache: avoids creating a temporary tensor on every allocation. +_DTYPE_ELEMENT_SIZE: Dict[torch.dtype, int] = {} + + +def _element_size(dtype: torch.dtype) -> int: + """Return the element size in bytes for *dtype*, using a module-level cache.""" + if dtype not in _DTYPE_ELEMENT_SIZE: + _DTYPE_ELEMENT_SIZE[dtype] = torch.empty((), dtype=dtype).element_size() + return _DTYPE_ELEMENT_SIZE[dtype] + + +class _CUDAArrayInterface: + """ + Minimal ``__cuda_array_interface__`` wrapper. + + Lets ``torch.as_tensor`` create a tensor view over a raw device-memory + pointer without going through PyTorch's caching allocator. + """ + + __slots__ = ("ptr", "nbytes", "device") + + def __init__(self, ptr: int, nbytes: int, device: torch.device) -> None: + self.ptr = ptr + self.nbytes = nbytes + self.device = device + + @property + def __cuda_array_interface__(self) -> dict: + return { + "shape": (self.nbytes,), + "typestr": "|u1", + "data": (self.ptr, False), + "version": 3, + } + + +class VMemPow2Allocator(BaseAllocator): + """ + Power-of-two virtual memory allocator using HIP's VMem APIs. + + All allocation requests are rounded up to the nearest power-of-two size + class (minimum: HIP allocation granularity). Freed blocks are returned to + per-class free lists. When a free-listed VA is reused, the old physical + handle is released (``mem_unmap`` + ``mem_release``) and fresh physical + memory is mapped in its place (``mem_create`` + ``mem_map``). + + Physical memory is therefore renewed at **reuse time** rather than at + free time. This design respects the ROCm constraint that + ``hipMemSetAccess`` must be called cumulatively from ``base_va`` + (see rocm-systems#2667): the VA range always remains contiguous, so the + cumulative access call never encounters unmapped gaps. + + A ``weakref`` finalizer is registered on every returned tensor's storage + so that blocks are automatically returned to the free list when the last + view of a tensor is garbage collected, without requiring explicit calls to + :meth:`free`. + + .. note:: + Callers that release a tensor whose memory may still be in use by an + in-flight CUDA kernel should call ``torch.cuda.synchronize()`` before + dropping the last reference (or before calling :meth:`free`) to avoid + races during the physical remap that happens on next reuse. + + Args: + heap_size: Total virtual address space to reserve, in bytes. + device_id: HIP/CUDA device index. + rank: Current process rank. + world_size: Total number of ranks. + """ + + def __init__( + self, + heap_size: int, + device_id: int, + rank: int, + world_size: int, + ): + super().__init__(heap_size, device_id, rank, world_size) + self.device = torch.device(f"cuda:{device_id}") + self.lock = Lock() + + # HIP allocation granularity (always a power of two, e.g. 2 MiB on MI300X). + self.granularity = get_allocation_granularity(self.device_id) + self.min_size_class: int = self.granularity + + # Align the heap to the granularity. + self.aligned_heap_size = (heap_size + self.granularity - 1) & ~(self.granularity - 1) + self.va_size = self.aligned_heap_size + self.base_va: int = mem_address_reserve(self.va_size, self.granularity, 0) + + # Bootstrap: map one minimal segment so we always have something mapped. + self.minimal_size: int = self.min_size_class + self.minimal_handle = mem_create(self.minimal_size, self.device_id) + mem_map(self.base_va, self.minimal_size, 0, self.minimal_handle) + + # Access descriptors: allow read/write from every peer device. + self.access_descs: List[hipMemAccessDesc] = [] + for peer_device_id in range(world_size): + desc = hipMemAccessDesc() + desc.location.type = hipMemLocationTypeDevice + desc.location.id = peer_device_id + desc.flags = hipMemAccessFlagsProtReadWrite + self.access_descs.append(desc) + + # ROCm: mem_set_access must be called cumulatively from base_va + # (see rocm-systems#2667). We therefore always call it as + # mem_set_access(base_va, cumulative_size, ...) so the range is + # always contiguous from the reservation start. + self.cumulative_mapped_size: int = self.minimal_size + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + # Physical-segment tracking. + # Maps VA-offset -> (size, is_imported, handle, va). + # Segments remain physically mapped at all times (even when on the free list). + self.allocations: Dict[int, Tuple] = {} + # Ordered list of (offset, size) for get_allocation_segments(). + self.allocation_order: List[Tuple[int, int]] = [] + self._track_allocation(0, self.minimal_size, False, self.minimal_handle, self.base_va) + + # Per-offset generation counter. Incremented every time a VA block is + # remapped with fresh physical memory (i.e., each reuse from the free + # list). The symmetric heap uses (offset, size, generation) as the + # deduplication key so peers re-import when physical backing changes. + self._segment_generation: Dict[int, int] = {0: 0} + + # Next available VA offset for a brand-new physical segment. + self.current_offset: int = self.minimal_size + + # Free lists: size_class (power-of-two bytes) -> [(offset, va), ...] + # Physical memory at these VAs is STILL MAPPED; it will be replaced + # when the VA is next popped from the list (_remap_free_block). + self.free_lists: Dict[int, List[Tuple[int, int]]] = {} + + # Logical-allocation tracking: va -> size_class (needed by free()). + self.logical_allocations: Dict[int, int] = {} + + # Weak-reference finalizers: va -> weakref.finalize object. + # Stored so we can detach them on explicit free() to prevent double-free. + # NOTE: finalizers are attached to tensor.untyped_storage(), not to the + # Python tensor wrapper, so they survive tensor.reshape() / other view ops. + self._finalizers: Dict[int, weakref.finalize] = {} + + # Pending GC-triggered frees. The GC callback appends here instead of + # acquiring the lock directly, which would deadlock if GC fires while + # the lock is already held on the same thread. Entries are processed at + # the start of allocate() and free(). + # Thread-safety note: ``deque.append`` and ``deque.popleft`` are atomic + # in CPython due to the GIL. This is a documented property of + # ``collections.deque`` for single-element operations and is relied on + # here to allow GC finalizers (running without the allocator lock) to + # safely enqueue work for the owning thread. + self._pending_free: deque = deque() + + self.world_size = world_size + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _track_allocation(self, offset: int, size: int, is_imported: bool, handle, va: int) -> None: + """Record a physical segment for cleanup and segmented DMA-BUF export.""" + self.allocations[offset] = (size, is_imported, handle, va) + self.allocation_order.append((offset, size)) + + def _size_class(self, size_bytes: int) -> int: + """Return the smallest power-of-two size class >= size_bytes and >= granularity.""" + return max(_next_pow2(max(size_bytes, 1)), self.min_size_class) + + def _map_new_segment(self, size_class: int) -> Tuple[int, int]: + """ + Map a fresh physical segment of *size_class* bytes at the next + available VA offset and extend cumulative access. + + VA offsets are aligned to the HIP granularity so that consecutive + segments remain contiguous, which is required by the cumulative + ``hipMemSetAccess`` call. + + Returns: + ``(offset, va)`` of the newly mapped segment. + + Raises: + RuntimeError: If the heap VA space is exhausted. + """ + aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) + + if aligned_offset + size_class > self.aligned_heap_size: + raise RuntimeError( + f"VMemPow2Allocator: out of VA space. " + f"Need {size_class} bytes at offset {aligned_offset}, " + f"heap size is {self.aligned_heap_size}, " + f"current offset is {self.current_offset}." + ) + + va = self.base_va + aligned_offset + handle = mem_create(size_class, self.device_id) + mem_map(va, size_class, 0, handle) + + # Extend cumulative access to include the new segment. + new_cumulative = aligned_offset + size_class + if new_cumulative > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self._track_allocation(aligned_offset, size_class, False, handle, va) + self._segment_generation[aligned_offset] = 0 + self.current_offset = aligned_offset + size_class + return aligned_offset, va + + def _remap_free_block(self, offset: int, va: int, size_class: int) -> None: + """ + Refresh the physical backing of a free-listed VA block. + + The *old* physical handle is released (``mem_unmap`` + ``mem_release``) + and fresh physical memory is created (``mem_create`` + ``mem_map``) at + the same VA. Cumulative access is re-set for the full range after the + physical swap so the remapped segment is accessible. + + The generation counter for *offset* is incremented so the symmetric + heap detects the physical change and re-imports the segment on peers. + """ + alloc_info = self.allocations.get(offset) + if alloc_info is not None: + _old_size, _is_imported, old_handle, _old_va = alloc_info + if old_handle is not None: + mem_unmap(_old_va, _old_size) + mem_release(old_handle) + + new_handle = mem_create(size_class, self.device_id) + mem_map(va, size_class, 0, new_handle) + + # Re-set cumulative access after the physical swap (access is tied to + # the physical mapping and must be restored after unmap/remap). + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self.allocations[offset] = (size_class, False, new_handle, va) + self._segment_generation[offset] = self._segment_generation.get(offset, 0) + 1 + + def _process_pending_frees(self) -> None: + """ + Process GC-triggered frees that were queued to avoid lock re-entry. + + Must be called while holding ``self.lock``. + """ + while self._pending_free: + va, size_class = self._pending_free.popleft() + if va not in self.logical_allocations: + continue # already freed manually + del self.logical_allocations[va] + self._finalizers.pop(va, None) + offset = va - self.base_va + self.free_lists.setdefault(size_class, []).append((offset, va)) + + def _register_finalizer(self, tensor: torch.Tensor, va: int, size_class: int) -> None: + """Register a weak-reference GC finalizer on the tensor's storage. + + The finalizer is attached to ``tensor.untyped_storage()`` (the shared + C++ storage object) rather than to the Python tensor wrapper. This is + essential because callers routinely create new wrappers over the same + storage (e.g. via ``.reshape()``): if the finalizer were on the wrapper + it would fire as soon as the first wrapper was discarded, even while + other wrappers still hold the storage alive. Attaching to the storage + ensures the block is only freed when *every* view has been released. + """ + allocator_ref = weakref.ref(self) + + def _gc_free() -> None: + alloc = allocator_ref() + if alloc is None: + return + # Enqueue rather than locking directly to avoid deadlock when GC + # fires inside a locked section on the same thread. + alloc._pending_free.append((va, size_class)) + + self._finalizers[va] = weakref.finalize(tensor.untyped_storage(), _gc_free) + + def _make_tensor_view(self, va: int, size_class: int, num_elements: int, dtype: torch.dtype) -> torch.Tensor: + """Create a PyTorch tensor view over the given VA-backed device memory.""" + elem_sz = _element_size(dtype) + interface_size = (size_class // elem_sz) * elem_sz + cuda_array = _CUDAArrayInterface(va, interface_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + full = tensor_bytes.view(dtype) + if num_elements == 0: + return full.narrow(0, 1, 0) + return full.narrow(0, 0, num_elements) + + # ------------------------------------------------------------------ + # BaseAllocator interface + # ------------------------------------------------------------------ + + def get_base_address(self) -> int: + """Return the base virtual address of this allocator's VA range.""" + return self.base_va + + def get_minimum_allocation_size(self) -> int: + """Minimum allocation size in bytes (one size-class / granule).""" + return self.granularity + + def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) -> torch.Tensor: + """ + Allocate a tensor on the power-of-two symmetric heap. + + If a block of the required size class is on the free list it is reused + (old physical memory released, fresh physical memory mapped at the same + VA); otherwise a new segment is mapped at the next available offset. + + Args: + num_elements: Number of tensor elements. + dtype: PyTorch data type. + alignment: Ignored; alignment is guaranteed by the size class. + + Returns: + A PyTorch tensor of shape ``(num_elements,)``. + + Raises: + RuntimeError: If the heap VA space is exhausted. + """ + with self.lock: + self._process_pending_frees() + + elem_sz = _element_size(dtype) + size_bytes = num_elements * elem_sz + size_class = self._size_class(size_bytes) + + # Try the free list first; otherwise map a new segment. + if size_class in self.free_lists and self.free_lists[size_class]: + offset, va = self.free_lists[size_class].pop() + self._remap_free_block(offset, va, size_class) + else: + offset, va = self._map_new_segment(size_class) + + self.logical_allocations[va] = size_class + tensor = self._make_tensor_view(va, size_class, num_elements, dtype) + self._register_finalizer(tensor, va, size_class) + return tensor + + def free(self, tensor: torch.Tensor) -> None: + """ + Return a tensor's VA block to the free list. + + The block remains physically mapped so that the VA range stays + contiguous (required for the cumulative ``hipMemSetAccess`` call). + Physical memory is released and renewed when the block is next + reused from the free list. + + Zero-element tensors are silently ignored. + + Args: + tensor: A tensor previously returned by :meth:`allocate`. + + Raises: + ValueError: If the tensor was not allocated by this allocator or + was already freed. + """ + if tensor.numel() == 0: + return + + with self.lock: + self._process_pending_frees() + + va = tensor.data_ptr() + if va not in self.logical_allocations: + raise ValueError( + f"VMemPow2Allocator.free(): tensor at VA 0x{va:x} was not " + "allocated by this allocator (or was already freed)." + ) + size_class = self.logical_allocations.pop(va) + + # Detach the GC finalizer to prevent double-free. + fin = self._finalizers.pop(va, None) + if fin is not None: + fin.detach() + + offset = va - self.base_va + self.free_lists.setdefault(size_class, []).append((offset, va)) + + def get_device(self) -> torch.device: + """Return the PyTorch device for this allocator.""" + return self.device + + def owns_tensor(self, tensor: torch.Tensor) -> bool: + """ + Return True if *tensor*'s data pointer lies within this heap's VA range. + + The check is purely address-based; zero-element tensors are checked by + pointer rather than being unconditionally claimed as owned, which would + incorrectly claim externally-created zero-element tensors. + + Args: + tensor: PyTorch tensor to check. + """ + if not tensor.is_cuda: + return False + ptr = tensor.data_ptr() + # data_ptr() returns 0 for tensors that have no storage (e.g. meta tensors) + # or for certain zero-element tensors on some backends. Such tensors are + # never part of this allocator's heap. + if ptr == 0: + return False + return self.base_va <= ptr < self.base_va + self.aligned_heap_size + + # ------------------------------------------------------------------ + # Symmetric-heap segment API (used by SymmetricHeap.refresh_peer_access) + # ------------------------------------------------------------------ + + def get_allocation_segments(self) -> List[Tuple[int, int, int, int]]: + """ + Return the ordered list of physical segments for DMA-BUF export. + + All tracked segments are included (both live and free-listed), since + free-listed blocks remain physically mapped. + + Each element is ``(offset, size, va, generation)`` where *generation* + is a monotonically increasing counter bumped each time the block is + remapped with fresh physical memory. The symmetric heap uses + ``(offset, size, generation)`` as the de-duplication key so that + remapped segments are recognised as new and peer ranks re-import them. + + Returns: + List of ``(offset, size, va, generation)`` tuples in allocation order. + """ + segments = [] + for offset, size in self.allocation_order: + va = self.base_va + offset + generation = self._segment_generation.get(offset, 0) + segments.append((offset, size, va, generation)) + return segments + + # ------------------------------------------------------------------ + # as_symmetric() support + # ------------------------------------------------------------------ + + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Import an external PyTorch tensor into the symmetric heap. + + The returned tensor **shares physical memory** with the original; + changes to one are immediately visible in the other. + + Args: + external_tensor: A contiguous CUDA tensor allocated by PyTorch. + + Returns: + A tensor view in the symmetric heap that shares memory with + *external_tensor*. + + Raises: + RuntimeError: If the tensor is not a contiguous CUDA tensor or + the heap VA space is exhausted. + """ + with self.lock: + if not external_tensor.is_cuda: + raise RuntimeError("VMemPow2Allocator: can only import CUDA tensors.") + if not external_tensor.is_contiguous(): + raise RuntimeError( + "VMemPow2Allocator: only contiguous tensors can be imported; " + "call .contiguous() before as_symmetric()." + ) + + external_ptr = external_tensor.data_ptr() + alloc_base, alloc_size = get_address_range(external_ptr) + offset_in_alloc = external_ptr - alloc_base + + # Export first so we know the actual export_size before the OOM check. + dmabuf_fd, _export_base, export_size = export_dmabuf_handle(alloc_base, alloc_size) + aligned_export_size = (export_size + self.granularity - 1) & ~(self.granularity - 1) + aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) + + if aligned_offset + aligned_export_size > self.aligned_heap_size: + os.close(dmabuf_fd) + raise RuntimeError( + f"VMemPow2Allocator: out of VA space for import. " + f"Need {aligned_export_size} bytes at offset {aligned_offset}, " + f"heap size is {self.aligned_heap_size}." + ) + + try: + imported_handle = mem_import_from_shareable_handle(dmabuf_fd) + finally: + os.close(dmabuf_fd) + + target_va = self.base_va + aligned_offset + mem_map(target_va, aligned_export_size, 0, imported_handle) + + new_cumulative = aligned_offset + aligned_export_size + if new_cumulative > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + tensor_va = target_va + offset_in_alloc + self._track_allocation(aligned_offset, aligned_export_size, True, imported_handle, target_va) + self._segment_generation[aligned_offset] = 0 + self.current_offset = aligned_offset + aligned_export_size + + tensor_size = external_tensor.numel() * external_tensor.element_size() + cuda_array = _CUDAArrayInterface(tensor_va, tensor_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + return tensor_bytes.view(external_tensor.dtype).reshape(external_tensor.shape) + + # ------------------------------------------------------------------ + # Resource management + # ------------------------------------------------------------------ + + def close(self) -> None: + """Release all VMem resources (unmap, release handles, free VA range).""" + if getattr(self, "_closed", False): + return + + with self.lock: + # Detach all GC finalizers so they cannot fire after close(). + for fin in self._finalizers.values(): + fin.detach() + self._finalizers.clear() + self._pending_free.clear() + + # Release all physical segments (both live and free-listed). + for offset, alloc_info in self.allocations.items(): + _size, _is_imported, handle, va = alloc_info + if handle is not None: + mem_unmap(va, _size) + mem_release(handle) + + self.allocations.clear() + self.free_lists.clear() + self.logical_allocations.clear() + + if getattr(self, "base_va", 0): + mem_address_free(self.base_va, self.va_size) + self.base_va = 0 + + self._closed = True + + def __del__(self) -> None: + """Cleanup VMem resources on garbage collection.""" + self.close() + + # ------------------------------------------------------------------ + # Diagnostics + # ------------------------------------------------------------------ + + def stats(self) -> dict: + """ + Return a snapshot of allocator statistics. + + Returns: + A dict with keys: + + * ``heap_size`` – requested heap size in bytes + * ``aligned_heap_size`` – actual VA reservation in bytes + * ``granularity`` – HIP allocation granularity in bytes + * ``current_offset`` – bytes consumed from VA space + * ``num_segments`` – number of physical segments ever mapped + * ``num_live_allocations`` – logical allocations currently in use + * ``free_list_counts`` – dict of {size_class: count} for free lists + """ + with self.lock: + return { + "heap_size": self.heap_size, + "aligned_heap_size": self.aligned_heap_size, + "granularity": self.granularity, + "current_offset": self.current_offset, + "num_segments": len(self.allocation_order), + "num_live_allocations": len(self.logical_allocations), + "free_list_counts": {sc: len(bl) for sc, bl in self.free_lists.items() if bl}, + } diff --git a/iris/iris.py b/iris/iris.py index f0effbb2..a6e9c891 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -78,7 +78,7 @@ class Iris: Args: heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) - allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem" + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem", "vmem_pow2" Example: >>> ctx = iris.iris(heap_size=2**31) # 2GB heap with torch allocator @@ -87,6 +87,9 @@ class Iris: >>> # Use VMem allocator for memory oversubscription >>> ctx = iris.iris(heap_size=2**31, allocator_type="vmem") + + >>> # Use power-of-two VMem allocator for efficient memory reuse + >>> ctx = iris.iris(heap_size=2**31, allocator_type="vmem_pow2") """ def __init__(self, heap_size=1 << 30, allocator_type="torch"): @@ -2399,8 +2402,8 @@ def iris(heap_size=1 << 30, allocator_type="torch"): Args: heap_size (int): Size of the heap in bytes. Defaults to 1GB. - allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem". - Can be overridden with IRIS_ALLOCATOR environment variable. + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem", + "vmem_pow2". Can be overridden with IRIS_ALLOCATOR environment variable. Returns: Iris: An initialized Iris instance. @@ -2413,5 +2416,9 @@ def iris(heap_size=1 << 30, allocator_type="torch"): >>> # Use VMem allocator >>> iris_ctx = iris.iris(2**30, allocator_type="vmem") >>> tensor = iris_ctx.zeros(1024, 1024) + + >>> # Use power-of-two VMem allocator for efficient memory reuse + >>> iris_ctx = iris.iris(2**30, allocator_type="vmem_pow2") + >>> tensor = iris_ctx.zeros(1024, 1024) """ return Iris(heap_size, allocator_type) diff --git a/iris/symmetric_heap.py b/iris/symmetric_heap.py index eef39197..1b7b13f8 100644 --- a/iris/symmetric_heap.py +++ b/iris/symmetric_heap.py @@ -12,7 +12,7 @@ import torch import os -from iris.allocators import TorchAllocator, VMemAllocator +from iris.allocators import TorchAllocator, VMemAllocator, VMemPow2Allocator from iris.fd_passing import setup_fd_infrastructure from iris._distributed_helpers import distributed_allgather @@ -24,7 +24,7 @@ class SymmetricHeap: Manages distributed memory with symmetric addressing across ranks, handling all allocator coordination and memory sharing internally. - Supports multiple allocator backends: 'torch' (default) and 'vmem'. + Supports multiple allocator backends: 'torch' (default), 'vmem', and 'vmem_pow2'. """ def __init__( @@ -43,7 +43,7 @@ def __init__( device_id: GPU device ID cur_rank: Current process rank num_ranks: Total number of ranks - allocator_type: Type of allocator ("torch" or "vmem"); default "torch" + allocator_type: Type of allocator ("torch", "vmem", or "vmem_pow2"); default "torch" Raises: ValueError: If allocator_type is not supported @@ -58,8 +58,10 @@ def __init__( self.allocator = TorchAllocator(heap_size, device_id, cur_rank, num_ranks) elif allocator_type == "vmem": self.allocator = VMemAllocator(heap_size, device_id, cur_rank, num_ranks) + elif allocator_type == "vmem_pow2": + self.allocator = VMemPow2Allocator(heap_size, device_id, cur_rank, num_ranks) else: - raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem'") + raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem', 'vmem_pow2'") self.fd_conns = setup_fd_infrastructure(cur_rank, num_ranks) device = self.allocator.get_device() @@ -160,6 +162,7 @@ def refresh_peer_access(self): export_dmabuf_handle, mem_import_from_shareable_handle, mem_map, + mem_unmap, mem_set_access, mem_address_reserve, hipMemAccessDesc, @@ -197,9 +200,11 @@ def refresh_peer_access(self): my_segments = self.allocator.get_allocation_segments() my_exported_fds = [] - for offset, size, va in my_segments: + for seg in my_segments: + offset, size, va = seg[0], seg[1], seg[2] + generation = seg[3] if len(seg) > 3 else 0 dmabuf_fd, export_base, export_size = export_dmabuf_handle(va, size) - my_exported_fds.append((dmabuf_fd, export_size, offset)) + my_exported_fds.append((dmabuf_fd, export_size, offset, generation)) access_desc = hipMemAccessDesc() access_desc.location.type = hipMemLocationTypeDevice @@ -220,7 +225,7 @@ def refresh_peer_access(self): peer_va_base = self._peer_va_ranges[peer] peer_fds = [] - for seg_idx, (my_fd, my_size, my_offset) in enumerate(my_exported_fds): + for my_fd, my_size, my_offset, my_gen in my_exported_fds: # Exchange FDs (higher rank sends first to avoid deadlock) if self.cur_rank > peer: send_fd(sock, my_fd) @@ -229,25 +234,38 @@ def refresh_peer_access(self): peer_fd, _ = recv_fd(sock) send_fd(sock, my_fd) - peer_fds.append((peer_fd, my_size, my_offset)) + peer_fds.append((peer_fd, my_size, my_offset, my_gen)) if not hasattr(self, "_peer_cumulative_sizes"): self._peer_cumulative_sizes = {} cumulative_size = self._peer_cumulative_sizes.get(peer, 0) - if not hasattr(self, "_peer_imported_segments"): - self._peer_imported_segments = {} - if peer not in self._peer_imported_segments: - self._peer_imported_segments[peer] = set() - - for peer_fd, segment_size, offset in peer_fds: - segment_key = (offset, segment_size) - if segment_key in self._peer_imported_segments[peer]: + # _peer_segment_generations maps (offset, size) -> generation for each + # segment already imported from this peer. When the generation changes + # the VA has been remapped with new physical memory and must be + # re-imported (unmap old handle, map new handle). + if not hasattr(self, "_peer_segment_generations"): + self._peer_segment_generations = {} + if peer not in self._peer_segment_generations: + self._peer_segment_generations[peer] = {} + + for peer_fd, segment_size, offset, generation in peer_fds: + seg_key = (offset, segment_size) + known_gen = self._peer_segment_generations[peer].get(seg_key) + + if known_gen == generation: + # Already imported this exact physical mapping; skip. import os os.close(peer_fd) continue + if known_gen is not None: + # Physical backing has changed (generation bumped after free+remap). + # Unmap the stale peer mapping before importing the new one. + peer_va = peer_va_base + offset + mem_unmap(peer_va, segment_size) + imported_handle = mem_import_from_shareable_handle(peer_fd) import os @@ -255,17 +273,22 @@ def refresh_peer_access(self): peer_va = peer_va_base + offset mem_map(peer_va, segment_size, 0, imported_handle) - self._peer_imported_segments[peer].add(segment_key) + self._peer_segment_generations[peer][seg_key] = generation new_cumulative = offset + segment_size if new_cumulative > cumulative_size: cumulative_size = new_cumulative mem_set_access(peer_va_base, cumulative_size, access_desc) + elif known_gen is not None: + # Physical backing changed but VA offset is already within the + # cumulative range (no extension needed). Re-set access for + # just this segment so the new physical mapping is accessible. + mem_set_access(peer_va, segment_size, access_desc) self._peer_cumulative_sizes[peer] = cumulative_size self.heap_bases[peer] = peer_va_base - for fd, _, _ in my_exported_fds: + for fd, _sz, _off, _gen in my_exported_fds: import os os.close(fd) diff --git a/tests/unittests/test_vmem_pow2_allocator.py b/tests/unittests/test_vmem_pow2_allocator.py new file mode 100644 index 00000000..d0fc258a --- /dev/null +++ b/tests/unittests/test_vmem_pow2_allocator.py @@ -0,0 +1,509 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for the power-of-two VMem allocator (VMemPow2Allocator). +""" + +import gc +import threading + +import pytest +import torch + +import iris +from iris.allocators.vmem_pow2_allocator import _next_pow2 + + +# --------------------------------------------------------------------------- +# Unit tests for _next_pow2 helper +# --------------------------------------------------------------------------- + + +def test_next_pow2_small(): + assert _next_pow2(1) == 1 + assert _next_pow2(2) == 2 + assert _next_pow2(3) == 4 + assert _next_pow2(4) == 4 + assert _next_pow2(5) == 8 + assert _next_pow2(7) == 8 + assert _next_pow2(8) == 8 + assert _next_pow2(9) == 16 + + +def test_next_pow2_large(): + assert _next_pow2(1 << 20) == 1 << 20 + assert _next_pow2((1 << 20) + 1) == 1 << 21 + + +def test_next_pow2_one(): + assert _next_pow2(0) == 1 + + +# --------------------------------------------------------------------------- +# Allocator creation +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_allocator_creation(): + """VMemPow2Allocator can be created via the iris context.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + assert ctx.cur_rank >= 0 + assert ctx.num_ranks >= 1 + assert ctx.heap_size == 4 << 20 + + from iris.allocators.vmem_pow2_allocator import VMemPow2Allocator + + assert isinstance(ctx.heap.allocator, VMemPow2Allocator) + + +# --------------------------------------------------------------------------- +# Basic allocation +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_basic_allocation(): + """Basic tensor allocation and write.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + tensor = ctx.zeros(1024, dtype=torch.float32) + + assert tensor.shape == (1024,) + assert tensor.device.type == "cuda" + assert torch.all(tensor == 0) + + tensor.fill_(42.0) + assert torch.all(tensor == 42.0) + + +def test_vmem_pow2_multiple_allocations(): + """Multiple allocations from the same heap.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + + tensors = [] + for i in range(8): + t = ctx.zeros(256, dtype=torch.float32) + t.fill_(float(i)) + tensors.append(t) + + for i, t in enumerate(tensors): + assert torch.all(t == float(i)), f"Tensor {i} has wrong value." + + +def test_vmem_pow2_different_dtypes(): + """Allocations with different dtypes.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + + t_f32 = ctx.zeros(128, dtype=torch.float32) + t_f16 = ctx.zeros(128, dtype=torch.float16) + t_i32 = ctx.zeros(128, dtype=torch.int32) + + t_f32.fill_(1.0) + t_f16.fill_(2.0) + t_i32.fill_(3) + + assert torch.all(t_f32 == 1.0) + assert torch.all(t_f16 == 2.0) + assert torch.all(t_i32 == 3) + + +# --------------------------------------------------------------------------- +# owns_tensor +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_owns_tensor(): + """owns_tensor correctly identifies heap vs. non-heap tensors.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + heap_tensor = ctx.zeros(100, dtype=torch.float32) + assert ctx.heap.allocator.owns_tensor(heap_tensor), "Heap tensor should be owned." + + external = torch.zeros(100, dtype=torch.float32, device=ctx.device) + assert not ctx.heap.allocator.owns_tensor(external), "External tensor should not be owned." + + # Zero-element tensors NOT from the heap must NOT be claimed as owned. + external_empty = torch.zeros(0, dtype=torch.float32, device=ctx.device) + assert not ctx.heap.allocator.owns_tensor(external_empty), ( + "External zero-element tensor must not be claimed as owned." + ) + + del heap_tensor, external, external_empty + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Free-list reuse +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_free_reuse(): + """After free(), the same VA is returned on the next allocate() of the same size class.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t1 = ctx.zeros(512, dtype=torch.float32) + ptr1 = t1.data_ptr() + + allocator.free(t1) + + t2 = ctx.zeros(512, dtype=torch.float32) + ptr2 = t2.data_ptr() + + assert ptr2 == ptr1, f"Expected VA reuse 0x{ptr1:x}, got 0x{ptr2:x}." + + +def test_vmem_pow2_free_reuse_multiple(): + """Multiple free + realloc cycles for different size classes all reuse VAs.""" + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + for num_elems in [64, 256, 1024]: + t = ctx.zeros(num_elems, dtype=torch.float32) + ptr = t.data_ptr() + allocator.free(t) + + t2 = ctx.zeros(num_elems, dtype=torch.float32) + assert t2.data_ptr() == ptr, f"Expected VA reuse for {num_elems} elements." + + +def test_vmem_pow2_free_wrong_tensor_raises(): + """Freeing a tensor not allocated by this allocator raises ValueError.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + external = torch.zeros(64, dtype=torch.float32, device=ctx.device) + with pytest.raises(ValueError): + allocator.free(external) + + +# --------------------------------------------------------------------------- +# GC-based auto-free +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_gc_auto_free(): + """Tensors that go out of scope are automatically returned to the free list.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + def alloc_and_drop(): + t = ctx.zeros(512, dtype=torch.float32) + return t.data_ptr() # tensor dies at function exit (CPython refcount → 0) + + ptr = alloc_and_drop() + gc.collect() # ensure finalizer has run across all Python implementations + + # The next allocation of the same size class must reuse the freed VA. + t2 = ctx.zeros(512, dtype=torch.float32) + assert t2.data_ptr() == ptr, f"Expected GC-freed VA 0x{ptr:x} to be reused, got 0x{t2.data_ptr():x}." + + +# --------------------------------------------------------------------------- +# Heap bases +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_heap_bases(): + """Heap bases are properly initialised.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + assert ctx.heap_bases.shape == (ctx.num_ranks,) + assert int(ctx.heap_bases[ctx.cur_rank].item()) > 0 + + if ctx.num_ranks > 1: + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) + + +# --------------------------------------------------------------------------- +# Granularity alignment +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_granularity_alignment(): + """The aligned heap size must be a multiple of the HIP granularity.""" + from iris.hip import get_allocation_granularity + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + granularity = get_allocation_granularity(ctx.gpu_id) + + assert ctx.heap.allocator.aligned_heap_size % granularity == 0 + + +# --------------------------------------------------------------------------- +# Size-class rounding +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_size_class_rounding(): + """ + Two allocation sizes that map to the same power-of-two size class can + interchangeably reuse each other's freed VA block. + """ + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + granularity = allocator.granularity + + # Both sizes round up to 2*granularity. + size_a = granularity + 1 # bytes + size_b = granularity + granularity // 2 # bytes + + elems_a = size_a # dtype=torch.int8 → element_size == 1 + elems_b = size_b + + t_a = allocator.allocate(elems_a, torch.int8) + ptr_a = t_a.data_ptr() + allocator.free(t_a) + + t_b = allocator.allocate(elems_b, torch.int8) + ptr_b = t_b.data_ptr() + + assert ptr_b == ptr_a, f"Expected VA reuse: both sizes share size class. ptr_a=0x{ptr_a:x}, ptr_b=0x{ptr_b:x}" + + +# --------------------------------------------------------------------------- +# stats() +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_stats(): + """stats() returns sensible values before, during, and after an allocation.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + s0 = allocator.stats() + assert s0["heap_size"] == 8 << 20 + assert s0["granularity"] > 0 + assert s0["num_live_allocations"] == 0 + + t = ctx.zeros(512, dtype=torch.float32) + s1 = allocator.stats() + assert s1["num_live_allocations"] == 1 + + allocator.free(t) + s2 = allocator.stats() + assert s2["num_live_allocations"] == 0 + + +# --------------------------------------------------------------------------- +# get_allocation_segments() and generation counter +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_allocation_segments_grow(): + """ + get_allocation_segments() grows when new physical segments are mapped + and does NOT grow when free-listed VAs are reused (remap same entry). + """ + ctx = iris.iris(64 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + seg_count_0 = len(allocator.get_allocation_segments()) + + # First allocation: new physical segment. + t1 = ctx.zeros(512, dtype=torch.float32) + seg_count_1 = len(allocator.get_allocation_segments()) + assert seg_count_1 == seg_count_0 + 1 + + # Free + reallocate: reuse VA, no new entry in allocation_order. + allocator.free(t1) + t2 = ctx.zeros(512, dtype=torch.float32) + seg_count_2 = len(allocator.get_allocation_segments()) + assert seg_count_2 == seg_count_1, "Free-list reuse must not create a new segment entry." + + +def test_vmem_pow2_generation_increments_on_remap(): + """Generation counter increases when a freed VA is remapped.""" + ctx = iris.iris(16 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + va = t.data_ptr() + offset = va - allocator.base_va + + gen_before = allocator._segment_generation[offset] + + allocator.free(t) + _ = ctx.zeros(512, dtype=torch.float32) # remap same offset + + gen_after = allocator._segment_generation[offset] + assert gen_after == gen_before + 1, "Generation must increment on remap." + + +# --------------------------------------------------------------------------- +# OOM / heap exhaustion +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_oom(): + """Allocating beyond the VA space raises RuntimeError.""" + # A heap of exactly 4 granules; bootstrap uses 1, so we have 3 left. + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + tensors = [] + with pytest.raises(RuntimeError, match="out of VA space"): + while True: + tensors.append(allocator.allocate(1, torch.int8)) + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_thread_safety(): + """Concurrent alloc/free from multiple threads does not corrupt state.""" + ctx = iris.iris(256 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + errors: list = [] + + def worker(): + try: + for _ in range(20): + t = allocator.allocate(16, torch.float32) + allocator.free(t) + except Exception as exc: # noqa: BLE001 – capture any error from any thread + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for th in threads: + th.start() + for th in threads: + th.join() + + assert not errors, f"Thread-safety errors: {errors}" + + +# --------------------------------------------------------------------------- +# close() and resource cleanup +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_close(): + """close() releases all resources and is idempotent.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + allocator.free(t) + + allocator.close() + assert allocator._closed + assert allocator.base_va == 0 + + # close() must be safe to call multiple times. + allocator.close() + assert allocator._closed + + +def test_vmem_pow2_close_disables_finalizers(): + """close() detaches GC finalizers so they cannot run after the allocator is gone.""" + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + allocator = ctx.heap.allocator + + t = ctx.zeros(512, dtype=torch.float32) + va = t.data_ptr() + + assert va in allocator._finalizers + allocator.close() + assert not allocator._finalizers, "All finalizers must be cleared by close()." + + +# --------------------------------------------------------------------------- +# as_symmetric() (import_external_tensor) +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_import_external_tensor(): + """ + Importing an external tensor gives a symmetric-heap view that shares + physical memory; the original tensor remains valid after ctx is destroyed. + """ + ctx = iris.iris(8 << 20, allocator_type="vmem_pow2") + + original = torch.randn(64, dtype=torch.float32, device=ctx.device) + original_data = original.clone() + + imported = ctx.as_symmetric(original) + assert torch.allclose(imported, original_data), "Imported data should match original." + + imported.fill_(7.0) + assert torch.all(original == 7.0), "Original should see changes through shared memory." + + original.fill_(13.0) + assert torch.all(imported == 13.0), "Imported should see changes through shared memory." + + del ctx, imported + gc.collect() + torch.cuda.synchronize() + + assert torch.all(original == 13.0), "Original tensor should survive ctx destruction." + original.fill_(99.0) + assert torch.all(original == 99.0), "Original tensor should still be writable." + + +# --------------------------------------------------------------------------- +# Multi-rank tests +# --------------------------------------------------------------------------- + + +def test_vmem_pow2_multirank_heap_bases(): + """Multi-rank: each rank sees all peers' heap bases after setup.""" + ctx = iris.iris(4 << 20, allocator_type="vmem_pow2") + + tensor = ctx.zeros(1024, dtype=torch.float32) + tensor.fill_(float(ctx.cur_rank * 100)) + + assert ctx.heap_bases.shape == (ctx.num_ranks,) + assert int(ctx.heap_bases[ctx.cur_rank].item()) > 0 + + if ctx.num_ranks > 1: + for peer in range(ctx.num_ranks): + if peer != ctx.cur_rank: + assert int(ctx.heap_bases[peer].item()) > 0 + assert int(ctx.heap_bases[peer].item()) != int(ctx.heap_bases[ctx.cur_rank].item()) + + ctx.barrier() + tensor.fill_(float(ctx.cur_rank * 100)) + ctx.barrier() + assert torch.all(tensor == float(ctx.cur_rank * 100)) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + test_next_pow2_small() + test_next_pow2_large() + test_next_pow2_one() + test_vmem_pow2_allocator_creation() + test_vmem_pow2_basic_allocation() + test_vmem_pow2_multiple_allocations() + test_vmem_pow2_different_dtypes() + test_vmem_pow2_owns_tensor() + test_vmem_pow2_free_reuse() + test_vmem_pow2_free_reuse_multiple() + test_vmem_pow2_gc_auto_free() + test_vmem_pow2_heap_bases() + test_vmem_pow2_granularity_alignment() + test_vmem_pow2_size_class_rounding() + test_vmem_pow2_stats() + test_vmem_pow2_allocation_segments_grow() + test_vmem_pow2_generation_increments_on_remap() + test_vmem_pow2_oom() + test_vmem_pow2_thread_safety() + test_vmem_pow2_close() + test_vmem_pow2_close_disables_finalizers() + test_vmem_pow2_import_external_tensor() + test_vmem_pow2_multirank_heap_bases() + print("All VMemPow2Allocator tests passed.")