Skip to content

Commit 0c5c514

Browse files
unit test fixes
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent f6c797e commit 0c5c514

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/llmcompressor/modeling/fuse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def normalize_embedding(embedding: torch.nn.Module):
3232
else:
3333
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}")
3434

35+
3536
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
3637
"""
3738
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
@@ -42,7 +43,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
4243
:param norm: norm layer whose weight will be fused into subsequent linears
4344
:param linears: linear layers which directly follow the norm layer
4445
"""
45-
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)):
46+
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)):
4647
for linear in linears:
4748
# NOTE: spinquant does this op in float64
4849
exec_device = get_execution_device(norm)

tests/llmcompressor/modeling/test_fuse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
22
import torch
33

4-
from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears
4+
from llmcompressor.modeling.fuse import normalize_embedding, fuse_norm_linears
55

66

77
@pytest.mark.unit
8-
def test_center_embeddings():
8+
def test_normalize_embedding():
99
embedding = torch.nn.Embedding(10, 10)
10-
center_embeddings(embedding)
10+
normalize_embedding(embedding)
1111

1212
assert torch.allclose(
1313
embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5

0 commit comments

Comments
 (0)