Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
6e33ee3
debug error
strint Oct 16, 2025
fa19dd4
debug offload
strint Oct 16, 2025
f40e00c
add detail debug
strint Oct 16, 2025
2b22296
add debug log
strint Oct 16, 2025
c1eac55
add debug log
strint Oct 16, 2025
9352987
add log
strint Oct 16, 2025
a207301
rm useless log
strint Oct 16, 2025
71b23d1
rm useless log
strint Oct 16, 2025
e5ff6a1
refine log
strint Oct 16, 2025
5c3c6c0
add debug log of cpu load
strint Oct 17, 2025
6583cc0
debug load mem
strint Oct 17, 2025
49597bf
load remains mmap
strint Oct 17, 2025
21ebcad
debug free mem
strint Oct 20, 2025
4ac827d
unload partial
strint Oct 20, 2025
e9e1d2f
add mmap tensor
strint Oct 20, 2025
4956178
fix log
strint Oct 20, 2025
8aeebbf
fix to
strint Oct 20, 2025
05c2518
refact mmap
strint Oct 20, 2025
2f0d566
refine code
strint Oct 21, 2025
2d010f5
refine code
strint Oct 21, 2025
fff56de
fix format
strint Oct 21, 2025
08e094e
use native mmap
strint Oct 21, 2025
8038393
lazy rm file
strint Oct 21, 2025
98ba311
add env
strint Oct 21, 2025
f3c673d
Merge branch 'master' of https://github.com/siliconflow/ComfyUI into …
strint Oct 22, 2025
aab0e24
fix MMAP_MEM_THRESHOLD_GB default
strint Oct 23, 2025
58d28ed
no limit for offload size
strint Oct 23, 2025
c312733
refine log
strint Oct 23, 2025
dc7c77e
better partial unload
strint Oct 23, 2025
5c5fbdd
debug mmap
strint Nov 17, 2025
d28093f
Merge branch 'master' into refine_offload
doombeaker Nov 26, 2025
96c7f18
Merge branch 'master' into refine_offload
doombeaker Nov 27, 2025
7733d51
try fix flux2 (#9)
strint Dec 4, 2025
211fa31
Merge branch 'master' into refine_offload
doombeaker Dec 8, 2025
1122cd0
allow offload quant (#10)
strint Dec 9, 2025
532eb01
rm comment
strint Dec 9, 2025
a511d0d
Merge remote-tracking branch 'upstream/master' into refine_offload
yiquanfeng Dec 12, 2025
f61871b
Merge branch 'master' into refine_offload
doombeaker Dec 12, 2025
407dab1
merge master
strint Dec 22, 2025
e2ddb7a
merge master
strint Jan 9, 2026
d5412e2
skip quant tensor
strint Jan 9, 2026
68ceb5a
refine model_unload (#14)
sfiisf Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_management import get_free_memory

class ModelType(Enum):
EPS = 1
Expand Down Expand Up @@ -305,8 +306,15 @@ def load_model_weights(self, sd, unet_prefix=""):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}")
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
logging.debug(f"load model {self.model_config} weights process end")
# replace tensor with mmap tensor by assign
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

Expand Down
68 changes: 58 additions & 10 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@
import gc
import os

from functools import lru_cache

@lru_cache(maxsize=1)
def get_mmap_mem_threshold_gb():
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}")
return mmap_mem_threshold_gb

def get_free_disk():
return psutil.disk_usage("/").free

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
Expand Down Expand Up @@ -535,16 +546,50 @@ def should_reload_model(self, force_patch_weights=False):
return False

def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
model_loaded_size = self.model.loaded_size()
if memory_to_free is None:
# free the full model
memory_to_free = model_loaded_size

logging.debug(f"model_unload: {self.model.model.__class__.__name__}")
logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
logging.debug(f"unpatch_weights: {unpatch_weights}")
logging.debug(f"loaded_size: {model_loaded_size/(1024*1024*1024)} GB")
logging.debug(f"offload_device: {self.model.offload_device}")

available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")

mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size:
partially_unload = True
else:
partially_unload = False

if partially_unload:
logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed < memory_to_free:
logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
if freed == model_loaded_size:
partially_unload = False
else:
logging.debug("Do full unload")
self.model.detach(unpatch_weights)
logging.debug("Do full unload done")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None

available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")

if partially_unload:
return False
else:
return True


def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
Expand Down Expand Up @@ -593,6 +638,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def free_memory(memory_required, device, keep_loaded=[]):
logging.debug("start to free mem")
cleanup_models_gc()
unloaded_model = []
can_unload = []
Expand Down Expand Up @@ -630,6 +676,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models

def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
logging.debug(f"start to load models")
cleanup_models_gc()
global vram_state

Expand All @@ -651,6 +698,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = []

for x in models:
logging.debug(f"start loading model to vram: {x.model.__class__.__name__}")
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
Expand Down
103 changes: 101 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from typing import Callable, Optional

import torch
import os
import tempfile
import weakref
import gc

import comfy.float
import comfy.hooks
Expand All @@ -37,6 +41,87 @@
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk

def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
return True
return False

def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
else:
temp_file = filename

# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)

# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
try:
if os.path.exists(temp_file):
os.remove(temp_file)
logging.debug(f"Cleaned up mmap file: {temp_file}")
except Exception:
pass

weakref.finalize(mmap_tensor, _cleanup)

return mmap_tensor

def model_to_mmap(model: torch.nn.Module):
"""Convert all parameters and buffers to memory-mapped tensors

This function mimics PyTorch's Module.to() behavior but converts
tensors to memory-mapped format instead, using _apply() method.

Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244

Note: For Parameters, we modify .data in-place because
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
For buffers, _apply() will automatically update the reference.

Args:
model: PyTorch module to convert

Returns:
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")

def convert_fn(t):
if isinstance(t, QuantizedTensor):
logging.debug(f"QuantizedTensor detected, mmap skipped, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
return t
elif isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):
return to_mmap(t)
return t

new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model


def string_to_seed(data):
Expand Down Expand Up @@ -506,6 +591,7 @@ def get_model_object(self, name: str) -> torch.nn.Module:
return comfy.utils.get_attr(self.model, name)

def model_patches_to(self, device):
# TODO(sf): to mmap
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
Expand Down Expand Up @@ -855,9 +941,15 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
self.model.current_weight_patches_uuid = None
self.backup.clear()


if device_to is not None:
self.model.to(device_to)
if need_mmap():
# offload to mmap
model_to_mmap(self.model)
else:
self.model.to(device_to)
self.model.device = device_to

self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0

Expand Down Expand Up @@ -916,7 +1008,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
if need_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
Expand Down
1 change: 1 addition & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
logging.warning("{} {}".format(diffusers_keys[k], k))

offload_device = model_management.unet_offload_device()
logging.debug(f"loader load model to offload device: {offload_device}")
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
weight_dtype = None
Expand Down
3 changes: 3 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
if not DISABLE_MMAP:
logging.debug(f"load_torch_file of safetensors into mmap True")
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
Expand All @@ -81,6 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else:
torch_args = {}
if MMAP_TORCH_FILES:
logging.debug(f"load_torch_file of torch state dict into mmap True")
torch_args["mmap"] = True

if safe_load or ALWAYS_SAFE_LOAD:
Expand Down
Loading
Loading