From 66be54738624321df099a2e99c4606dac79d5453 Mon Sep 17 00:00:00 2001 From: sfiisf Date: Wed, 15 Apr 2026 16:45:48 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20partially=5Funload=20?= =?UTF-8?q?=E4=B8=8E=20need=5Fmmap=20=E5=8F=8A=20model=5Fto=5Fmmap=20?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/model_management.py | 2 +- comfy/model_patcher.py | 30 +++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index aabaed92d39c..6dac20dffcbb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -624,7 +624,7 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage - if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size: + if memory_to_free < model_loaded_size: partially_unload = True else: partially_unload = False diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 311b404eb53c..7ddea07b11b6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -42,10 +42,10 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk -def need_mmap() -> bool: +def need_mmap(offload_size: int = 0) -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() - if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + if free_cpu_mem - offload_size < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True return False @@ -104,6 +104,14 @@ def model_to_mmap(model: torch.nn.Module): The same model with all tensors converted to memory-mapped format """ free_cpu_mem = get_free_memory(torch.device("cpu")) + free_disk_mem = get_free_disk() + model_mem = comfy.model_management.module_size(model) + if model_mem > free_cpu_mem: + logging.error(f"Not enough free CPU memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free CPU memory: {free_cpu_mem/(1024*1024*1024)} GB") + raise ValueError("Not enough free CPU memory to convert model to mmap") + if model_mem > free_disk_mem: + logging.error(f"Not enough free disk memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free disk memory: {free_disk_mem/(1024*1024*1024)} GB") + raise ValueError("Not enough free disk memory to convert model to mmap") logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): @@ -1022,9 +1030,12 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): if device_to is not None: - if need_mmap(): + if need_mmap(offload_size=self.model_size()): # offload to mmap - model_to_mmap(self.model) + try: + model_to_mmap(self.model) + except Exception as e: + pass # todo else: self.model.to(device_to) self.model.device = device_to @@ -1087,12 +1098,13 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - if need_mmap(): - if get_free_disk() < module_mem: - logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + if need_mmap(offload_size=module_mem): + try: + # offload to mmap + model_to_mmap(m) + except Exception as e: + logging.error(f"Error occurred while offloading {n} to mmap: {e}") break - # offload to mmap - model_to_mmap(m) else: m.to(device_to) module_mem += move_weight_functions(m, device_to) From 00c0028c7b0952c4f6128334870f06e1f4e9dc03 Mon Sep 17 00:00:00 2001 From: sfiisf Date: Thu, 23 Apr 2026 17:07:59 +0800 Subject: [PATCH 2/4] use gds to mmap offload --- comfy/model_patcher.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7ddea07b11b6..da5fdbd6492c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -30,6 +30,7 @@ import tempfile import weakref import gc +import mmap import comfy.float import comfy.hooks @@ -56,21 +57,21 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ # Create temporary file if filename is None: - temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] + temp_file = tempfile.mkstemp(suffix='.bin', prefix='comfy_mmap_')[1] else: temp_file = filename - # Save tensor to file - cpu_tensor = t.cpu() - torch.save(cpu_tensor, temp_file) - - # If we created a CPU copy from other device, delete it to free memory - if not t.device.type == 'cpu': - del cpu_tensor - gc.collect() - - # Load with mmap - this doesn't load all data into RAM - mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR) + file.save_storage(t.untyped_storage(), offset=0) + t_type = t.dtype + t_shape = t.shape + num = t.numel() * t.element_size() + del t + gc.collect() + + fo = open(temp_file, "rb") + mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ) + mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu() # Register cleanup callback - will be called when tensor is garbage collected def _cleanup(): @@ -106,9 +107,6 @@ def model_to_mmap(model: torch.nn.Module): free_cpu_mem = get_free_memory(torch.device("cpu")) free_disk_mem = get_free_disk() model_mem = comfy.model_management.module_size(model) - if model_mem > free_cpu_mem: - logging.error(f"Not enough free CPU memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free CPU memory: {free_cpu_mem/(1024*1024*1024)} GB") - raise ValueError("Not enough free CPU memory to convert model to mmap") if model_mem > free_disk_mem: logging.error(f"Not enough free disk memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free disk memory: {free_disk_mem/(1024*1024*1024)} GB") raise ValueError("Not enough free disk memory to convert model to mmap") @@ -1035,7 +1033,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): try: model_to_mmap(self.model) except Exception as e: - pass # todo + logging.warning(f"Error occurred while offloading model to mmap: {e}") + # todo: 回退 然后 partially_unload else: self.model.to(device_to) self.model.device = device_to From 46591dfe8783458a544961e87f5eda554c215064 Mon Sep 17 00:00:00 2001 From: sfiisf Date: Fri, 24 Apr 2026 14:29:33 +0800 Subject: [PATCH 3/4] =?UTF-8?q?mmap=20=E5=A4=B1=E8=B4=A5=E6=97=B6=E5=9B=9E?= =?UTF-8?q?=E9=80=80=E5=88=B0=20normal=20offload?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/model_patcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index da5fdbd6492c..50eeea28cb45 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1033,8 +1033,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): try: model_to_mmap(self.model) except Exception as e: - logging.warning(f"Error occurred while offloading model to mmap: {e}") - # todo: 回退 然后 partially_unload + logging.warning(f"Error occurred while offloading model to mmap: {e}, fall back to normal offload") + self.model.to(device_to) else: self.model.to(device_to) self.model.device = device_to @@ -1102,8 +1102,8 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals # offload to mmap model_to_mmap(m) except Exception as e: - logging.error(f"Error occurred while offloading {n} to mmap: {e}") - break + logging.warning(f"Error occurred while offloading {n} to mmap: {e}, fall back to normal offload") + m.to(device_to) else: m.to(device_to) module_mem += move_weight_functions(m, device_to) From 0b5b8fad3f1ff0fc75e549b5ca7409370a9ca896 Mon Sep 17 00:00:00 2001 From: sfiisf Date: Tue, 28 Apr 2026 14:48:59 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=89=8D=E7=BD=AE?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E6=A3=80=E6=9F=A5&=E5=85=B3=E9=97=ADfo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy/model_patcher.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 50eeea28cb45..4f8d9c2d0896 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -43,6 +43,8 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +_CUDA_GDS_AVAILABLE = hasattr(torch, "cuda") and hasattr(torch.cuda, "gds") and hasattr(torch.cuda.gds, "GdsFile") + def need_mmap(offload_size: int = 0) -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() @@ -61,17 +63,29 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: else: temp_file = filename - file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR) - file.save_storage(t.untyped_storage(), offset=0) - t_type = t.dtype - t_shape = t.shape - num = t.numel() * t.element_size() - del t - gc.collect() - - fo = open(temp_file, "rb") - mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ) - mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu() + if _CUDA_GDS_AVAILABLE: + file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR) + file.save_storage(t.untyped_storage(), offset=0) + t_type = t.dtype + t_shape = t.shape + num = t.numel() * t.element_size() + del t + gc.collect() + + with open(temp_file, "rb") as fo: + mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ) + mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu() + else: + cpu_tensor = t.cpu() + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': + del cpu_tensor + gc.collect() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) # Register cleanup callback - will be called when tensor is garbage collected def _cleanup():