Skip to content

Commit 517eb73

Browse files
committed
diffusion/lora: reuse load_weights packed mapping
Signed-off-by: dongbo910220 <1275604947@qq.com>
1 parent de0a8b0 commit 517eb73

File tree

11 files changed

+68
-42
lines changed

11 files changed

+68
-42
lines changed

tests/diffusion/lora/test_lora_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,11 @@ def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule:
133133
monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule)
134134

135135
pipeline = torch.nn.Module()
136-
pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
136+
pipeline.stacked_params_mapping = [
137+
(".to_qkv", ".to_q", "q"),
138+
(".to_qkv", ".to_k", "k"),
139+
(".to_qkv", ".to_v", "v"),
140+
]
137141
pipeline.transformer = torch.nn.Module()
138142
pipeline.transformer.to_qkv = _FakeLinearBase()
139143

@@ -145,7 +149,7 @@ def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule:
145149
)
146150

147151
# Treat the dummy layer as a packed 3-slice projection so the manager uses
148-
# `packed_modules_mapping` to decide replacement based on target_modules.
152+
# `stacked_params_mapping` to decide replacement based on target_modules.
149153
monkeypatch.setattr(manager, "_get_packed_modules_list", lambda _module: ["q", "k", "v"])
150154

151155
peft_helper = type("_PH", (), {"r": 1, "target_modules": ["to_q"]})()
@@ -206,7 +210,11 @@ def test_lora_manager_activates_fused_lora_on_packed_layer():
206210

207211
def test_lora_manager_activates_packed_lora_from_sublayers():
208212
pipeline = torch.nn.Module()
209-
pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
213+
pipeline.stacked_params_mapping = [
214+
(".to_qkv", ".to_q", "q"),
215+
(".to_qkv", ".to_k", "k"),
216+
(".to_qkv", ".to_v", "v"),
217+
]
210218
manager = DiffusionLoRAManager(
211219
pipeline=pipeline,
212220
device=torch.device("cpu"),

vllm_omni/diffusion/lora/manager.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,43 @@ def _compute_supported_lora_modules(self) -> set[str]:
131131
def _compute_packed_modules_mapping(self) -> dict[str, list[str]]:
132132
"""Collect packed->sublayer mappings from the diffusion model.
133133
134-
vLLM models declare `packed_modules_mapping` on the model class. For
135-
diffusion pipelines, we attach the same mapping on the transformer
136-
module(s) that implement packed (fused) projections, so LoRA loading can
137-
accept checkpoints trained against the logical sub-projections.
134+
Diffusion models often use packed (fused) projections like `to_qkv` or
135+
`w13`, while LoRA checkpoints are typically saved against the logical
136+
sub-projections (e.g. `to_q`/`to_k`/`to_v`, `w1`/`w3`). Many diffusion
137+
model implementations already define these relationships in
138+
`load_weights()` via `stacked_params_mapping`. To avoid duplicating the
139+
mapping in multiple places, we derive packed→sublayer mappings from the
140+
model's `stacked_params_mapping`.
138141
"""
142+
143+
def _derive_from_stacked_params_mapping(stacked: object) -> dict[str, list[str]]:
144+
if not isinstance(stacked, (list, tuple)):
145+
return {}
146+
derived: dict[str, list[str]] = {}
147+
for item in stacked:
148+
if not isinstance(item, (list, tuple)) or len(item) < 2:
149+
continue
150+
packed_suffix, sub_suffix = item[0], item[1]
151+
if not isinstance(packed_suffix, str) or not packed_suffix:
152+
continue
153+
if not isinstance(sub_suffix, str) or not sub_suffix:
154+
continue
155+
# The mapping strings are usually suffix patterns (e.g. ".to_qkv"),
156+
# but some models scope them under submodules (e.g. ".attn1.to_qkv").
157+
# For LoRA we only care about the leaf module names.
158+
packed_name = packed_suffix.split(".")[-1]
159+
sub_name = sub_suffix.split(".")[-1]
160+
existing = derived.get(packed_name)
161+
if existing is None:
162+
derived[packed_name] = [sub_name]
163+
elif sub_name not in existing:
164+
existing.append(sub_name)
165+
return derived
166+
139167
mapping: dict[str, list[str]] = {}
140168
for module in self.pipeline.modules():
141-
packed = getattr(module, "packed_modules_mapping", None)
142-
if not isinstance(packed, dict):
143-
continue
144-
for packed_name, sub_names in packed.items():
169+
derived = _derive_from_stacked_params_mapping(getattr(module, "stacked_params_mapping", None))
170+
for packed_name, sub_names in derived.items():
145171
if not isinstance(packed_name, str) or not packed_name:
146172
continue
147173
if not isinstance(sub_names, (list, tuple)) or not all(isinstance(s, str) for s in sub_names):
@@ -155,7 +181,7 @@ def _compute_packed_modules_mapping(self) -> dict[str, list[str]]:
155181
mapping[packed_name] = sub_names_list
156182
elif existing != sub_names_list:
157183
logger.warning(
158-
"Conflicting packed_modules_mapping for %s: %s vs %s; using %s",
184+
"Conflicting packed module mapping for %s: %s vs %s; using %s",
159185
packed_name,
160186
existing,
161187
sub_names_list,
@@ -170,7 +196,7 @@ def _get_packed_sublayer_suffixes(self, packed_module_suffix: str, n_slices: int
170196
return None
171197
if len(sub_suffixes) != n_slices:
172198
logger.warning(
173-
"packed_modules_mapping[%s] has %d slices but layer expects %d; skipping sublayer lookup",
199+
"Packed module mapping[%s] has %d slices but layer expects %d; skipping sublayer lookup",
174200
packed_module_suffix,
175201
len(sub_suffixes),
176202
n_slices,

vllm_omni/diffusion/lora/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ def _expand_expected_modules_for_packed_layers(
3939
`supported_modules`, but the sublayer names are not. Expanding the set
4040
ensures these sublayer keys are not dropped when loading a LoRA checkpoint.
4141
42-
The packed→sublayer mapping is model-specific (see each diffusion model's
43-
`packed_modules_mapping`) so new packed layers are added alongside the model
44-
implementation rather than hard-coded in the LoRA framework.
42+
The packed→sublayer mapping is model-specific and is derived from each
43+
diffusion model's `stacked_params_mapping` (used by `load_weights()`), so
44+
new packed layers are added alongside the model implementation rather than
45+
hard-coded in the LoRA framework.
4546
"""
4647
expanded = set(supported_modules)
4748
if not packed_modules_mapping:

vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,6 @@ class Flux2Transformer2DModel(nn.Module):
559559
"""
560560

561561
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
562-
packed_modules_mapping = {
563-
"to_qkv": ["to_q", "to_k", "to_v"],
564-
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
565-
}
566562

567563
def __init__(
568564
self,
@@ -735,6 +731,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
735731
(".add_kv_proj", ".add_k_proj", "k"),
736732
(".add_kv_proj", ".add_v_proj", "v"),
737733
]
734+
# Expose packed shard mappings for LoRA handling of fused projections.
735+
self.stacked_params_mapping = stacked_params_mapping
738736

739737
params_dict = dict(self.named_parameters())
740738

vllm_omni/diffusion/models/glm_image/glm_image_transformer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,6 @@ class GlmImageTransformer2DModel(CachedTransformer):
551551
`od_config.tf_model_config`.
552552
"""
553553

554-
packed_modules_mapping = {
555-
"to_qkv": ["to_q", "to_k", "to_v"],
556-
}
557-
558554
def __init__(
559555
self,
560556
od_config: OmniDiffusionConfig,
@@ -724,6 +720,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
724720
(".to_qkv", ".to_k", "k"),
725721
(".to_qkv", ".to_v", "v"),
726722
]
723+
# Expose packed shard mappings for LoRA handling of fused projections.
724+
self.stacked_params_mapping = stacked_params_mapping
727725

728726
params_dict = dict(self.named_parameters())
729727

vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,11 +503,6 @@ class LongCatImageTransformer2DModel(nn.Module):
503503
Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
504504
"""
505505

506-
packed_modules_mapping = {
507-
"to_qkv": ["to_q", "to_k", "to_v"],
508-
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
509-
}
510-
511506
def __init__(
512507
self,
513508
od_config: OmniDiffusionConfig,
@@ -707,6 +702,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
707702
(".add_kv_proj", ".add_k_proj", "k"),
708703
(".add_kv_proj", ".add_v_proj", "v"),
709704
]
705+
# Expose packed shard mappings for LoRA handling of fused projections.
706+
self.stacked_params_mapping = stacked_params_mapping
710707

711708
params_dict = dict(self.named_parameters())
712709

vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,6 @@ class OvisImageTransformer2DModel(nn.Module):
366366
"""
367367

368368
_repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"]
369-
packed_modules_mapping = {
370-
"to_qkv": ["to_q", "to_k", "to_v"],
371-
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
372-
}
373369

374370
def __init__(
375371
self,
@@ -518,6 +514,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
518514
(".add_kv_proj", ".add_k_proj", "k"),
519515
(".add_kv_proj", ".add_v_proj", "v"),
520516
]
517+
# Expose packed shard mappings for LoRA handling of fused projections.
518+
self.stacked_params_mapping = stacked_params_mapping
521519

522520
params_dict = dict(self.named_parameters())
523521

vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
10531053
(".add_kv_proj", ".add_k_proj", "k"),
10541054
(".add_kv_proj", ".add_v_proj", "v"),
10551055
]
1056+
# Expose packed shard mappings for LoRA handling of fused projections.
1057+
self.stacked_params_mapping = stacked_params_mapping
10561058

10571059
params_dict = dict(self.named_parameters())
10581060

vllm_omni/diffusion/models/sd3/sd3_transformer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,6 @@ class SD3Transformer2DModel(nn.Module):
322322
"""
323323

324324
_repeated_blocks = ["SD3TransformerBlock"]
325-
packed_modules_mapping = {
326-
"to_qkv": ["to_q", "to_k", "to_v"],
327-
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
328-
}
329325

330326
def __init__(
331327
self,
@@ -454,6 +450,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
454450
(".add_kv_proj", ".add_k_proj", "k"),
455451
(".add_kv_proj", ".add_v_proj", "v"),
456452
]
453+
# Expose packed shard mappings for LoRA handling of fused projections.
454+
self.stacked_params_mapping = stacked_params_mapping
457455

458456
params_dict = dict(self.named_parameters())
459457

vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
725725
Returns:
726726
Set of parameter names that were successfully loaded.
727727
"""
728-
# Stacked params mapping for self-attention QKV fusion
728+
# Stacked params mapping for self-attention QKV fusion.
729729
# Format: (param_name, shard_name, shard_id)
730730
# Note: Only fuse attn1 (self-attention), NOT attn2 (cross-attention)
731731
stacked_params_mapping = [
@@ -734,6 +734,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
734734
(".attn1.to_qkv", ".attn1.to_k", "k"),
735735
(".attn1.to_qkv", ".attn1.to_v", "v"),
736736
]
737+
# Expose packed shard mappings for LoRA handling of fused projections.
738+
self.stacked_params_mapping = stacked_params_mapping
737739

738740
params_dict = dict(self.named_parameters())
739741
loaded_params: set[str] = set()

0 commit comments

Comments
 (0)