Skip to content

Commit d6a3551

Browse files
[Bug] [Bagel] Fix kv transfer bug (#1437)
Signed-off-by: Ding Zuhao <e1583181@u.nus.edu> Co-authored-by: Wang Zhipeng: princepride <wangzhipeng628@gmail.com>
1 parent 40f72bb commit d6a3551

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

vllm_omni/diffusion/models/bagel/bagel_transformer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,14 @@ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
12861286
@staticmethod
12871287
def _merge_naive_caches(caches: list) -> NaiveCache:
12881288
"""Merge multiple NaiveCache objects by concatenating KV tensors per layer."""
1289-
merged = NaiveCache(caches[0].num_layers)
1290-
for layer_idx in range(merged.num_layers):
1289+
if not caches:
1290+
# Handle empty list case gracefully if desired,
1291+
# though original code also crashed on this.
1292+
return NaiveCache(0)
1293+
1294+
num_layers = len(caches[0].key_cache)
1295+
merged = NaiveCache(num_layers)
1296+
for layer_idx in range(num_layers):
12911297
merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0)
12921298
merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0)
12931299
return merged

0 commit comments

Comments
 (0)