Skip to content

Commit b1e637e

Browse files
committed
better registery logic
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8542f8d commit b1e637e

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from llmcompressor.modeling import fuse_norm_linears, normalize_embedding
1616
from llmcompressor.modifiers import Modifier
1717

18-
from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings
19-
from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping
18+
from .mappings import SpinQuantMapping, infer_mapping_from_model
19+
from .norm_mappings import NormMapping, infer_norm_mapping_from_model
2020

2121

2222
class SpinquantRotation(str, Enum):
@@ -36,9 +36,7 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
3636

3737
# norm mappings separate from spinquant mappings to allow users to
3838
# override spinquant mappings with transform_config without overriding norms
39-
# we can combine these mappings, but it requires some more validation logic
40-
# maybe there's a reason to keep if other modifiers want norm fusing, idk
41-
mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True)
39+
mappings: Optional[SpinQuantMapping] = Field(default=None, exclude=True)
4240
norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True)
4341

4442
# optional override for more fine-grained control
@@ -53,8 +51,8 @@ def validate_rotations(cls, value):
5351

5452
def on_initialize(self, state: State, **kwargs) -> bool:
5553
# TODO: more validation
56-
self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__]
57-
self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__]
54+
self.mappings = infer_mapping_from_model(state.model)
55+
self.norm_mappings = infer_norm_mapping_from_model(state.model)
5856

5957
if self.transform_config is not None:
6058
if self.mappings is not None:

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Dict, List, Optional
22

3+
from loguru import logger
34
from pydantic import BaseModel, Field, field_validator
5+
from transformers import PreTrainedModel
46

7+
__all__ = ["SpinQuantMapping", "infer_mapping_from_model"]
58

6-
class SpinQuantMappings(BaseModel):
9+
10+
class SpinQuantMapping(BaseModel):
711
embedding: str
812

913
attn_q: str
@@ -25,7 +29,7 @@ def cast_to_list(cls, value):
2529
return value
2630

2731

28-
_default_mappings = SpinQuantMappings(
32+
_default_mappings = SpinQuantMapping(
2933
embedding="re:.*embed_tokens$",
3034
attn_q="re:.*q_proj$",
3135
attn_k="re:.*k_proj$",
@@ -37,6 +41,17 @@ def cast_to_list(cls, value):
3741
)
3842

3943

40-
SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = {
44+
SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = {
4145
"LlamaForCausalLM": _default_mappings,
4246
}
47+
48+
49+
def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping:
50+
architecture = model.__class__.__name__
51+
if architecture not in SPINQUANT_MAPPING_REGISTRY:
52+
logger.info(
53+
f"Unrecognized model architecture {architecture}. "
54+
"Falling back to default mappings"
55+
)
56+
57+
return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings)

src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Dict, List
22

3+
from loguru import logger
34
from pydantic import BaseModel, field_validator
5+
from transformers import PreTrainedModel
6+
7+
__all__ = ["infer_norm_mapping_from_model"]
48

59

610
class NormMapping(BaseModel):
@@ -15,7 +19,7 @@ def cast_to_list(cls, value):
1519
return value
1620

1721

18-
_default_norm_mappings = [
22+
_default_mappings = [
1923
NormMapping(
2024
norm="re:.*input_layernorm$",
2125
linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
@@ -31,5 +35,16 @@ def cast_to_list(cls, value):
3135
]
3236

3337
NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = {
34-
"LlamaForCausalLM": _default_norm_mappings,
38+
"LlamaForCausalLM": _default_mappings,
3539
}
40+
41+
42+
def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]:
43+
architecture = model.__class__.__name__
44+
if architecture not in NORM_MAPPING_REGISTRY:
45+
logger.info(
46+
f"Unrecognized model architecture {architecture}. "
47+
"Falling back to default mappings"
48+
)
49+
50+
return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings)

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
register_offload_parameter,
1515
)
1616
from loguru import logger
17-
from safetensors.torch import storage_ptr
1817
from transformers import PreTrainedModel
1918

2019
from llmcompressor.core import active_session

0 commit comments

Comments
 (0)