Skip to content

Commit c6873c4

Browse files
authored
[UX] Support nested dicts in hf_overrides (#25727)
Signed-off-by: mgoin <[email protected]>
1 parent 2111b46 commit c6873c4

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

tests/test_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,37 @@ def test_rope_customization():
292292
assert longchat_model_config.max_model_len == 4096
293293

294294

295+
def test_nested_hf_overrides():
296+
"""Test that nested hf_overrides work correctly."""
297+
# Test with a model that has text_config
298+
model_config = ModelConfig(
299+
"Qwen/Qwen2-VL-2B-Instruct",
300+
hf_overrides={
301+
"text_config": {
302+
"hidden_size": 1024,
303+
},
304+
},
305+
)
306+
assert model_config.hf_config.text_config.hidden_size == 1024
307+
308+
# Test with deeply nested overrides
309+
model_config = ModelConfig(
310+
"Qwen/Qwen2-VL-2B-Instruct",
311+
hf_overrides={
312+
"text_config": {
313+
"hidden_size": 2048,
314+
"num_attention_heads": 16,
315+
},
316+
"vision_config": {
317+
"hidden_size": 512,
318+
},
319+
},
320+
)
321+
assert model_config.hf_config.text_config.hidden_size == 2048
322+
assert model_config.hf_config.text_config.num_attention_heads == 16
323+
assert model_config.hf_config.vision_config.hidden_size == 512
324+
325+
295326
@pytest.mark.skipif(
296327
current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
297328
)

vllm/config/model.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,51 @@ def compute_hash(self) -> str:
367367
assert_hashable(str_factors)
368368
return hashlib.sha256(str(factors).encode()).hexdigest()
369369

370+
def _update_nested(
371+
self,
372+
target: Union["PretrainedConfig", dict[str, Any]],
373+
updates: dict[str, Any],
374+
) -> None:
375+
"""Recursively updates a config or dict with nested updates."""
376+
for key, value in updates.items():
377+
if isinstance(value, dict):
378+
# Get the nested target
379+
if isinstance(target, dict):
380+
nested_target = target.get(key)
381+
else:
382+
nested_target = getattr(target, key, None)
383+
384+
# If nested target exists and can be updated recursively
385+
if nested_target is not None and (
386+
isinstance(nested_target, dict)
387+
or hasattr(nested_target, "__dict__")
388+
):
389+
self._update_nested(nested_target, value)
390+
continue
391+
392+
# Set the value (base case)
393+
if isinstance(target, dict):
394+
target[key] = value
395+
else:
396+
setattr(target, key, value)
397+
398+
def _apply_dict_overrides(
399+
self,
400+
config: "PretrainedConfig",
401+
overrides: dict[str, Any],
402+
) -> None:
403+
"""Apply dict overrides, handling both nested configs and dict values."""
404+
from transformers import PretrainedConfig
405+
406+
for key, value in overrides.items():
407+
attr = getattr(config, key, None)
408+
if attr is not None and isinstance(attr, PretrainedConfig):
409+
# It's a nested config - recursively update it
410+
self._update_nested(attr, value)
411+
else:
412+
# It's a dict-valued parameter - set it directly
413+
setattr(config, key, value)
414+
370415
def __post_init__(
371416
self,
372417
# Multimodal config init vars
@@ -419,8 +464,17 @@ def __post_init__(
419464
if callable(self.hf_overrides):
420465
hf_overrides_kw = {}
421466
hf_overrides_fn = self.hf_overrides
467+
dict_overrides: dict[str, Any] = {}
422468
else:
423-
hf_overrides_kw = self.hf_overrides
469+
# Separate dict overrides from flat ones
470+
# We'll determine how to apply dict overrides after loading the config
471+
hf_overrides_kw = {}
472+
dict_overrides = {}
473+
for key, value in self.hf_overrides.items():
474+
if isinstance(value, dict):
475+
dict_overrides[key] = value
476+
else:
477+
hf_overrides_kw[key] = value
424478
hf_overrides_fn = None
425479

426480
if self.rope_scaling:
@@ -478,6 +532,8 @@ def __post_init__(
478532
)
479533

480534
self.hf_config = hf_config
535+
if dict_overrides:
536+
self._apply_dict_overrides(hf_config, dict_overrides)
481537
self.hf_text_config = get_hf_text_config(self.hf_config)
482538
self.attention_chunk_size = getattr(
483539
self.hf_text_config, "attention_chunk_size", None

0 commit comments

Comments
 (0)