22
33import torch
44import torch .utils .data .dataloader
5- from compressed_tensors .utils import get_execution_device
65from tqdm import tqdm
6+ from transformers import PreTrainedModel
77
88from llmcompressor .modifiers .utils .hooks import HooksMixin
99from llmcompressor .pipelines .cache import IntermediatesCache
10- from llmcompressor .pipelines .sequential .helpers import trace_subgraphs
10+ from llmcompressor .pipelines .sequential .helpers import (
11+ infer_oneshot_device ,
12+ trace_subgraphs ,
13+ )
1114from llmcompressor .utils .helpers import align_modules , calibration_forward_context
1215
1316if TYPE_CHECKING :
1720
1821
1922def run_pipeline (
20- model : torch . nn . Module ,
23+ model : PreTrainedModel ,
2124 dataloader : torch .utils .data .DataLoader ,
2225 sequential_targets : List [str ],
2326 ignore : List [str ],
27+ oneshot_device : Optional [torch .device ],
2428 callback_modifier : Optional ["Modifier" ] = None ,
2529):
2630 """
@@ -45,16 +49,22 @@ def run_pipeline(
4549 :param dataloader: loads data for calibration
4650 :param sequential_targets: patterns which match to the layer modules of the model
4751 :param ignore: patterns which match to modules which should be ignored by tracing
52+ :param oneshot_device: device to onload layers ontop, uses device_map if None
53+ :param callback_modifier: Temporary HACK which should be replaced by event callback
4854 """
55+ # if the model is dispatched, use the dispatch to determine onloading, return None
56+ # otherwise, infer a oneshot device (either user passed or the first available gpu)
57+ oneshot_device = infer_oneshot_device (model , oneshot_device )
58+
4959 # trace subgraphs
5060 sample_input = next (iter (dataloader ))
5161 subgraphs = trace_subgraphs (model , sample_input , sequential_targets , ignore )
5262
53- with calibration_forward_context (model ):
54- # prepare intermediates cache
55- model_device = get_execution_device (model )
56- intermediates = IntermediatesCache .from_dataloader (dataloader , model_device )
63+ # prepare intermediates cache
64+ model_device = oneshot_device or model .device
65+ intermediates = IntermediatesCache .from_dataloader (dataloader , model_device )
5766
67+ with calibration_forward_context (model ):
5868 num_subgraphs = len (subgraphs )
5969 for subgraph_index , subgraph in enumerate (subgraphs ):
6070 # prepare tqdm description texts
@@ -64,7 +74,7 @@ def run_pipeline(
6474 # compile subgraph forward function
6575 forward_function = subgraph .compile_forward ()
6676
67- with align_modules (subgraph .modules ):
77+ with align_modules (subgraph .modules , oneshot_device ):
6878 # do an preliminary pass to trigger modifier hooks
6979 for batch_index in tqdm (range (len (dataloader )), desc = calib_desc ):
7080 inputs = intermediates .fetch (batch_index , subgraph .input_names )
0 commit comments