Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def save_pretrained_compressed(save_pretrained_method):
model_class = model_ref().__class__
del save_pretrained_method

# hotfix: create a weak reference to the model to avoid circular dep
# TODO: determine why circular dep is not collected and how to clean up this fn
model_ref = weakref.ref(model)

@wraps(original_save_pretrained)
def save_pretrained_wrapper(
save_directory: str,
Expand Down Expand Up @@ -95,11 +99,11 @@ def save_pretrained_wrapper(
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
logger.info("Fetching state_dict - this may take some time")
state_dict = get_state_dict_offloaded_model(model)
state_dict = get_state_dict_offloaded_model(model_ref())

logger.info("Fetching compressor")
compressor = get_model_compressor(
model=model,
model=model_ref(),
sparsity_config=sparsity_config,
quantization_format=quantization_format,
save_compressed=save_compressed,
Expand All @@ -111,7 +115,7 @@ def save_pretrained_wrapper(
if compressor is None:
# model is not compressed or quantized, save as normal
original_save_pretrained_func = original_save_pretrained.__get__(
model, model_class
model_ref(), model_class
)
original_save_pretrained_func(
save_directory, state_dict=state_dict, **kwargs
Expand All @@ -121,10 +125,10 @@ def save_pretrained_wrapper(
# make sure we're on the main process when saving
if state_dict is not None and len(state_dict) > 0:
compressed_state_dict = compressor.compress(
model, state_dict, show_progress=True
model_ref(), state_dict, show_progress=True
)
logger.info("Saving compressed model to disk")
original_save_pretrained.__get__(model, model_class)(
original_save_pretrained.__get__(model_ref(), model_class)(
save_directory,
state_dict=compressed_state_dict,
safe_serialization=safe_serialization,
Expand All @@ -133,10 +137,10 @@ def save_pretrained_wrapper(
compressor.update_config(save_directory)

# update existing recipe
update_and_save_recipe(model.name_or_path, save_directory)
update_and_save_recipe(model_ref().name_or_path, save_directory)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_directory)
copy_python_files_from_model_cache(model_ref(), save_directory)

save_pretrained_wrapper._overridden = True
return save_pretrained_wrapper
Expand Down