|
2 | 2 | import inspect
|
3 | 3 | from collections import deque
|
4 | 4 | from dataclasses import dataclass
|
5 |
| -from typing import Any, Dict, List, Optional, Set |
| 5 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | from compressed_tensors import has_offloaded_params
|
|
23 | 23 |
|
24 | 24 | from .ast_helpers import autowrap_forwards
|
25 | 25 |
|
26 |
| -__all__ = ["trace_subgraphs", "Subgraph", "get_targets_from_modifiers"] |
| 26 | +if TYPE_CHECKING: |
| 27 | + from llmcompressor.args.dataset_arguments import DatasetArguments |
| 28 | + |
| 29 | +__all__ = ["trace_subgraphs", "Subgraph", "get_sequential_targets"] |
27 | 30 |
|
28 | 31 |
|
29 | 32 | @dataclass
|
@@ -416,44 +419,59 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
|
416 | 419 | )
|
417 | 420 |
|
418 | 421 |
|
419 |
| -def get_targets_from_modifiers( |
420 |
| - modifiers: List[Modifier], model: PreTrainedModel |
| 422 | +def get_sequential_targets( |
| 423 | + modifiers: List[Modifier], model: PreTrainedModel, args: "DatasetArguments" |
421 | 424 | ) -> List[str]:
|
422 | 425 | """
|
423 |
| - Infer sequential targets from modifiers list |
| 426 | + Infer sequential targets from modifiers list and dataset args |
424 | 427 |
|
425 | 428 | :param model: model being calibrated
|
426 | 429 | :param modifiers: list of modifiers being applied during calibration
|
| 430 | + :param dataset_args: dataset arguments passed by user |
427 | 431 | :return: list of sequential targets
|
428 | 432 | """
|
429 |
| - # avoid circular import |
430 |
| - from llmcompressor.pipelines.registry import SEQUENTIAL_MODIFIERS |
431 |
| - |
432 |
| - sequential_modifiers = [ |
433 |
| - modifier for modifier in modifiers if isinstance(modifier, SEQUENTIAL_MODIFIERS) |
| 433 | + modifier_targets = [ |
| 434 | + (modifier, modifier.sequential_targets) |
| 435 | + for modifier in modifiers |
| 436 | + if getattr(modifier, "sequential_targets", None) is not None |
434 | 437 | ]
|
435 | 438 |
|
436 |
| - if len(sequential_modifiers) >= 2: |
437 |
| - types = [type(modifier) for modifier in sequential_modifiers] |
| 439 | + # deprecation warning |
| 440 | + if len(modifier_targets) > 1: |
438 | 441 | logger.warning(
|
| 442 | + "Passing sequential targets through modifiers is deprecated, " |
| 443 | + "please use `oneshot(sequential_targets=...)`" |
| 444 | + ) |
| 445 | + |
| 446 | + # cannot infer from multiple modifiers |
| 447 | + if len(modifier_targets) >= 2: |
| 448 | + types = [type(modifier) for modifier, _ in modifier_targets] |
| 449 | + raise ValueError( |
439 | 450 | "Cannot infer sequential targets from multiple sequential modifiers "
|
440 |
| - f"({types}). Defaulting to {types[0]}" |
| 451 | + f"({types})" |
441 | 452 | )
|
442 |
| - elif len(sequential_modifiers) <= 0: |
443 |
| - types = [type(modifier) for modifier in modifiers] |
444 |
| - raise ValueError(f"Cannot infer sequential targets from list of {types}") |
445 | 453 |
|
446 |
| - modifier = sequential_modifiers[0] |
| 454 | + # resolve single modifier |
| 455 | + if len(modifier_targets) == 1: |
| 456 | + if args.sequential_targets is not None: |
| 457 | + raise ValueError( |
| 458 | + f"Got sequential targets from both {type(modifier_targets[0][0])} " |
| 459 | + "and dataset arguments `sequential_targets`" |
| 460 | + ) |
| 461 | + |
| 462 | + sequential_targets = modifier_targets[0][1] |
447 | 463 |
|
448 |
| - # infer sequential targets |
449 |
| - if modifier.sequential_targets is None: |
450 |
| - sequential_targets = get_no_split_params(model) |
451 |
| - elif isinstance(modifier.sequential_targets, str): |
452 |
| - sequential_targets = [modifier.sequential_targets] |
| 464 | + # if no modifiers, use data args |
453 | 465 | else:
|
454 |
| - sequential_targets = modifier.sequential_targets |
| 466 | + sequential_targets = args.sequential_targets # may be `None` |
455 | 467 |
|
456 |
| - return sequential_targets |
| 468 | + # validate and infer |
| 469 | + if sequential_targets is None: |
| 470 | + return get_no_split_params(model) |
| 471 | + elif isinstance(sequential_targets, str): |
| 472 | + return [sequential_targets] |
| 473 | + else: |
| 474 | + return sequential_targets |
457 | 475 |
|
458 | 476 |
|
459 | 477 | def add_line_numbers(text: str) -> str:
|
|
0 commit comments