Skip to content

Commit fde0b61

Browse files
authored
[Model] Decouple glm4v (#22751)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent d0a6301 commit fde0b61

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
615615
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
616616
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
617617
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
618-
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
618+
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | | ✅︎ | ✅︎ |
619619
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
620620
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
621621
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |

vllm/model_executor/models/glm4_1v.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,10 +1227,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
12271227
"k_proj",
12281228
"v_proj",
12291229
],
1230-
"gate_up_proj": [
1231-
"gate_proj",
1232-
"up_proj",
1233-
],
1230+
"gate_up_proj": ["gate_up_proj"]
12341231
}
12351232

12361233
# To ensure correct weight loading and mapping.
@@ -1567,7 +1564,26 @@ def get_mm_mapping(self) -> MultiModelKeys:
15671564
Get the module prefix in multimodal models
15681565
"""
15691566
return MultiModelKeys.from_string_field(
1570-
language_model="language_model",
1567+
language_model="language_model.model",
15711568
connector="visual.merger.",
15721569
tower_model="visual.",
15731570
)
1571+
1572+
1573+
@MULTIMODAL_REGISTRY.register_processor(
1574+
Glm4vMultiModalProcessor,
1575+
info=Glm4vProcessingInfo,
1576+
dummy_inputs=Glm4vDummyInputsBuilder,
1577+
)
1578+
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
1579+
packed_modules_mapping = {
1580+
"qkv_proj": [
1581+
"q_proj",
1582+
"k_proj",
1583+
"v_proj",
1584+
],
1585+
"gate_up_proj": [
1586+
"gate_proj",
1587+
"up_proj",
1588+
],
1589+
}

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@
208208
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
209209
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
210210
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
211-
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
211+
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
212212
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
213213
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
214214
"InternVLChatModel": ("internvl", "InternVLChatModel"),

0 commit comments

Comments
 (0)