|
40 | 40 | parser.add_argument("-img", "--image_dump", action = "store_true", help = "Save model tensors as images (saved to working directory)") |
41 | 41 | parser.add_argument("-mcg", "--mcg_multiplier", type = str, default = None, help = "MCG multiplier - EXPERIMENTAL, DO NOT USE") |
42 | 42 | parser.add_argument("-mul1", "--mul1_multiplier", type = str, default = None, help = "MUL1 multiplier - EXPERIMENTAL, DO NOT USE") |
| 43 | +parser.add_argument("-strat", "--strategy", type = str, default = None, help = "Modifiers for quantization strategy - EXPERIMENTAL") |
43 | 44 |
|
44 | 45 | group = parser.add_mutually_exclusive_group() |
45 | 46 | group.add_argument("--out_scales", dest = "out_scales_", action = "store_true", help = "Always enable out channel scales (for debug purposes)") |
@@ -154,6 +155,7 @@ def override(arg, can_override, default): |
154 | 155 | ("device_ratios", True, None), |
155 | 156 | ("mcg_multiplier", True, ""), |
156 | 157 | ("mul1_multiplier", True, ""), |
| 158 | + ("strategy", False, ""), |
157 | 159 | ]: |
158 | 160 | override(arg_, can_override if not args.override_anyway else True, default) |
159 | 161 |
|
@@ -233,6 +235,32 @@ def get_state_error(x, ref): |
233 | 235 | return err.item(), cos, sq |
234 | 236 |
|
235 | 237 |
|
| 238 | +def mod_strategy(args, module, strategy, idx): |
| 239 | + mod_arg = args.get("strategy") |
| 240 | + if not mod_arg: |
| 241 | + return strategy |
| 242 | + |
| 243 | + s_layers = [""] + mod_arg.split(";") |
| 244 | + if idx >= len(s_layers): |
| 245 | + return strategy |
| 246 | + |
| 247 | + s = s_layers[idx] |
| 248 | + mod = {} |
| 249 | + while s: |
| 250 | + l, m = s[0], s[1] |
| 251 | + s = s[2:] |
| 252 | + mod[l] = int(m) |
| 253 | + |
| 254 | + new_strategy = {} |
| 255 | + for key, bits in strategy.items(): |
| 256 | + submodule = module.find_module(key) |
| 257 | + modifier = mod.get(submodule.qbits_mod_key, 0) |
| 258 | + new_strategy[key] = min(bits + modifier, 8) |
| 259 | + |
| 260 | + # TODO: Automate this, also calculate overall increase in bitrate, track in job.json across resumes |
| 261 | + return new_strategy |
| 262 | + |
| 263 | + |
236 | 264 | @torch.inference_mode() |
237 | 265 | def main(args, job_state): |
238 | 266 |
|
@@ -281,6 +309,7 @@ def main(args, job_state): |
281 | 309 | }, |
282 | 310 | job_state["surplus_bits"], |
283 | 311 | ) |
| 312 | + strategy = mod_strategy(args, module, strategy, idx) |
284 | 313 | job_state["surplus_bits"] = surplus |
285 | 314 |
|
286 | 315 | # Slice module if necessary |
|
0 commit comments