diff --git a/src/llmcompressor/modifiers/utils/helpers.py b/src/llmcompressor/modifiers/utils/helpers.py index c7e1fc4bc2..cbba632d94 100644 --- a/src/llmcompressor/modifiers/utils/helpers.py +++ b/src/llmcompressor/modifiers/utils/helpers.py @@ -8,8 +8,8 @@ """ import torch +from compressed_tensors.offload import align_modules, update_offload_parameter from compressed_tensors.quantization import QuantizationStrategy, is_attention_module -from compressed_tensors.utils import align_modules, update_parameter_data from torch.nn import Linear, Module __all__ = ["update_fused_layer_weight_global_scales"] @@ -80,9 +80,9 @@ def _valid_tensor_group_quant(layer_list: list[Linear]): ) ).reshape([1]) - update_parameter_data(submodule.k_proj, global_scale, "weight_global_scale") - update_parameter_data(submodule.q_proj, global_scale, "weight_global_scale") - update_parameter_data(submodule.v_proj, global_scale, "weight_global_scale") + update_offload_parameter(submodule.k_proj, "weight_global_scale", global_scale) + update_offload_parameter(submodule.q_proj, "weight_global_scale", global_scale) + update_offload_parameter(submodule.v_proj, "weight_global_scale", global_scale) del global_scale @@ -100,7 +100,11 @@ def _valid_tensor_group_quant(layer_list: list[Linear]): ) ).reshape([1]) - update_parameter_data(submodule.gate_proj, global_scale, "weight_global_scale") - update_parameter_data(submodule.up_proj, global_scale, "weight_global_scale") + update_offload_parameter( + submodule.gate_proj, + "weight_global_scale", + global_scale, + ) + update_offload_parameter(submodule.up_proj, "weight_global_scale", global_scale) del global_scale