@@ -367,6 +367,51 @@ def compute_hash(self) -> str:
367
367
assert_hashable (str_factors )
368
368
return hashlib .sha256 (str (factors ).encode ()).hexdigest ()
369
369
370
+ def _update_nested (
371
+ self ,
372
+ target : Union ["PretrainedConfig" , dict [str , Any ]],
373
+ updates : dict [str , Any ],
374
+ ) -> None :
375
+ """Recursively updates a config or dict with nested updates."""
376
+ for key , value in updates .items ():
377
+ if isinstance (value , dict ):
378
+ # Get the nested target
379
+ if isinstance (target , dict ):
380
+ nested_target = target .get (key )
381
+ else :
382
+ nested_target = getattr (target , key , None )
383
+
384
+ # If nested target exists and can be updated recursively
385
+ if nested_target is not None and (
386
+ isinstance (nested_target , dict )
387
+ or hasattr (nested_target , "__dict__" )
388
+ ):
389
+ self ._update_nested (nested_target , value )
390
+ continue
391
+
392
+ # Set the value (base case)
393
+ if isinstance (target , dict ):
394
+ target [key ] = value
395
+ else :
396
+ setattr (target , key , value )
397
+
398
+ def _apply_dict_overrides (
399
+ self ,
400
+ config : "PretrainedConfig" ,
401
+ overrides : dict [str , Any ],
402
+ ) -> None :
403
+ """Apply dict overrides, handling both nested configs and dict values."""
404
+ from transformers import PretrainedConfig
405
+
406
+ for key , value in overrides .items ():
407
+ attr = getattr (config , key , None )
408
+ if attr is not None and isinstance (attr , PretrainedConfig ):
409
+ # It's a nested config - recursively update it
410
+ self ._update_nested (attr , value )
411
+ else :
412
+ # It's a dict-valued parameter - set it directly
413
+ setattr (config , key , value )
414
+
370
415
def __post_init__ (
371
416
self ,
372
417
# Multimodal config init vars
@@ -419,8 +464,17 @@ def __post_init__(
419
464
if callable (self .hf_overrides ):
420
465
hf_overrides_kw = {}
421
466
hf_overrides_fn = self .hf_overrides
467
+ dict_overrides : dict [str , Any ] = {}
422
468
else :
423
- hf_overrides_kw = self .hf_overrides
469
+ # Separate dict overrides from flat ones
470
+ # We'll determine how to apply dict overrides after loading the config
471
+ hf_overrides_kw = {}
472
+ dict_overrides = {}
473
+ for key , value in self .hf_overrides .items ():
474
+ if isinstance (value , dict ):
475
+ dict_overrides [key ] = value
476
+ else :
477
+ hf_overrides_kw [key ] = value
424
478
hf_overrides_fn = None
425
479
426
480
if self .rope_scaling :
@@ -478,6 +532,8 @@ def __post_init__(
478
532
)
479
533
480
534
self .hf_config = hf_config
535
+ if dict_overrides :
536
+ self ._apply_dict_overrides (hf_config , dict_overrides )
481
537
self .hf_text_config = get_hf_text_config (self .hf_config )
482
538
self .attention_chunk_size = getattr (
483
539
self .hf_text_config , "attention_chunk_size" , None
0 commit comments