Skip to content

Commit 7d7b00d

Browse files
committed
remove device arg from e2e
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8f71004 commit 7d7b00d

File tree

4 files changed

+7
-18
lines changed

4 files changed

+7
-18
lines changed

src/llmcompressor/entrypoints/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ output_dir = "./oneshot_model"
112112
# The model to train
113113
model = AutoModelForCausalLM.from_pretrained(
114114
output_dir,
115-
device_map="auto",
116115
quantization_config=CompressedTensorsConfig(run_compressed=False),
117116
)
118117

@@ -146,7 +145,6 @@ Comparisons are defined in `/src/llmcompressor/modifiers/distillation/utils/pyto
146145
# Define the teacher model
147146
distill_teacher = AutoModelForCausalLM.from_pretrained(
148147
"meta-llama/Meta-Llama-3-8B-Instruct",
149-
device_map="auto",
150148
)
151149

152150
# Define the recipe, use knowledge distillation modifier and target the `model.layers` using a regex with
@@ -236,7 +234,6 @@ num_calibration_samples = 8 # The number of workers processing datasets in para
236234
# Define teacher model
237235
distill_teacher = AutoModelForCausalLM.from_pretrained(
238236
"meta-llama/Meta-Llama-3-8B-Instruct",
239-
device_map="auto",
240237
)
241238

242239
# Define the recipe, use knowledge distillation modifier and target the `model.layers` using a regex with

src/llmcompressor/entrypoints/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def post_process(
8686
ValueError: If saving fails due to an invalid `output_dir` or other issues.
8787
"""
8888
# remove any existing dispatches
89-
remove_dispatch(model_args.model)
89+
if model_args is not None and model_args.model is not None:
90+
remove_dispatch(model_args.model)
9091

9192
if model_args is not None and output_dir is not None:
9293
if recipe_args is not None and getattr(recipe_args, "stage", None) is not None:

tests/e2e/e2e_utils.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,21 @@
1414
def _load_model_and_processor(
1515
model: str,
1616
model_class: str,
17-
device: str,
1817
):
1918
pretrained_model_class = getattr(transformers, model_class)
20-
loaded_model = pretrained_model_class.from_pretrained(
21-
model, device_map=device, torch_dtype="auto"
22-
)
19+
loaded_model = pretrained_model_class.from_pretrained(model, torch_dtype="auto")
2320
processor = AutoProcessor.from_pretrained(model)
2421
return loaded_model, processor
2522

2623

2724
@log_time
28-
def _run_oneshot(device: str, **oneshot_kwargs):
29-
oneshot(
30-
**oneshot_kwargs,
31-
oneshot_device=device,
32-
)
25+
def _run_oneshot(**oneshot_kwargs):
26+
oneshot(**oneshot_kwargs)
3327

3428

3529
def run_oneshot_for_e2e_testing(
3630
model: str,
3731
model_class: str,
38-
device: str,
3932
num_calibration_samples: int,
4033
max_seq_length: int,
4134
dataset_id: str,
@@ -49,7 +42,7 @@ def run_oneshot_for_e2e_testing(
4942
oneshot_kwargs = {}
5043

5144
loaded_model, processor = _load_model_and_processor(
52-
model=model, model_class=model_class, device=device
45+
model=model, model_class=model_class
5346
)
5447

5548
if dataset_id:
@@ -86,6 +79,6 @@ def data_collator(batch):
8679

8780
# Apply quantization.
8881
logger.info("ONESHOT KWARGS", oneshot_kwargs)
89-
_run_oneshot(device=device, **oneshot_kwargs)
82+
_run_oneshot(**oneshot_kwargs)
9083

9184
return oneshot_kwargs["model"], processor

tests/e2e/vLLM/test_vllm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def set_up(self, test_data_file: str):
8888
logger.info("========== RUNNING ==============")
8989
logger.info(self.save_dir)
9090

91-
self.device = "cuda:0"
9291
self.prompts = [
9392
"The capital of France is",
9493
"The president of the US is",
@@ -105,7 +104,6 @@ def test_vllm(self, test_data_file: str):
105104
oneshot_model, tokenizer = run_oneshot_for_e2e_testing(
106105
model=self.model,
107106
model_class=self.model_class,
108-
device=self.device,
109107
num_calibration_samples=self.num_calibration_samples,
110108
max_seq_length=self.max_seq_length,
111109
scheme=self.scheme,

0 commit comments

Comments
 (0)