diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index d6595111bf..a78dcad523 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -16,7 +16,12 @@ import torch import torch.distributed as dist -from compressed_tensors.offload import is_distributed +from compressed_tensors.offload import ( + get_cache_init_kwargs, + is_distributed, +) +from compressed_tensors.offload.cache import OffloadCache +from compressed_tensors.offload.module import offload_module from compressed_tensors.registry import RegistryMixin, standardize_lookup_name from loguru import logger from tqdm import tqdm @@ -111,8 +116,18 @@ def moe_calibration_context( config=model.config, calibrate_all_experts=calibrate_all_experts, ) + # Apply the same offloading settings as the original module + _apply_offloading_to_replacement(module, replacement) + model.set_submodule(name, replacement) - replaced[name] = (module, replacement) + + # Only store original if we need to restore it later + if replacement.is_permanent: + replaced[name] = (None, replacement) + del module + else: + replaced[name] = (module, replacement) + if is_distributed(): dist.barrier() @@ -145,3 +160,42 @@ def moe_calibration_context( def _is_registered(name: str, subclass: RegistryMixin): return standardize_lookup_name(name) in subclass.registered_names() + + +def _find_ancestor_with_offload_cache(module): + if isinstance(module._parameters, OffloadCache): + return module + + for child in module.children(): + child_val = _find_ancestor_with_offload_cache(child) + if child_val is not None: + return child_val + return None + + +def _apply_offloading_to_replacement( + original: torch.nn.Module, replacement: torch.nn.Module +): + """ + Apply the same offloading configuration from original to replacement module. + + If the original module or ANY of its children use OffloadCache, this recursively + applies the same offloading settings to all submodules of the replacement that + have parameters. + """ + + module_with_cache = _find_ancestor_with_offload_cache(original) + if module_with_cache is None: + return + + kwargs = get_cache_init_kwargs(module_with_cache) + + # Apply offloading to all submodules that have parameters + # and are not already offloaded + for module in replacement.modules(): + if isinstance(module._parameters, OffloadCache): + continue + if len(list(module.parameters(recurse=False))) == 0: + continue + + offload_module(module, **kwargs) diff --git a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py index 1055b4fed7..bb2894eb03 100644 --- a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py @@ -3,6 +3,7 @@ from functools import wraps import torch +import torch.distributed as dist from compressed_tensors import ModelCompressor, SparsityCompressionConfig from compressed_tensors.config import CompressionFormat from compressed_tensors.offload import from_accelerate, is_rank0, to_accelerate @@ -85,6 +86,9 @@ def save_pretrained_wrapper( # copy python files from cache dir to save_path if any copy_python_files_from_model_cache(model, save_directory) + # synchronize before converting back from accelerate + if dist.is_initialized(): + dist.barrier() # convert back from accelerate to restore model to original form from_accelerate(model) diff --git a/tests/llmcompressor/modeling/test_moe_context.py b/tests/llmcompressor/modeling/test_moe_context.py new file mode 100644 index 0000000000..be2cd26ee4 --- /dev/null +++ b/tests/llmcompressor/modeling/test_moe_context.py @@ -0,0 +1,54 @@ +from unittest.mock import patch + +import torch +from compressed_tensors.offload.cache import OffloadCache + +from llmcompressor.modeling.moe_context import ( + _apply_offloading_to_replacement, + _find_ancestor_with_offload_cache, +) + + +def test_find_ancestor_with_offload_cache(): + """Test finding ancestor modules with OffloadCache.""" + # Module without offload cache + module_no_cache = torch.nn.Linear(10, 10) + assert _find_ancestor_with_offload_cache(module_no_cache) is None + + # Module with offload cache + module_with_cache = torch.nn.Linear(10, 10) + module_with_cache._parameters = OffloadCache() + assert _find_ancestor_with_offload_cache(module_with_cache) is module_with_cache + + # Parent with child that has cache + parent = torch.nn.Sequential(module_with_cache) + assert _find_ancestor_with_offload_cache(parent) is module_with_cache + + +@patch("llmcompressor.modeling.moe_context.get_cache_init_kwargs") +@patch("llmcompressor.modeling.moe_context.offload_module") +def test_apply_offloading_to_replacement(mock_offload, mock_get_kwargs): + """Test offloading is applied from original to replacement.""" + mock_get_kwargs.return_value = {"device": "cpu"} + + # Original with offload cache + original = torch.nn.Sequential(torch.nn.Linear(10, 10)) + original[0]._parameters = OffloadCache() + + # Replacement without cache + replacement = torch.nn.Sequential(torch.nn.Linear(10, 10)) + + _apply_offloading_to_replacement(original, replacement) + + # Should call offload_module for the child linear layer + assert mock_offload.called + assert mock_get_kwargs.called + + +def test_apply_offloading_no_cache(): + """Test no offloading applied when original has no cache.""" + original = torch.nn.Linear(10, 10) + replacement = torch.nn.Linear(10, 10) + + # Should not raise, just return early + _apply_offloading_to_replacement(original, replacement)