diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 92fbdd709348..18774f4c694e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -437,6 +437,35 @@ def create_weights( output_size_per_partition, input_size_per_partition, weight_loader ) else: + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # add a counter to track how many elements we have updated + if not hasattr(param, "_loaded_numel"): + param._loaded_numel = loaded_weight.numel() + else: + param._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + if param._loaded_numel == param.numel(): + # This works for Linear without biases because there is only one + # weight. It can be extended to more complicated modules with some + # additional state - we have `layer`, so we can inspect all of its + # parameters and count the updates on all of them to know when we + # are done. + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del param._loaded_numel + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + return res + # For non-serialized checkpoints, use original dtype weight = ModelWeightParameter( data=torch.empty( @@ -446,7 +475,7 @@ def create_weights( ), input_dim=1, output_dim=0, - weight_loader=weight_loader, + weight_loader=patched_weight_loader, ) layer.register_parameter("weight", weight) @@ -487,6 +516,9 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + size_k_first = True input_scale = None # TODO(rob): refactor block quant into separate class.