Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True):
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")

mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size:
if memory_to_free < model_loaded_size:
partially_unload = True
else:
partially_unload = False
Expand Down
65 changes: 45 additions & 20 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
import weakref
import gc
import mmap

import comfy.float
import comfy.hooks
Expand All @@ -42,10 +43,12 @@
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk

def need_mmap() -> bool:
_CUDA_GDS_AVAILABLE = hasattr(torch, "cuda") and hasattr(torch.cuda, "gds") and hasattr(torch.cuda.gds, "GdsFile")

def need_mmap(offload_size: int = 0) -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
if free_cpu_mem - offload_size < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
return True
return False
Expand All @@ -56,21 +59,33 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
temp_file = tempfile.mkstemp(suffix='.bin', prefix='comfy_mmap_')[1]
else:
temp_file = filename

# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
if _CUDA_GDS_AVAILABLE:
file = torch.cuda.gds.GdsFile(temp_file, os.O_CREAT | os.O_RDWR)
file.save_storage(t.untyped_storage(), offset=0)
t_type = t.dtype
t_shape = t.shape
num = t.numel() * t.element_size()
del t
gc.collect()

with open(temp_file, "rb") as fo:
mm = mmap.mmap(fo.fileno(), length=num, access=mmap.ACCESS_READ)
mmap_tensor = torch.frombuffer(mm, dtype=t_type).reshape(t_shape).cpu()
else:
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)

# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
Expand Down Expand Up @@ -104,6 +119,11 @@ def model_to_mmap(model: torch.nn.Module):
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
free_disk_mem = get_free_disk()
model_mem = comfy.model_management.module_size(model)
if model_mem > free_disk_mem:
logging.error(f"Not enough free disk memory to convert model to mmap. Model size: {model_mem/(1024*1024*1024)} GB, free disk memory: {free_disk_mem/(1024*1024*1024)} GB")
raise ValueError("Not enough free disk memory to convert model to mmap")
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")

def convert_fn(t):
Expand Down Expand Up @@ -1022,9 +1042,13 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):


if device_to is not None:
if need_mmap():
if need_mmap(offload_size=self.model_size()):
# offload to mmap
model_to_mmap(self.model)
try:
model_to_mmap(self.model)
except Exception as e:
logging.warning(f"Error occurred while offloading model to mmap: {e}, fall back to normal offload")
self.model.to(device_to)
else:
self.model.to(device_to)
self.model.device = device_to
Expand Down Expand Up @@ -1087,12 +1111,13 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
if need_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
if need_mmap(offload_size=module_mem):
try:
# offload to mmap
model_to_mmap(m)
except Exception as e:
logging.warning(f"Error occurred while offloading {n} to mmap: {e}, fall back to normal offload")
m.to(device_to)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
Expand Down
Loading