Skip to content

Commit ecb7723

Browse files
committed
feat: add input validation for targets_per_subgraph
1 parent 21f185a commit ecb7723

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def trace_subgraphs(
100100
:param sequential_targets: list of patterns matching sequential targets
101101
:param ignore: function and method names to skip during tracing
102102
:param targets_per_subgraph: number of targets to include per subgraph
103-
104103
:return: a list of Subgraphs in order of execution
105104
"""
106105
# find modules
@@ -277,15 +276,18 @@ def topological_partition(graph: GraphModule, targets: set[Module], targets_per_
277276
"""
278277
assert graph_is_well_formed(graph.graph)
279278
target_nodes = find_target_nodes(graph, targets)
279+
280+
if(targets_per_subgraph <= 0):
281+
raise ValueError("targets_per_subgraph is required to be greater than or equal to one")
280282

281283
partitions: list[list[Node]] = [[]]
282284
remaining_indegrees = {
283285
node: len([node for node in node.all_input_nodes if node.op != "get_attr"])
284286
for node in graph.graph.nodes
285287
}
286288
partition_index = 0 # global counter
287-
targets_seen = 0 # global counter
288-
289+
targets_seen = 0 # number of targets encountered so far
290+
289291
# start with graph input nodes,
290292
# but delay the `get_attr` nodes as long as possible
291293
queue = deque(
@@ -303,7 +305,7 @@ def topological_partition(graph: GraphModule, targets: set[Module], targets_per_
303305
if node in target_nodes:
304306
targets_seen += 1
305307

306-
if(targets_seen % targets_per_subgraph == 0):
308+
if targets_seen % targets_per_subgraph == 0:
307309
partition_index += 1
308310
partitions.append([])
309311

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def __call__(
103103
sample_input,
104104
sequential_targets,
105105
ignore,
106-
dataset_args.sequential_targets_per_subgraph)
106+
dataset_args.sequential_targets_per_subgraph,
107+
)
107108
num_subgraphs = len(subgraphs)
108109

109110
LifecycleCallbacks.calibration_epoch_start()

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def parse_args():
2727
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
2828
parser.add_argument("--skip_weights", type=bool, default=True, help="Whether to load the model with dummy weights") # noqa: E501
2929
parser.add_argument("--device_map", type=str, default="cpu", help="Device to load model and inputs onto") # noqa: E501
30-
parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501
30+
parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501
3131
return parser.parse_args()
3232

3333

@@ -106,7 +106,11 @@ def trace(
106106
f" ignore={dataset_args.tracing_ignore}\n"
107107
)
108108
subgraphs = trace_subgraphs(
109-
model, sample, sequential_targets, dataset_args.tracing_ignore, targets_per_subgraph
109+
model,
110+
sample,
111+
sequential_targets,
112+
dataset_args.tracing_ignore,
113+
targets_per_subgraph
110114
)
111115
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")
112116

@@ -168,7 +172,7 @@ def main():
168172
trust_remote_code=args.trust_remote_code,
169173
skip_weights=args.skip_weights,
170174
device_map=args.device_map,
171-
targets_per_subgraph=DatasetArguments.sequential_targets_per_subgraph
175+
targets_per_subgraph=args.sequential_targets_per_subgraph
172176
)
173177

174178

0 commit comments

Comments
 (0)