@@ -167,9 +167,8 @@ def parse_args(**kwargs):
167167
168168def initialize_model_from_path (
169169 model_args : ModelArguments ,
170- training_args : TrainingArguments ,
170+ training_args : Optional [ TrainingArguments ] = None ,
171171):
172- last_checkpoint = detect_last_checkpoint (training_args , model_args = model_args )
173172 # Load pretrained model
174173 # The .from_pretrained methods guarantee that only one local process can
175174 # concurrently download model & vocab.
@@ -182,38 +181,70 @@ def initialize_model_from_path(
182181 tie_word_embeddings = model_args .tie_word_embeddings ,
183182 trust_remote_code = model_args .trust_remote_code_model ,
184183 )
185- teacher_config = (
186- AutoConfig .from_pretrained (
187- model_args .distill_teacher ,
188- use_auth_token = True if model_args .use_auth_token else None ,
189- tie_word_embeddings = model_args .tie_word_embeddings ,
190- trust_remote_code = model_args .trust_remote_code_model ,
184+
185+ last_checkpoint = None
186+ teacher = None
187+
188+ if training_args is not None :
189+ # Load teacher configuration if applicable
190+ teacher_config = (
191+ AutoConfig .from_pretrained (
192+ model_args .distill_teacher ,
193+ use_auth_token = True if model_args .use_auth_token else None ,
194+ tie_word_embeddings = model_args .tie_word_embeddings ,
195+ trust_remote_code = model_args .trust_remote_code_model ,
196+ )
197+ if model_args .distill_teacher
198+ else None
191199 )
192- if model_args .distill_teacher
193- else None
194- )
200+
201+ # Detect last checkpoint
202+ last_checkpoint = detect_last_checkpoint (training_args , model_args = model_args )
203+
204+ # Set seed before initializing model
205+ set_seed (training_args .seed )
206+
207+ # Initialize teacher model if teacher path is provided
208+ if model_args .distill_teacher is not None :
209+ teacher_device_map = (
210+ None
211+ if os .environ .get ("ACCELERATE_USE_FSDP" , "false" ) == "true"
212+ else "auto"
213+ )
214+ teacher_kwargs = {
215+ "config" : teacher_config ,
216+ "cache_dir" : model_args .cache_dir ,
217+ "use_auth_token" : True if model_args .use_auth_token else None ,
218+ "torch_dtype" : parse_dtype (model_args .precision ),
219+ "device_map" : teacher_device_map ,
220+ "trust_remote_code" : model_args .trust_remote_code_model ,
221+ }
222+
223+ teacher = AutoModelForCausalLM .from_pretrained (
224+ model_args .distill_teacher ,
225+ ** teacher_kwargs ,
226+ )
227+ if "sequence_length" in teacher_kwargs :
228+ teacher .seqlen = teacher_kwargs ["sequence_length" ]
195229
196230 model_path = (
197231 last_checkpoint or model_args .model
198232 if hasattr (model_args , "model" )
199233 else model_args .model_name_or_path
200234 )
201235
202- # Set seed before initializing model.
203- set_seed (training_args .seed )
204-
205236 # Fallback to CPU if GPU requested and not available
206- training_args .oneshot_device = fallback_to_cpu (model_args .oneshot_device )
237+ model_args .oneshot_device = fallback_to_cpu (model_args .oneshot_device )
207238
208239 # Trainer handles device assignment for FSDP and training, don't do mapping here
209240 # if running oneshot outside of FSDP, apply user device settings
210- device_map = None
241+
211242 fsdp_enabled = os .environ .get ("ACCELERATE_USE_FSDP" , "false" ) == "true"
212- if not fsdp_enabled and training_args .do_oneshot :
213- device_map = training_args .oneshot_device
214- logger .warning (f"Moving { model_path } to device { device_map } for One-Shot" )
215- elif not fsdp_enabled :
243+
244+ device_map = model_args .oneshot_device
245+ if not fsdp_enabled and training_args is not None and training_args .do_train :
216246 device_map = "auto"
247+
217248 model_kwargs = {
218249 "config" : config ,
219250 "cache_dir" : model_args .cache_dir ,
@@ -223,15 +254,7 @@ def initialize_model_from_path(
223254 "device_map" : device_map ,
224255 "trust_remote_code" : model_args .trust_remote_code_model ,
225256 }
226- teacher_device_map = None if fsdp_enabled else "auto"
227- teacher_kwargs = {
228- "config" : teacher_config ,
229- "cache_dir" : model_args .cache_dir ,
230- "use_auth_token" : True if model_args .use_auth_token else None ,
231- "torch_dtype" : parse_dtype (model_args .precision ),
232- "device_map" : teacher_device_map ,
233- "trust_remote_code" : model_args .trust_remote_code_model ,
234- }
257+
235258 # this calls from_pretrained under the hood so should be FSDP safe
236259
237260 # optimized models must be decompressed to carry out oneshot/train/etc
@@ -247,18 +270,7 @@ def initialize_model_from_path(
247270 if "sequence_length" in model_kwargs :
248271 model .seqlen = model_kwargs ["sequence_length" ]
249272
250- teacher = (
251- AutoModelForCausalLM .from_pretrained (
252- model_args .distill_teacher ,
253- ** teacher_kwargs ,
254- )
255- if model_args .distill_teacher is not None
256- else None
257- )
258- if teacher is not None and "sequence_length" in teacher_kwargs :
259- teacher .seqlen = teacher_kwargs ["sequence_length" ]
260-
261- return teacher , model_path , model
273+ return model , teacher
262274
263275
264276def initialize_processor_from_path (
@@ -357,7 +369,7 @@ def main(
357369
358370 model = model_args .model
359371 if isinstance (model , str ) or isinstance (model , PosixPath ):
360- ( teacher , _model_path , model ) = initialize_model_from_path (
372+ model , teacher = initialize_model_from_path (
361373 model_args ,
362374 training_args ,
363375 )
0 commit comments