From af6034084b77de28675bb78d9830163becd4688a Mon Sep 17 00:00:00 2001 From: Ayush Date: Thu, 19 Mar 2026 21:59:05 -0500 Subject: [PATCH 1/2] feat: added targets_per_subgraph to allow multiple sequential targets per subgraph Signed-off-by: Ayush --- src/llmcompressor/args/dataset_arguments.py | 7 +++++++ .../pipelines/sequential/helpers.py | 16 ++++++++++++---- .../pipelines/sequential/pipeline.py | 7 ++++++- src/llmcompressor/transformers/tracing/debug.py | 6 +++++- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 60705744e2..b3fef53efe 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -242,6 +242,13 @@ class DatasetArguments(CustomDatasetArguments): "than one gpu. Default is cpu." }, ) + sequential_targets_per_subgraph: int = field( + default=1, + metadata={ + "help": "Number of sequential targets to include per subgraph. " + "Higher values use more VRAM but are faster. Default is 1." + }, + ) quantization_aware_calibration: bool = field( default=True, metadata={ diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index fbbd7c9d51..70b796bf73 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -87,6 +87,7 @@ def trace_subgraphs( sample_input: dict[str, Any], sequential_targets: list[str], ignore: list[str], + targets_per_subgraph: int = 1 ) -> list[Subgraph]: """ Trace a model to produce subgraphs, where each sequential target belongs to exactly @@ -98,6 +99,8 @@ def trace_subgraphs( __len__, __bool__, and __contains__ values are assumed constant across batches :param sequential_targets: list of patterns matching sequential targets :param ignore: function and method names to skip during tracing + :param targets_per_subgraph: number of targets to include per subgraph + :return: a list of Subgraphs in order of execution """ # find modules @@ -152,7 +155,7 @@ def trace_subgraphs( graph.device = model.device # perform subgraph partition - partitions = topological_partition(graph, targets) + partitions = topological_partition(graph, targets, targets_per_subgraph) subgraphs = partition_graph(model, partitions) trace_consumed_names(subgraphs) @@ -260,7 +263,7 @@ def find_target_nodes(graph: GraphModule, targets: set[Module]) -> set[Node]: ) -def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list[Node]]: +def topological_partition(graph: GraphModule, targets: set[Module], targets_per_subgraph: int = 1) -> list[list[Node]]: """ Partition the graph into partitions such that each `target` belongs to exactly one partition and executing each partition depends only on intermediate values produced @@ -268,6 +271,7 @@ def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list :param graph: graph being partitioned :param targets: target modules which will be assigned to disjoint partitions + :param targets_per_subgraph: number of targets to include per subgraph :return: list of partitions, where each partition is a list of nodes belonging to that partition """ @@ -280,6 +284,7 @@ def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list for node in graph.graph.nodes } partition_index = 0 # global counter + targets_seen = 0 # global counter # start with graph input nodes, # but delay the `get_attr` nodes as long as possible @@ -296,8 +301,11 @@ def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list # guarantee targets are assigned to disjoint partitions if node in target_nodes: - partition_index += 1 - partitions.append([]) + targets_seen += 1 + + if(targets_seen % targets_per_subgraph == 0): + partition_index += 1 + partitions.append([]) # recurse on last indegree only in order to guarantee that # the node is assigned to maximal partition diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 600ba1061b..ad22a74539 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -98,7 +98,12 @@ def __call__( # trace subgraphs sample_input = next(iter(dataloader)) - subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) + subgraphs = trace_subgraphs( + model, + sample_input, + sequential_targets, + ignore, + dataset_args.sequential_targets_per_subgraph) num_subgraphs = len(subgraphs) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index 0f6213815c..79056b74db 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -27,6 +27,7 @@ def parse_args(): parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501 parser.add_argument("--skip_weights", type=bool, default=True, help="Whether to load the model with dummy weights") # noqa: E501 parser.add_argument("--device_map", type=str, default="cpu", help="Device to load model and inputs onto") # noqa: E501 + parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501 return parser.parse_args() @@ -39,6 +40,7 @@ def trace( trust_remote_code: bool = True, skip_weights: bool = True, device_map: str | dict = "cpu", + targets_per_subgraph: int = 1 ) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]: """ Debug traceability by tracing a pre-trained model into subgraphs @@ -51,6 +53,7 @@ def trace( :param ignore: patterns to ignore during tracing :param modality: data modality for dummy tracing data, defaults to 'text' :param trust_remote_code: trust remote model code + :param targets_per_subgraph: number of targets to include per subgraph Example usage from CLI llmcompressor.trace \ @@ -103,7 +106,7 @@ def trace( f" ignore={dataset_args.tracing_ignore}\n" ) subgraphs = trace_subgraphs( - model, sample, sequential_targets, dataset_args.tracing_ignore + model, sample, sequential_targets, dataset_args.tracing_ignore, targets_per_subgraph ) print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n") @@ -165,6 +168,7 @@ def main(): trust_remote_code=args.trust_remote_code, skip_weights=args.skip_weights, device_map=args.device_map, + targets_per_subgraph=DatasetArguments.sequential_targets_per_subgraph ) From 50e8b9a9fbe8dfacdd3ad5d5baf8a3167372a36f Mon Sep 17 00:00:00 2001 From: Ayush Date: Thu, 19 Mar 2026 22:44:35 -0500 Subject: [PATCH 2/2] feat: add input validation for targets_per_subgraph Signed-off-by: Ayush --- src/llmcompressor/pipelines/sequential/helpers.py | 10 ++++++---- src/llmcompressor/pipelines/sequential/pipeline.py | 3 ++- src/llmcompressor/transformers/tracing/debug.py | 10 +++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 70b796bf73..18a75073a4 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -100,7 +100,6 @@ def trace_subgraphs( :param sequential_targets: list of patterns matching sequential targets :param ignore: function and method names to skip during tracing :param targets_per_subgraph: number of targets to include per subgraph - :return: a list of Subgraphs in order of execution """ # find modules @@ -277,6 +276,9 @@ def topological_partition(graph: GraphModule, targets: set[Module], targets_per_ """ assert graph_is_well_formed(graph.graph) target_nodes = find_target_nodes(graph, targets) + + if(targets_per_subgraph <= 0): + raise ValueError("targets_per_subgraph is required to be greater than or equal to one") partitions: list[list[Node]] = [[]] remaining_indegrees = { @@ -284,8 +286,8 @@ def topological_partition(graph: GraphModule, targets: set[Module], targets_per_ for node in graph.graph.nodes } partition_index = 0 # global counter - targets_seen = 0 # global counter - + targets_seen = 0 # number of targets encountered so far + # start with graph input nodes, # but delay the `get_attr` nodes as long as possible queue = deque( @@ -303,7 +305,7 @@ def topological_partition(graph: GraphModule, targets: set[Module], targets_per_ if node in target_nodes: targets_seen += 1 - if(targets_seen % targets_per_subgraph == 0): + if targets_seen % targets_per_subgraph == 0: partition_index += 1 partitions.append([]) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index ad22a74539..8a560a6611 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -103,7 +103,8 @@ def __call__( sample_input, sequential_targets, ignore, - dataset_args.sequential_targets_per_subgraph) + dataset_args.sequential_targets_per_subgraph, + ) num_subgraphs = len(subgraphs) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index 79056b74db..031c7c6862 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -27,7 +27,7 @@ def parse_args(): parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501 parser.add_argument("--skip_weights", type=bool, default=True, help="Whether to load the model with dummy weights") # noqa: E501 parser.add_argument("--device_map", type=str, default="cpu", help="Device to load model and inputs onto") # noqa: E501 - parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501 + parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501 return parser.parse_args() @@ -106,7 +106,11 @@ def trace( f" ignore={dataset_args.tracing_ignore}\n" ) subgraphs = trace_subgraphs( - model, sample, sequential_targets, dataset_args.tracing_ignore, targets_per_subgraph + model, + sample, + sequential_targets, + dataset_args.tracing_ignore, + targets_per_subgraph ) print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n") @@ -168,7 +172,7 @@ def main(): trust_remote_code=args.trust_remote_code, skip_weights=args.skip_weights, device_map=args.device_map, - targets_per_subgraph=DatasetArguments.sequential_targets_per_subgraph + targets_per_subgraph=args.sequential_targets_per_subgraph )