Skip to content

Commit a979f8a

Browse files
committed
break into separate files
Signed-off-by: Kyle Sayers <[email protected]>
1 parent fce83be commit a979f8a

File tree

4 files changed

+93
-69
lines changed

4 files changed

+93
-69
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
# flake8: noqa
2+
13
from .base import *

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 14 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,22 @@
11
from enum import Enum
22
from typing import Iterable, List, Literal, Optional
33

4-
from compressed_tensors import match_named_modules, is_match
4+
from compressed_tensors import is_match, match_named_modules
55
from compressed_tensors.transform import (
66
TransformArgs,
77
TransformConfig,
88
TransformScheme,
99
apply_transform_config,
1010
)
11-
from pydantic import BaseModel, Field, field_validator
11+
from pydantic import Field, field_validator
1212
from transformers import PreTrainedModel
1313

1414
from llmcompressor.core import Event, EventType, State
15-
from llmcompressor.modeling import normalize_embedding, fuse_norm_linears
15+
from llmcompressor.modeling import fuse_norm_linears, normalize_embedding
1616
from llmcompressor.modifiers import Modifier
1717

18-
19-
class SpinQuantMappings(BaseModel):
20-
embedding: str
21-
22-
attn_q: str
23-
attn_k: str
24-
attn_v: str
25-
attn_o: str
26-
attn_head_dim: Optional[int] = Field(default=None)
27-
28-
mlp_in: List[str] # up_proj, gate_proj
29-
mlp_out: List[str] # down_proj
30-
31-
lm_head: str
32-
33-
@field_validator("mlp_in", "mlp_out", mode="before")
34-
def cast_to_list(cls, value):
35-
if isinstance(value, str):
36-
return [value]
37-
38-
return value
39-
40-
41-
class NormMapping(BaseModel):
42-
norm: str
43-
linears: List[str]
44-
45-
@field_validator("linears", mode="before")
46-
def cast_to_list(cls, value):
47-
if isinstance(value, str):
48-
return [value]
49-
50-
return value
51-
52-
53-
llama_spinquant = SpinQuantMappings(
54-
embedding="re:.*embed_tokens$",
55-
attn_q="re:.*q_proj$",
56-
attn_k="re:.*k_proj$",
57-
attn_v="re:.*v_proj$",
58-
attn_o="re:.*o_proj$",
59-
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
60-
mlp_out="re:.*down_proj$",
61-
lm_head="lm_head",
62-
)
63-
64-
llama_norm_mappings = [
65-
NormMapping(
66-
norm="re:.*input_layernorm$",
67-
linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
68-
),
69-
NormMapping(
70-
norm="re:.*post_attention_layernorm$",
71-
linears=["re:.*up_proj$", "re:.*gate_proj$"],
72-
),
73-
NormMapping(
74-
norm="model.norm",
75-
linears=["lm_head"],
76-
),
77-
]
18+
from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings
19+
from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping
7820

7921

8022
class SpinquantRotation(Enum):
@@ -92,12 +34,15 @@ class SpinQuantModifier(Modifier):
9234
randomize: bool = Field(default=False)
9335
learnable: bool = Field(default=False)
9436

37+
# norm mappings separate from spinquant mappings to allow users to
38+
# override spinquant mappings with transform_config without overriding norms
39+
# we can combine these mappings, but it requires some more validation logic
40+
# maybe there's a reason to keep if other modifiers want norm fusing, idk
9541
mappings: Optional[SpinQuantMappings] = None
9642
norm_mappings: Optional[List[NormMapping]] = None
9743

98-
transform_config: Optional[TransformConfig] = (
99-
None # optional override for more fine-grained control
100-
)
44+
# optional override for more fine-grained control
45+
transform_config: Optional[TransformConfig] = None
10146

10247
@field_validator("rotations", mode="before")
10348
def validate_rotations(cls, value):
@@ -106,9 +51,9 @@ def validate_rotations(cls, value):
10651
return value
10752

10853
def on_initialize(self, state: State, **kwargs) -> bool:
109-
# HARDCODE
110-
self.mappings = llama_spinquant
111-
self.norm_mappings = llama_norm_mappings
54+
# TODO: more validation
55+
self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__]
56+
self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__]
11257

11358
if self.transform_config is not None:
11459
if self.mappings is not None:
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Dict, List, Optional
2+
3+
from pydantic import BaseModel, Field, field_validator
4+
5+
6+
class SpinQuantMappings(BaseModel):
7+
embedding: str
8+
9+
attn_q: str
10+
attn_k: str
11+
attn_v: str
12+
attn_o: str
13+
attn_head_dim: Optional[int] = Field(default=None)
14+
15+
mlp_in: List[str] # up_proj, gate_proj
16+
mlp_out: List[str] # down_proj
17+
18+
lm_head: str
19+
20+
@field_validator("mlp_in", "mlp_out", mode="before")
21+
def cast_to_list(cls, value):
22+
if isinstance(value, str):
23+
return [value]
24+
25+
return value
26+
27+
28+
_default_mappings = SpinQuantMappings(
29+
embedding="re:.*embed_tokens$",
30+
attn_q="re:.*q_proj$",
31+
attn_k="re:.*k_proj$",
32+
attn_v="re:.*v_proj$",
33+
attn_o="re:.*o_proj$",
34+
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
35+
mlp_out="re:.*down_proj$",
36+
lm_head="lm_head",
37+
)
38+
39+
40+
SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = {
41+
"LlamaForCausalLM": _default_mappings,
42+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Dict, List
2+
3+
from pydantic import BaseModel, field_validator
4+
5+
6+
class NormMapping(BaseModel):
7+
norm: str
8+
linears: List[str]
9+
10+
@field_validator("linears", mode="before")
11+
def cast_to_list(cls, value):
12+
if isinstance(value, str):
13+
return [value]
14+
15+
return value
16+
17+
18+
_default_norm_mappings = [
19+
NormMapping(
20+
norm="re:.*input_layernorm$",
21+
linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
22+
),
23+
NormMapping(
24+
norm="re:.*post_attention_layernorm$",
25+
linears=["re:.*up_proj$", "re:.*gate_proj$"],
26+
),
27+
NormMapping(
28+
norm="model.norm",
29+
linears=["lm_head"],
30+
),
31+
]
32+
33+
NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = {
34+
"LlamaForCausalLM": _default_norm_mappings,
35+
}

0 commit comments

Comments
 (0)