Skip to content
96 changes: 54 additions & 42 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,8 @@ def parse_args(**kwargs):

def initialize_model_from_path(
model_args: ModelArguments,
training_args: TrainingArguments,
training_args: Optional[TrainingArguments] = None,
):
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Load pretrained model
# The .from_pretrained methods guarantee that only one local process can
# concurrently download model & vocab.
Expand All @@ -182,38 +181,70 @@ def initialize_model_from_path(
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,

last_checkpoint = None
teacher = None

if training_args is not None:
# Load teacher configuration if applicable
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
else None
)
if model_args.distill_teacher
else None
)

# Detect last checkpoint
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)

# Set seed before initializing model
set_seed(training_args.seed)

# Initialize teacher model if teacher path is provided
if model_args.distill_teacher is not None:
teacher_device_map = (
None
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else "auto"
)
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

teacher = AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

model_path = (
last_checkpoint or model_args.model
if hasattr(model_args, "model")
else model_args.model_name_or_path
)

# Set seed before initializing model.
set_seed(training_args.seed)

# Fallback to CPU if GPU requested and not available
training_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings
device_map = None

fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
if not fsdp_enabled and training_args.do_oneshot:
device_map = training_args.oneshot_device
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
elif not fsdp_enabled:

device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"

model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
Expand All @@ -223,15 +254,7 @@ def initialize_model_from_path(
"device_map": device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
teacher_device_map = None if fsdp_enabled else "auto"
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
Expand All @@ -247,18 +270,7 @@ def initialize_model_from_path(
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]

teacher = (
AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if model_args.distill_teacher is not None
else None
)
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

return teacher, model_path, model
return model, teacher


def initialize_processor_from_path(
Expand Down Expand Up @@ -357,7 +369,7 @@ def main(

model = model_args.model
if isinstance(model, str) or isinstance(model, PosixPath):
(teacher, _model_path, model) = initialize_model_from_path(
model, teacher = initialize_model_from_path(
model_args,
training_args,
)
Expand Down