Skip to content

Commit 74150cb

Browse files
author
George
authored
[Oneshot refactor] Refactor initialize_model_from_path (#1109)
ORDER OF REVIEWS: 1. #1108 2. #1103 3. #1109 <- current PR 4. #1110 SUMMARY: Refactor `initialize_model_from_path` to decouple `training_args` dependent logic and oneshot (non-training_args) logic. TEST PLAN: * Pass all tests * search `initialize_model_from_path` using `grep`
1 parent b55ec42 commit 74150cb

File tree

1 file changed

+54
-42
lines changed

1 file changed

+54
-42
lines changed

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@ def parse_args(**kwargs):
167167

168168
def initialize_model_from_path(
169169
model_args: ModelArguments,
170-
training_args: TrainingArguments,
170+
training_args: Optional[TrainingArguments] = None,
171171
):
172-
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
173172
# Load pretrained model
174173
# The .from_pretrained methods guarantee that only one local process can
175174
# concurrently download model & vocab.
@@ -182,38 +181,70 @@ def initialize_model_from_path(
182181
tie_word_embeddings=model_args.tie_word_embeddings,
183182
trust_remote_code=model_args.trust_remote_code_model,
184183
)
185-
teacher_config = (
186-
AutoConfig.from_pretrained(
187-
model_args.distill_teacher,
188-
use_auth_token=True if model_args.use_auth_token else None,
189-
tie_word_embeddings=model_args.tie_word_embeddings,
190-
trust_remote_code=model_args.trust_remote_code_model,
184+
185+
last_checkpoint = None
186+
teacher = None
187+
188+
if training_args is not None:
189+
# Load teacher configuration if applicable
190+
teacher_config = (
191+
AutoConfig.from_pretrained(
192+
model_args.distill_teacher,
193+
use_auth_token=True if model_args.use_auth_token else None,
194+
tie_word_embeddings=model_args.tie_word_embeddings,
195+
trust_remote_code=model_args.trust_remote_code_model,
196+
)
197+
if model_args.distill_teacher
198+
else None
191199
)
192-
if model_args.distill_teacher
193-
else None
194-
)
200+
201+
# Detect last checkpoint
202+
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
203+
204+
# Set seed before initializing model
205+
set_seed(training_args.seed)
206+
207+
# Initialize teacher model if teacher path is provided
208+
if model_args.distill_teacher is not None:
209+
teacher_device_map = (
210+
None
211+
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
212+
else "auto"
213+
)
214+
teacher_kwargs = {
215+
"config": teacher_config,
216+
"cache_dir": model_args.cache_dir,
217+
"use_auth_token": True if model_args.use_auth_token else None,
218+
"torch_dtype": parse_dtype(model_args.precision),
219+
"device_map": teacher_device_map,
220+
"trust_remote_code": model_args.trust_remote_code_model,
221+
}
222+
223+
teacher = AutoModelForCausalLM.from_pretrained(
224+
model_args.distill_teacher,
225+
**teacher_kwargs,
226+
)
227+
if "sequence_length" in teacher_kwargs:
228+
teacher.seqlen = teacher_kwargs["sequence_length"]
195229

196230
model_path = (
197231
last_checkpoint or model_args.model
198232
if hasattr(model_args, "model")
199233
else model_args.model_name_or_path
200234
)
201235

202-
# Set seed before initializing model.
203-
set_seed(training_args.seed)
204-
205236
# Fallback to CPU if GPU requested and not available
206-
training_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
237+
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
207238

208239
# Trainer handles device assignment for FSDP and training, don't do mapping here
209240
# if running oneshot outside of FSDP, apply user device settings
210-
device_map = None
241+
211242
fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
212-
if not fsdp_enabled and training_args.do_oneshot:
213-
device_map = training_args.oneshot_device
214-
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
215-
elif not fsdp_enabled:
243+
244+
device_map = model_args.oneshot_device
245+
if not fsdp_enabled and training_args is not None and training_args.do_train:
216246
device_map = "auto"
247+
217248
model_kwargs = {
218249
"config": config,
219250
"cache_dir": model_args.cache_dir,
@@ -223,15 +254,7 @@ def initialize_model_from_path(
223254
"device_map": device_map,
224255
"trust_remote_code": model_args.trust_remote_code_model,
225256
}
226-
teacher_device_map = None if fsdp_enabled else "auto"
227-
teacher_kwargs = {
228-
"config": teacher_config,
229-
"cache_dir": model_args.cache_dir,
230-
"use_auth_token": True if model_args.use_auth_token else None,
231-
"torch_dtype": parse_dtype(model_args.precision),
232-
"device_map": teacher_device_map,
233-
"trust_remote_code": model_args.trust_remote_code_model,
234-
}
257+
235258
# this calls from_pretrained under the hood so should be FSDP safe
236259

237260
# optimized models must be decompressed to carry out oneshot/train/etc
@@ -247,18 +270,7 @@ def initialize_model_from_path(
247270
if "sequence_length" in model_kwargs:
248271
model.seqlen = model_kwargs["sequence_length"]
249272

250-
teacher = (
251-
AutoModelForCausalLM.from_pretrained(
252-
model_args.distill_teacher,
253-
**teacher_kwargs,
254-
)
255-
if model_args.distill_teacher is not None
256-
else None
257-
)
258-
if teacher is not None and "sequence_length" in teacher_kwargs:
259-
teacher.seqlen = teacher_kwargs["sequence_length"]
260-
261-
return teacher, model_path, model
273+
return model, teacher
262274

263275

264276
def initialize_processor_from_path(
@@ -357,7 +369,7 @@ def main(
357369

358370
model = model_args.model
359371
if isinstance(model, str) or isinstance(model, PosixPath):
360-
(teacher, _model_path, model) = initialize_model_from_path(
372+
model, teacher = initialize_model_from_path(
361373
model_args,
362374
training_args,
363375
)

0 commit comments

Comments
 (0)