diff --git a/comfy/model_management.py b/comfy/model_management.py index aabaed92d39c..6dac20dffcbb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -624,7 +624,7 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage - if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size: + if memory_to_free < model_loaded_size: partially_unload = True else: partially_unload = False diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 311b404eb53c..4f8d9c2d0896 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -30,6 +30,7 @@ import tempfile import weakref import gc +import mmap import comfy.float import comfy.hooks @@ -42,10 +43,12 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk -def need_mmap() -> bool: +_CUDA_GDS_AVAILABLE = hasattr(torch, "cuda") and hasattr(torch.cuda, "gds") and hasattr(torch.cuda.gds, "GdsFile") + +def need_mmap(offload_size: int = 0) -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() - if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + if free_cpu_mem - offload_size < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True return False @@ -56,21 +59,33 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ # Create temporary file if filename is None: - temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] + temp_file = tempfile.mkstemp(suffix='.bin', prefix='comfy_mmap_')[1] else: temp_file = filename - # Save tensor to file - cpu_tensor = t.cpu() - torch.save(cpu_tensor, temp_file) - - # If we created a CPU copy from other device, delete it to free memory - if not t.device.type == 'cpu': - del cpu_tensor + if _CUDA_GDS_AVAILABLE: + file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR) + file.save_storage(t.untyped_storage(), offset=0) + t_type = t.dtype + t_shape = t.shape + num = t.numel() * t.element_size() + del t gc.collect() + + with open(temp_file, "rb") as fo: + mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ) + mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu() + else: + cpu_tensor = t.cpu() + torch.save(cpu_tensor, temp_file) - # Load with mmap - this doesn't load all data into RAM - mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': + del cpu_tensor + gc.collect() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) # Register cleanup callback - will be called when tensor is garbage collected def _cleanup(): @@ -104,6 +119,11 @@ def model_to_mmap(model: torch.nn.Module): The same model with all tensors converted to memory-mapped format """ free_cpu_mem = get_free_memory(torch.device("cpu")) + free_disk_mem = get_free_disk() + model_mem = comfy.model_management.module_size(model) + if model_mem > free_disk_mem: + logging.error(f"Not enough free disk memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free disk memory: {free_disk_mem/(1024*1024*1024)} GB") + raise ValueError("Not enough free disk memory to convert model to mmap") logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): @@ -1022,9 +1042,13 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): if device_to is not None: - if need_mmap(): + if need_mmap(offload_size=self.model_size()): # offload to mmap - model_to_mmap(self.model) + try: + model_to_mmap(self.model) + except Exception as e: + logging.warning(f"Error occurred while offloading model to mmap: {e}, fall back to normal offload") + self.model.to(device_to) else: self.model.to(device_to) self.model.device = device_to @@ -1087,12 +1111,13 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - if need_mmap(): - if get_free_disk() < module_mem: - logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") - break - # offload to mmap - model_to_mmap(m) + if need_mmap(offload_size=module_mem): + try: + # offload to mmap + model_to_mmap(m) + except Exception as e: + logging.warning(f"Error occurred while offloading {n} to mmap: {e}, fall back to normal offload") + m.to(device_to) else: m.to(device_to) module_mem += move_weight_functions(m, device_to)