Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,13 @@ class DatasetArguments(CustomDatasetArguments):
"than one gpu. Default is cpu."
},
)
sequential_targets_per_subgraph: int = field(
default=1,
metadata={
"help": "Number of sequential targets to include per subgraph. "
"Higher values use more VRAM but are faster. Default is 1."
},
)
quantization_aware_calibration: bool = field(
default=True,
metadata={
Expand Down
20 changes: 15 additions & 5 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def trace_subgraphs(
sample_input: dict[str, Any],
sequential_targets: list[str],
ignore: list[str],
targets_per_subgraph: int = 1
) -> list[Subgraph]:
"""
Trace a model to produce subgraphs, where each sequential target belongs to exactly
Expand All @@ -98,6 +99,7 @@ def trace_subgraphs(
__len__, __bool__, and __contains__ values are assumed constant across batches
:param sequential_targets: list of patterns matching sequential targets
:param ignore: function and method names to skip during tracing
:param targets_per_subgraph: number of targets to include per subgraph
:return: a list of Subgraphs in order of execution
"""
# find modules
Expand Down Expand Up @@ -152,7 +154,7 @@ def trace_subgraphs(
graph.device = model.device

# perform subgraph partition
partitions = topological_partition(graph, targets)
partitions = topological_partition(graph, targets, targets_per_subgraph)
subgraphs = partition_graph(model, partitions)
trace_consumed_names(subgraphs)

Expand Down Expand Up @@ -260,27 +262,32 @@ def find_target_nodes(graph: GraphModule, targets: set[Module]) -> set[Node]:
)


def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list[Node]]:
def topological_partition(graph: GraphModule, targets: set[Module], targets_per_subgraph: int = 1) -> list[list[Node]]:
"""
Partition the graph into partitions such that each `target` belongs to exactly one
partition and executing each partition depends only on intermediate values produced
by executing the partitions before it.

:param graph: graph being partitioned
:param targets: target modules which will be assigned to disjoint partitions
:param targets_per_subgraph: number of targets to include per subgraph
:return: list of partitions, where each partition is a list of nodes belonging to
that partition
"""
assert graph_is_well_formed(graph.graph)
target_nodes = find_target_nodes(graph, targets)

if(targets_per_subgraph <= 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(targets_per_subgraph <= 0):
if targets_per_subgraph <= 0:

raise ValueError("targets_per_subgraph is required to be greater than or equal to one")

partitions: list[list[Node]] = [[]]
remaining_indegrees = {
node: len([node for node in node.all_input_nodes if node.op != "get_attr"])
for node in graph.graph.nodes
}
partition_index = 0 # global counter

targets_seen = 0 # number of targets encountered so far
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
targets_seen = 0 # number of targets encountered so far
targets_seen = 0 # number of targets encountered so far


# start with graph input nodes,
# but delay the `get_attr` nodes as long as possible
queue = deque(
Expand All @@ -296,8 +303,11 @@ def topological_partition(graph: GraphModule, targets: set[Module]) -> list[list

# guarantee targets are assigned to disjoint partitions
if node in target_nodes:
partition_index += 1
partitions.append([])
targets_seen += 1

if targets_seen % targets_per_subgraph == 0:
partition_index += 1
partitions.append([])

# recurse on last indegree only in order to guarantee that
# the node is assigned to maximal partition
Expand Down
8 changes: 7 additions & 1 deletion src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ def __call__(

# trace subgraphs
sample_input = next(iter(dataloader))
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
subgraphs = trace_subgraphs(
model,
sample_input,
sequential_targets,
ignore,
dataset_args.sequential_targets_per_subgraph,
)
num_subgraphs = len(subgraphs)

LifecycleCallbacks.calibration_epoch_start()
Expand Down
10 changes: 9 additions & 1 deletion src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def parse_args():
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
parser.add_argument("--skip_weights", type=bool, default=True, help="Whether to load the model with dummy weights") # noqa: E501
parser.add_argument("--device_map", type=str, default="cpu", help="Device to load model and inputs onto") # noqa: E501
parser.add_argument("--targets_per_subgraph", type=int, default=1, help="Number of sequential targets to include per subgraph") # noqa: E501
return parser.parse_args()


Expand All @@ -39,6 +40,7 @@ def trace(
trust_remote_code: bool = True,
skip_weights: bool = True,
device_map: str | dict = "cpu",
targets_per_subgraph: int = 1
) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]:
"""
Debug traceability by tracing a pre-trained model into subgraphs
Expand All @@ -51,6 +53,7 @@ def trace(
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'
:param trust_remote_code: trust remote model code
:param targets_per_subgraph: number of targets to include per subgraph

Example usage from CLI
llmcompressor.trace \
Expand Down Expand Up @@ -103,7 +106,11 @@ def trace(
f" ignore={dataset_args.tracing_ignore}\n"
)
subgraphs = trace_subgraphs(
model, sample, sequential_targets, dataset_args.tracing_ignore
model,
sample,
sequential_targets,
dataset_args.tracing_ignore,
targets_per_subgraph
)
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")

Expand Down Expand Up @@ -165,6 +172,7 @@ def main():
trust_remote_code=args.trust_remote_code,
skip_weights=args.skip_weights,
device_map=args.device_map,
targets_per_subgraph=args.sequential_targets_per_subgraph
)


Expand Down
Loading