Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 1 addition & 99 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,18 @@
import json
import os
from typing import Any, Dict, List, Optional, Union
from typing import Optional, Union

import torch
from loguru import logger
from safetensors import safe_open
from torch.nn import Module
from transformers import PreTrainedModel

from llmcompressor.core import active_session
from llmcompressor.typing import Processor

COMPLETED_STAGES_FILENAME = "completed_stages.json"

__all__ = [
"copy_python_files_from_model_cache",
"parse_dtype",
"get_session_model",
"get_completed_stages",
"save_completed_stages",
"save_checkpoint",
]


def save_checkpoint(
save_path: str,
model: PreTrainedModel,
processor: Optional[Processor] = None,
save_safetensors: bool = True,
save_compressed: bool = True,
skip_sparsity_compression_stats: bool = False,
):
"""
Save a model, processor, and recipe

:param save_path: Path used to save model and processor
:param model: model to save
:param processor: processor to save
:param save_safetensors: save model checkpoint using safetensors file type
:param save_compressed: save model checkpoint using compressed-tensors format
"""
from llmcompressor.transformers.compression.compressed_tensors_utils import (
get_model_compressor, # avoid circular import
)

# used for decompression
# unfortunately, if skip_sparsity_compression_stats==True, sparsity stats
# are computed twice. In the future, track sparsity from recipe or
# share recipe between compression and decompression
compressor = get_model_compressor(
model=model,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
)

# saving the model also saves the recipe
model.save_pretrained(
save_path,
save_safetensors=save_safetensors,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
)
if processor is not None:
processor.save_pretrained(save_path)

# decompression: saving the model modifies the model strcuture
# as this is only a checkpoint, decompress model to enable future training/oneshot
if compressor is not None:
compressor.decompress_model(model)


def parse_dtype(dtype_arg: Union[str, torch.dtype]) -> torch.dtype:
"""
:param dtype_arg: dtype or string to parse
Expand Down Expand Up @@ -100,47 +43,6 @@ def get_session_model() -> Optional[Module]:
return active_model


def get_completed_stages(checkpoint_dir: Any) -> List[str]:
"""
Given a checkpoint directory for a staged run, get the list of stages that
have completed in a prior run if the checkpoint_dir is a string

:param checkpoint_dir: path to staged checkpoint
:return: list of completed stage names
"""
if isinstance(checkpoint_dir, str):
stage_path = os.path.join(checkpoint_dir, COMPLETED_STAGES_FILENAME)
if os.path.exists(stage_path):
with open(stage_path) as stage_file:
stage_data = json.load(stage_file)
return stage_data["completed"]

return []


def save_completed_stages(checkpoint_dir: str, completed_stages: List[str]):
"""
Save a list of completed stages to a checkpoint directory

:param checkpoint_dir: model checkpoint directory to save stages to
:param completed_stages: list of stage names that have been run
"""
stage_path = os.path.join(checkpoint_dir, COMPLETED_STAGES_FILENAME)
with open(stage_path, "w") as out_file:
json.dump({"completed": completed_stages}, out_file)


def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]:
"""
Load a safetensors file from disk

:param file_path: path to the safetensors file
:return: dictionary of safetensors data
"""
with safe_open(file_path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}


def copy_python_files_from_model_cache(model, save_path: str):
config = model.config
cache_path = None
Expand Down
15 changes: 0 additions & 15 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
Utility / helper functions
"""

import random
from collections import OrderedDict
from collections.abc import Iterable, Mapping
from typing import Any

import numpy
import torch
from torch import Tensor
from torch.nn import Module

Expand All @@ -27,7 +25,6 @@
"tensors_module_forward",
"tensor_sparsity",
"get_quantized_layers",
"set_deterministic_seeds",
]


Expand Down Expand Up @@ -238,15 +235,3 @@ def get_quantized_layers(module: Module) -> list[tuple[str, Module]]:
quantized_layers.append((name, mod))

return quantized_layers


def set_deterministic_seeds(seed: int = 0):
"""
Manually seeds the numpy, random, and torch packages.
Also sets torch.backends.cudnn.deterministic to True
:param seed: the manual seed to use. Default is 0
"""
numpy.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
40 changes: 1 addition & 39 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""

import contextlib
import importlib.metadata
import importlib.util
import importlib
import re
from typing import Tuple, Union

import torch
from compressed_tensors.quantization import disable_quantization, enable_quantization
Expand All @@ -18,7 +16,6 @@
from llmcompressor.utils import get_embeddings

__all__ = [
"is_package_available",
"import_from_path",
"disable_cache",
"DisableQuantization",
Expand All @@ -30,41 +27,6 @@
]


def is_package_available(
package_name: str,
return_version: bool = False,
) -> Union[Tuple[bool, str], bool]:
"""
A helper function to check if a package is available
and optionally return its version. This function enforces
a check that the package is available and is not
just a directory/file with the same name as the package.

inspired from:
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/utils/import_utils.py#L41

:param package_name: The package name to check for
:param return_version: True to return the version of
the package if available
:return: True if the package is available, False otherwise or a tuple of
(bool, version) if return_version is True
"""

package_exists = importlib.util.find_spec(package_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(package_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
logger.debug(f"Detected {package_name} version {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists


def import_from_path(path: str) -> str:
"""
Import the module and the name of the function/class separated by :
Expand Down
Loading