Skip to content

Commit a9e2488

Browse files
authored
[DDP][GPTQ] Fixes for big models (#2400)
big models were failing with DDP for a few reasons, primarily related to overloading shared memory or having too many mmaps. this was primarily an issue with DDP + cpu offloading but even with disk offloading, the moe context stuff would not use the same offloading as the original module and would revert to cpu offloading and cause problems. Additionally storing all the original modules could still cause mmap issues so those are only stored if needed now. Finally when saving i saw situations where one thread would to the saving and another thread would go past it and then timeout so i added a barrier there. this PR depends on vllm-project/compressed-tensors#650 test plan: 596c6b7 --------- Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent 1690a4c commit a9e2488

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

src/llmcompressor/modeling/moe_context.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
import torch
1818
import torch.distributed as dist
19-
from compressed_tensors.offload import is_distributed
19+
from compressed_tensors.offload import (
20+
get_cache_init_kwargs,
21+
is_distributed,
22+
)
23+
from compressed_tensors.offload.cache import OffloadCache
24+
from compressed_tensors.offload.module import offload_module
2025
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
2126
from loguru import logger
2227
from tqdm import tqdm
@@ -111,8 +116,18 @@ def moe_calibration_context(
111116
config=model.config,
112117
calibrate_all_experts=calibrate_all_experts,
113118
)
119+
# Apply the same offloading settings as the original module
120+
_apply_offloading_to_replacement(module, replacement)
121+
114122
model.set_submodule(name, replacement)
115-
replaced[name] = (module, replacement)
123+
124+
# Only store original if we need to restore it later
125+
if replacement.is_permanent:
126+
replaced[name] = (None, replacement)
127+
del module
128+
else:
129+
replaced[name] = (module, replacement)
130+
116131
if is_distributed():
117132
dist.barrier()
118133

@@ -145,3 +160,42 @@ def moe_calibration_context(
145160

146161
def _is_registered(name: str, subclass: RegistryMixin):
147162
return standardize_lookup_name(name) in subclass.registered_names()
163+
164+
165+
def _find_ancestor_with_offload_cache(module):
166+
if isinstance(module._parameters, OffloadCache):
167+
return module
168+
169+
for child in module.children():
170+
child_val = _find_ancestor_with_offload_cache(child)
171+
if child_val is not None:
172+
return child_val
173+
return None
174+
175+
176+
def _apply_offloading_to_replacement(
177+
original: torch.nn.Module, replacement: torch.nn.Module
178+
):
179+
"""
180+
Apply the same offloading configuration from original to replacement module.
181+
182+
If the original module or ANY of its children use OffloadCache, this recursively
183+
applies the same offloading settings to all submodules of the replacement that
184+
have parameters.
185+
"""
186+
187+
module_with_cache = _find_ancestor_with_offload_cache(original)
188+
if module_with_cache is None:
189+
return
190+
191+
kwargs = get_cache_init_kwargs(module_with_cache)
192+
193+
# Apply offloading to all submodules that have parameters
194+
# and are not already offloaded
195+
for module in replacement.modules():
196+
if isinstance(module._parameters, OffloadCache):
197+
continue
198+
if len(list(module.parameters(recurse=False))) == 0:
199+
continue
200+
201+
offload_module(module, **kwargs)

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44

55
import torch
6+
import torch.distributed as dist
67
from compressed_tensors import ModelCompressor, SparsityCompressionConfig
78
from compressed_tensors.config import CompressionFormat
89
from compressed_tensors.offload import from_accelerate, is_rank0, to_accelerate
@@ -85,6 +86,9 @@ def save_pretrained_wrapper(
8586
# copy python files from cache dir to save_path if any
8687
copy_python_files_from_model_cache(model, save_directory)
8788

89+
# synchronize before converting back from accelerate
90+
if dist.is_initialized():
91+
dist.barrier()
8892
# convert back from accelerate to restore model to original form
8993
from_accelerate(model)
9094

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from unittest.mock import patch
2+
3+
import torch
4+
from compressed_tensors.offload.cache import OffloadCache
5+
6+
from llmcompressor.modeling.moe_context import (
7+
_apply_offloading_to_replacement,
8+
_find_ancestor_with_offload_cache,
9+
)
10+
11+
12+
class MockOffloadCache(OffloadCache):
13+
"""Mock implementation of OffloadCache for testing."""
14+
15+
offload_device = "cpu"
16+
17+
def onload(self, offloaded):
18+
return offloaded
19+
20+
def offload(self, tensor):
21+
return tensor
22+
23+
def update_offload(self, offloaded, data):
24+
if offloaded is not None and data is not None:
25+
offloaded.copy_(data)
26+
27+
28+
def test_find_ancestor_with_offload_cache():
29+
"""Test finding ancestor modules with OffloadCache."""
30+
# Module without offload cache
31+
module_no_cache = torch.nn.Linear(10, 10)
32+
assert _find_ancestor_with_offload_cache(module_no_cache) is None
33+
34+
# Module with offload cache
35+
module_with_cache = torch.nn.Linear(10, 10)
36+
module_with_cache._parameters = MockOffloadCache(onload_device="cpu")
37+
assert _find_ancestor_with_offload_cache(module_with_cache) is module_with_cache
38+
39+
# Parent with child that has cache
40+
parent = torch.nn.Sequential(module_with_cache)
41+
assert _find_ancestor_with_offload_cache(parent) is module_with_cache
42+
43+
44+
@patch("llmcompressor.modeling.moe_context.get_cache_init_kwargs")
45+
@patch("llmcompressor.modeling.moe_context.offload_module")
46+
def test_apply_offloading_to_replacement(mock_offload, mock_get_kwargs):
47+
"""Test offloading is applied from original to replacement."""
48+
mock_get_kwargs.return_value = {"device": "cpu"}
49+
50+
# Original with offload cache
51+
original = torch.nn.Sequential(torch.nn.Linear(10, 10))
52+
original[0]._parameters = MockOffloadCache(onload_device="cpu")
53+
54+
# Replacement without cache
55+
replacement = torch.nn.Sequential(torch.nn.Linear(10, 10))
56+
57+
_apply_offloading_to_replacement(original, replacement)
58+
59+
# Should call offload_module for the child linear layer
60+
assert mock_offload.called
61+
assert mock_get_kwargs.called
62+
63+
64+
def test_apply_offloading_no_cache():
65+
"""Test no offloading applied when original has no cache."""
66+
original = torch.nn.Linear(10, 10)
67+
replacement = torch.nn.Linear(10, 10)
68+
69+
# Should not raise, just return early
70+
_apply_offloading_to_replacement(original, replacement)

0 commit comments

Comments
 (0)