[Feat] support for multi-block layerwise offloading#1486
[Feat] support for multi-block layerwise offloading#1486RuixiangMa wants to merge 4 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Lancer <maruixiang6688@gmail.com>
|
|
||
| # Handle multiple block types (_layerwise_offload_blocks_attrs) | ||
| if blocks_attr_name is None: | ||
| blocks_attrs_names = getattr(model.__class__, "_layerwise_offload_blocks_attrs", None) |
There was a problem hiding this comment.
IMO having both _layerwise_offload_blocks_attrs and _layerwise_offload_blocks_attr is a little confusing. I think it would be cleaner to just have one attr that can also be a list, because the behavior is not well-defined if a module sets both attributes by mistake
There was a problem hiding this comment.
Yeah, accounted for that, only kept the legacy path for compatibility, but can refactor if needed.
|
|
||
| def __init__(self): | ||
| self.blocks = nn.ModuleList([...]) # Transformer blocks | ||
| ``` |
There was a problem hiding this comment.
This PR adds multi-block layerwise offloading but provides no test coverage. Add tests to verify: (1) multi-block offloading works correctly with different block types, (2) memory usage is reduced as expected, (3) output quality is maintained, and (4) edge cases like empty or invalid block attributes are handled.
|
|
||
| if not blocks_attr_name or not blocks: | ||
| if not blocks: | ||
| logger.warning( |
There was a problem hiding this comment.
No validation for blocks_attr_names. What happens if an attribute name doesn't exist on the model? Add error handling to check that each attribute in _layerwise_offload_blocks_attrs exists and contains valid blocks, with clear error messages for misconfiguration.
| m.to(self.device) | ||
|
|
||
| # Move top-level params/buffers to GPU (dit_module's own, not sub-modules) | ||
| for param in dit_module._parameters.values(): |
There was a problem hiding this comment.
This changes the offloading behavior from single 'blocks' attribute to multiple block attributes. Verify backward compatibility - existing models with only 'blocks' should still work. Consider adding a deprecation warning if the old single-attribute pattern is detected.
There was a problem hiding this comment.
The single-block model test has been verified. I'll supplement the result
lishunyang12
left a comment
There was a problem hiding this comment.
Left a couple comments on the backend changes. The multi-block approach looks right for Flux-style models.
| m.to(self.device) | ||
| logger.debug(f"Moved {name} to device {self.device}") | ||
| if blocks_attr_names and name not in blocks_attr_names: | ||
| m.to(self.device) |
There was a problem hiding this comment.
The old code had logger.debug calls here for skipped/moved modules. Dropping them makes offloading issues harder to debug — can you keep the logging?
| for param in dit_module._parameters.values(): | ||
| if param is not None: | ||
| param.data = param.data.to(self.device, non_blocking=True) | ||
|
|
There was a problem hiding this comment.
Moving top-level params/buffers looks like a separate bug fix (previously they would stay on CPU). Worth calling out in the PR description so it does not get overlooked during review.
| logger.debug(f"Skipped blocks module {name}") | ||
| continue | ||
| m.to(self.device) | ||
| logger.debug(f"Moved {name} to device {self.device}") |
There was a problem hiding this comment.
Nit: the blocks_attr_names and guard is redundant — we already continue above when not blocks, and blocks being non-empty implies blocks_attr_names is non-empty.
|
z-image is also supported in the pr to validate memory savings |
Purpose
Some diffusion models (e.g., Flux, LongCat, Ovis) have two types of transformer blocks(e.g., transformer_blocks and single_transformer_blocks ), the previous implementation only supported single block type, limiting layerwise offloading effectiveness for these models.
Test Plan
Test Result
NVIDIA-4090(24G)
vllm serve --model /data/models/black-forest-labs/FLUX* --omni --enable_layerwise_offload --port 8004Offload VS no offload
Since FLUX.1-dev and FLUX.2-klein-9B et.al incur OOM without layer offloading, we use FLUX.2-klein-4B and Z-Image as a representative example to illustrate memory usage:
19.7GB13.8GB22.7GB15.5GB