Skip to content

Commit 5326892

Browse files
committed
[not for land] online fp8 quant with streaming weight post-processing
Summary: not for land, just a demo 1. during weight loading, keep track of how many elements we have loaded 2. when we have loaded all the elements, call post-processing can be used to call weight post-processing in a streaming fashion to minimize GPU memory usage. Will only work if we can assume we only load each weight chunk once. Test Plan: tested locally with facebook/opt-125m and `fp8` online quantization Reviewers: Subscribers: Tasks: Tags: Signed-off-by: <[email protected]>
1 parent b34129b commit 5326892

File tree

1 file changed

+33
-1
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+33
-1
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,35 @@ def create_weights(
437437
output_size_per_partition, input_size_per_partition, weight_loader
438438
)
439439
else:
440+
441+
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
442+
# load the current weight chunk
443+
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
444+
445+
# add a counter to track how many elements we have updated
446+
if not hasattr(param, "_loaded_numel"):
447+
param._loaded_numel = loaded_weight.numel()
448+
else:
449+
param._loaded_numel += loaded_weight.numel()
450+
451+
# if we have loaded all of the elements, call
452+
# process_weights_after_loading
453+
if param._loaded_numel == param.numel():
454+
# This works for Linear without biases because there is only one
455+
# weight. It can be extended to more complicated modules with some
456+
# additional state - we have `layer`, so we can inspect all of its
457+
# parameters and count the updates on all of them to know when we
458+
# are done.
459+
self.process_weights_after_loading(layer)
460+
461+
# Delete the bookkeeping
462+
del param._loaded_numel
463+
# Prevent the usual `process_weights_after_loading` call from doing
464+
# anything
465+
self._already_called_process_weights_after_loading = True
466+
467+
return res
468+
440469
# For non-serialized checkpoints, use original dtype
441470
weight = ModelWeightParameter(
442471
data=torch.empty(
@@ -446,7 +475,7 @@ def create_weights(
446475
),
447476
input_dim=1,
448477
output_dim=0,
449-
weight_loader=weight_loader,
478+
weight_loader=patched_weight_loader,
450479
)
451480
layer.register_parameter("weight", weight)
452481

@@ -487,6 +516,9 @@ def create_weights(
487516
layer.register_parameter("input_scale", None)
488517

489518
def process_weights_after_loading(self, layer: Module) -> None:
519+
if getattr(self, "_already_called_process_weights_after_loading", False):
520+
return
521+
490522
size_k_first = True
491523
input_scale = None
492524
# TODO(rob): refactor block quant into separate class.

0 commit comments

Comments
 (0)