Skip to content

Commit 1b9e6b3

Browse files
committed
add tests
Summary Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
1 parent a1f1335 commit 1b9e6b3

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import pytest
3+
from unittest.mock import MagicMock, patch
4+
from compressed_tensors.offload.cache import OffloadCache
5+
6+
from llmcompressor.modeling.moe_context import (
7+
_find_ancestor_with_offload_cache,
8+
_apply_offloading_to_replacement,
9+
)
10+
11+
12+
def test_find_ancestor_with_offload_cache():
13+
"""Test finding ancestor modules with OffloadCache."""
14+
# Module without offload cache
15+
module_no_cache = torch.nn.Linear(10, 10)
16+
assert _find_ancestor_with_offload_cache(module_no_cache) is None
17+
18+
# Module with offload cache
19+
module_with_cache = torch.nn.Linear(10, 10)
20+
module_with_cache._parameters = OffloadCache()
21+
assert _find_ancestor_with_offload_cache(module_with_cache) is module_with_cache
22+
23+
# Parent with child that has cache
24+
parent = torch.nn.Sequential(module_with_cache)
25+
assert _find_ancestor_with_offload_cache(parent) is module_with_cache
26+
27+
28+
@patch("llmcompressor.modeling.moe_context.get_cache_init_kwargs")
29+
@patch("llmcompressor.modeling.moe_context.offload_module")
30+
def test_apply_offloading_to_replacement(mock_offload, mock_get_kwargs):
31+
"""Test offloading is applied from original to replacement."""
32+
mock_get_kwargs.return_value = {"device": "cpu"}
33+
34+
# Original with offload cache
35+
original = torch.nn.Sequential(torch.nn.Linear(10, 10))
36+
original[0]._parameters = OffloadCache()
37+
38+
# Replacement without cache
39+
replacement = torch.nn.Sequential(torch.nn.Linear(10, 10))
40+
41+
_apply_offloading_to_replacement(original, replacement)
42+
43+
# Should call offload_module for the child linear layer
44+
assert mock_offload.called
45+
assert mock_get_kwargs.called
46+
47+
48+
def test_apply_offloading_no_cache():
49+
"""Test no offloading applied when original has no cache."""
50+
original = torch.nn.Linear(10, 10)
51+
replacement = torch.nn.Linear(10, 10)
52+
53+
# Should not raise, just return early
54+
_apply_offloading_to_replacement(original, replacement)

0 commit comments

Comments
 (0)