Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/llmcompressor/modeling/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import torch
import torch.distributed as dist
from compressed_tensors.offload import is_distributed
from compressed_tensors.offload import (
get_cache_init_kwargs,
is_distributed,
)
from compressed_tensors.offload.cache import OffloadCache
from compressed_tensors.offload.module import offload_module
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
from loguru import logger
from tqdm import tqdm
Expand Down Expand Up @@ -111,8 +116,18 @@ def moe_calibration_context(
config=model.config,
calibrate_all_experts=calibrate_all_experts,
)
# Apply the same offloading settings as the original module
_apply_offloading_to_replacement(module, replacement)

model.set_submodule(name, replacement)
replaced[name] = (module, replacement)

# Only store original if we need to restore it later
if replacement.is_permanent:
replaced[name] = (None, replacement)
del module
else:
replaced[name] = (module, replacement)

if is_distributed():
dist.barrier()

Expand Down Expand Up @@ -145,3 +160,42 @@ def moe_calibration_context(

def _is_registered(name: str, subclass: RegistryMixin):
return standardize_lookup_name(name) in subclass.registered_names()


def _find_ancestor_with_offload_cache(module):
if isinstance(module._parameters, OffloadCache):
return module

for child in module.children():
child_val = _find_ancestor_with_offload_cache(child)
if child_val is not None:
return child_val
return None


def _apply_offloading_to_replacement(
original: torch.nn.Module, replacement: torch.nn.Module
):
"""
Apply the same offloading configuration from original to replacement module.

If the original module or ANY of its children use OffloadCache, this recursively
applies the same offloading settings to all submodules of the replacement that
have parameters.
"""

module_with_cache = _find_ancestor_with_offload_cache(original)
if module_with_cache is None:
return

kwargs = get_cache_init_kwargs(module_with_cache)

# Apply offloading to all submodules that have parameters
# and are not already offloaded
for module in replacement.modules():
if isinstance(module._parameters, OffloadCache):
continue
if len(list(module.parameters(recurse=False))) == 0:
continue

offload_module(module, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import wraps

import torch
import torch.distributed as dist
from compressed_tensors import ModelCompressor, SparsityCompressionConfig
from compressed_tensors.config import CompressionFormat
from compressed_tensors.offload import from_accelerate, is_rank0, to_accelerate
Expand Down Expand Up @@ -85,6 +86,9 @@ def save_pretrained_wrapper(
# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)

# synchronize before converting back from accelerate
if dist.is_initialized():
dist.barrier()
# convert back from accelerate to restore model to original form
from_accelerate(model)

Expand Down
54 changes: 54 additions & 0 deletions tests/llmcompressor/modeling/test_moe_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import patch

import torch
from compressed_tensors.offload.cache import OffloadCache

from llmcompressor.modeling.moe_context import (
_apply_offloading_to_replacement,
_find_ancestor_with_offload_cache,
)


def test_find_ancestor_with_offload_cache():
"""Test finding ancestor modules with OffloadCache."""
# Module without offload cache
module_no_cache = torch.nn.Linear(10, 10)
assert _find_ancestor_with_offload_cache(module_no_cache) is None

# Module with offload cache
module_with_cache = torch.nn.Linear(10, 10)
module_with_cache._parameters = OffloadCache()
assert _find_ancestor_with_offload_cache(module_with_cache) is module_with_cache

# Parent with child that has cache
parent = torch.nn.Sequential(module_with_cache)
assert _find_ancestor_with_offload_cache(parent) is module_with_cache


@patch("llmcompressor.modeling.moe_context.get_cache_init_kwargs")
@patch("llmcompressor.modeling.moe_context.offload_module")
def test_apply_offloading_to_replacement(mock_offload, mock_get_kwargs):
"""Test offloading is applied from original to replacement."""
mock_get_kwargs.return_value = {"device": "cpu"}

# Original with offload cache
original = torch.nn.Sequential(torch.nn.Linear(10, 10))
original[0]._parameters = OffloadCache()

# Replacement without cache
replacement = torch.nn.Sequential(torch.nn.Linear(10, 10))

_apply_offloading_to_replacement(original, replacement)

# Should call offload_module for the child linear layer
assert mock_offload.called
assert mock_get_kwargs.called


def test_apply_offloading_no_cache():
"""Test no offloading applied when original has no cache."""
original = torch.nn.Linear(10, 10)
replacement = torch.nn.Linear(10, 10)

# Should not raise, just return early
_apply_offloading_to_replacement(original, replacement)
Loading