Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/lmeval/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from llmcompressor.core import active_session
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
from tests.test_timer.timer_utils import get_singleton_manager, log_time
from tests.testing_utils import requires_gpu
from tests.testing_utils import cached_lm_eval_run, requires_gpu


class LmEvalConfig(BaseModel):
Expand Down Expand Up @@ -149,8 +149,9 @@ def test_lm_eval(self, test_data_file: str):
self.tear_down()

@log_time
@cached_lm_eval_run
def _eval_base_model(self):
"""Evaluate the base (uncompressed) model."""
"""Evaluate the base (uncompressed) model with caching."""
model_args = {**self.lmeval.model_args, "pretrained": self.model}

results = lm_eval.simple_evaluate(
Expand Down
90 changes: 89 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,87 @@
import dataclasses
import enum
import hashlib
import json
import logging
import os
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from pathlib import Path
from subprocess import PIPE, STDOUT, run
from typing import Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union

import pytest
import torch
import yaml
from datasets import Dataset
from loguru import logger
from transformers import ProcessorMixin

TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", None)
DISABLE_LMEVAL_CACHE = os.environ.get("DISABLE_LMEVAL_CACHE", "").lower() in (
"1",
"true",
"yes",
)
LMEVAL_CACHE_DIR = Path(os.environ.get("LMEVAL_CACHE_DIR", ".lmeval_cache"))


def _sha256_hash(text: str, length: Optional[int] = None) -> str:
hash_result = hashlib.sha256(text.encode()).hexdigest()
return hash_result[:length] if length else hash_result


@dataclass(frozen=True)
class LMEvalCacheKey:
model: str
task: str
num_fewshot: int
limit: int
batch_size: int
model_args_hash: str

@classmethod
def from_test_instance(cls, test_instance: Any) -> "LMEvalCacheKey":
lmeval = test_instance.lmeval
model_args_json = json.dumps(lmeval.model_args, sort_keys=True)

return cls(
model=test_instance.model,
task=lmeval.task,
num_fewshot=lmeval.num_fewshot,
limit=lmeval.limit,
batch_size=lmeval.batch_size,
model_args_hash=_sha256_hash(model_args_json),
)

@property
def cache_filepath(self) -> Path:
key_components = (
f"{self.model}_{self.task}_{self.num_fewshot}_"
f"{self.limit}_{self.batch_size}_{self.model_args_hash}"
)
return LMEVAL_CACHE_DIR / f"{_sha256_hash(key_components, 16)}.json"

def get_cached_result(self) -> Optional[dict]:
if not self.cache_filepath.exists():
return

try:
with open(self.cache_filepath) as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load cache from {self.cache_filepath}: {e}")
return

def store_result(self, result: dict) -> None:
try:
self.cache_filepath.parent.mkdir(parents=True, exist_ok=True)
with open(self.cache_filepath, "w") as f:
json.dump(result, f, default=str)
logger.info(f"LM-Eval cache WRITE: {self.cache_filepath}")
except Exception as e:
logger.warning(f"Failed to save cache to {self.cache_filepath}: {e}")


# TODO: maybe test type as decorators?
Expand Down Expand Up @@ -300,3 +367,24 @@ def requires_cadence(cadence: Union[str, List[str]]) -> Callable:
return pytest.mark.skipif(
(current_cadence not in cadence), reason="cadence mismatch"
)


def cached_lm_eval_run(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, *args, **kwargs):
if DISABLE_LMEVAL_CACHE:
logger.info("LM-Eval cache disabled")
return func(self, *args, **kwargs)

cache_key = LMEvalCacheKey.from_test_instance(self)

if (cached_result := cache_key.get_cached_result()) is not None:
logger.info(f"LM-Eval cache HIT: {cache_key.cache_filepath}")
return cached_result

logger.info(f"LM-Eval cache MISS: {cache_key.cache_filepath}")
result = func(self, *args, **kwargs)
cache_key.store_result(result)
return result

return wrapper