Skip to content

Commit 972f59f

Browse files
committed
add targets field
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7d34cca commit 972f59f

File tree

1 file changed

+5
-4
lines changed
  • src/llmcompressor/modifiers/transform/quip

1 file changed

+5
-4
lines changed

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class QuIPModifier(Modifier):
4747
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
4848
default="random-hadamard"
4949
)
50+
targets: Union[List[str], str] = Field(default="str")
5051
randomize: bool = Field(default=False)
5152
learnable: bool = Field(default=False)
5253
precision: TorchDtype = Field(default=torch.float64)
@@ -102,12 +103,12 @@ def _create_config(self) -> TransformConfig:
102103
type=self.transform_type,
103104
apply=[
104105
TransformArgs(
105-
targets=["Linear"],
106+
targets=self.targets,
106107
location="input", # non-mergable
107108
ignore=self.ignore,
108109
),
109110
TransformArgs(
110-
targets=["Linear"],
111+
targets=self.targets,
111112
location="weight_input",
112113
inverse=True,
113114
ignore=self.ignore,
@@ -121,12 +122,12 @@ def _create_config(self) -> TransformConfig:
121122
type=self.transform_type,
122123
apply=[
123124
TransformArgs(
124-
targets=["Linear"],
125+
targets=self.targets,
125126
location="weight_output",
126127
ignore=self.ignore,
127128
),
128129
TransformArgs(
129-
targets=["Linear"],
130+
targets=self.targets,
130131
location="output", # non-mergable
131132
inverse=True,
132133
ignore=self.ignore,

0 commit comments

Comments
 (0)