Skip to content

Commit a64a777

Browse files
committed
use remove_dispatch
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 06bb661 commit a64a777

File tree

4 files changed

+15
-12
lines changed

4 files changed

+15
-12
lines changed

src/llmcompressor/entrypoints/README.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ from llmcompressor.modifiers.quantization import QuantizationModifier
2929
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
3030

3131
# Load the model
32-
model = AutoModelForCausalLM.from_pretrained(
33-
MODEL_ID, device_map="auto", torch_dtype="auto"
34-
)
32+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
3533
# Load the tokenizer
3634
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
3735

@@ -204,9 +202,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
204202
oneshot_output_dir = "./oneshot_model"
205203

206204
# Load the model
207-
model = AutoModelForCausalLM.from_pretrained(
208-
MODEL_ID, device_map="auto", torch_dtype="auto"
209-
)
205+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
210206
# Load the tokenizer
211207
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
212208

@@ -226,7 +222,6 @@ from llmcompressor import create_session, train
226222
# Student model
227223
model = AutoModelForCausalLM.from_pretrained(
228224
oneshot_output_dir,
229-
device_map="auto",
230225
quantization_config=CompressedTensorsConfig(run_compressed=False),
231226
)
232227

src/llmcompressor/entrypoints/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import PosixPath
44
from typing import Optional, Tuple
55

6+
from compressed_tensors.utils import remove_dispatch
67
from loguru import logger
78
from torch.nn import Module
89
from transformers import (
@@ -84,6 +85,9 @@ def post_process(
8485
Raises:
8586
ValueError: If saving fails due to an invalid `output_dir` or other issues.
8687
"""
88+
# remove any existing dispatches
89+
remove_dispatch(model_args.model)
90+
8791
if model_args is not None and output_dir is not None:
8892
if recipe_args is not None and getattr(recipe_args, "stage", None) is not None:
8993
output_dir = os.path.join(output_dir, recipe_args.stage)

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
66

77
import torch
8-
from accelerate.hooks import remove_hook_from_module
98
from compressed_tensors.quantization import find_name_or_class_matches
10-
from compressed_tensors.utils import has_offloaded_params, offloaded_dispatch
9+
from compressed_tensors.utils import (
10+
has_offloaded_params,
11+
offloaded_dispatch,
12+
remove_dispatch,
13+
)
1114
from loguru import logger
1215
from torch.fx import Graph, GraphModule, Node
1316
from torch.fx.graph import PythonCode
@@ -520,7 +523,7 @@ def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
520523
:param model: model to dispatch
521524
:return: dispatched model
522525
"""
523-
remove_hook_from_module(model, recurse=True)
526+
remove_dispatch(model)
524527

525528
if torch.cuda.is_available():
526529
offloaded_dispatch(model, execution_device=torch.device("cuda:0"))

src/llmcompressor/utils/dev.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import torch
88
from accelerate import dispatch_model, infer_auto_device_map
9-
from accelerate.hooks import remove_hook_from_module
109
from accelerate.utils import get_balanced_memory
10+
from compressed_tensors.utils import remove_dispatch
1111
from huggingface_hub import snapshot_download
1212
from safetensors.torch import save_file
1313
from transformers import AutoModelForCausalLM, PreTrainedModel
@@ -124,7 +124,8 @@ def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
124124
:param model: model to dispatch
125125
:return: model which is dispatched
126126
"""
127-
remove_hook_from_module(model, recurse=True)
127+
remove_dispatch(model)
128+
128129
max_memory = get_balanced_memory(
129130
model,
130131
dtype=model.dtype,

0 commit comments

Comments
 (0)