Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,35 @@ def create_weights(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(param, "_loaded_numel"):
param._loaded_numel = loaded_weight.numel()
else:
param._loaded_numel += loaded_weight.numel()

# if we have loaded all of the elements, call
# process_weights_after_loading
if param._loaded_numel == param.numel():
# This works for Linear without biases because there is only one
# weight. It can be extended to more complicated modules with some
# additional state - we have `layer`, so we can inspect all of its
# parameters and count the updates on all of them to know when we
# are done.
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del param._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

return res

# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
Expand All @@ -446,7 +475,7 @@ def create_weights(
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight)

Expand Down Expand Up @@ -487,6 +516,9 @@ def create_weights(
layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

size_k_first = True
input_scale = None
# TODO(rob): refactor block quant into separate class.
Expand Down