diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index a44cd042f..deee4117f 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -15,7 +15,7 @@ from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing from tests.test_timer.timer_utils import get_singleton_manager, log_time -from tests.testing_utils import requires_gpu +from tests.testing_utils import cached_lm_eval_run, requires_gpu class LmEvalConfig(BaseModel): @@ -149,8 +149,9 @@ def test_lm_eval(self, test_data_file: str): self.tear_down() @log_time + @cached_lm_eval_run def _eval_base_model(self): - """Evaluate the base (uncompressed) model.""" + """Evaluate the base (uncompressed) model with caching.""" model_args = {**self.lmeval.model_args, "pretrained": self.model} results = lm_eval.simple_evaluate( diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 4ce6a5de6..001ff96c2 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,20 +1,87 @@ import dataclasses import enum +import hashlib +import json import logging import os from dataclasses import dataclass from enum import Enum +from functools import wraps from pathlib import Path from subprocess import PIPE, STDOUT, run -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import pytest import torch import yaml from datasets import Dataset +from loguru import logger from transformers import ProcessorMixin TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", None) +DISABLE_LMEVAL_CACHE = os.environ.get("DISABLE_LMEVAL_CACHE", "").lower() in ( + "1", + "true", + "yes", +) +LMEVAL_CACHE_DIR = Path(os.environ.get("LMEVAL_CACHE_DIR", ".lmeval_cache")) + + +def _sha256_hash(text: str, length: Optional[int] = None) -> str: + hash_result = hashlib.sha256(text.encode()).hexdigest() + return hash_result[:length] if length else hash_result + + +@dataclass(frozen=True) +class LMEvalCacheKey: + model: str + task: str + num_fewshot: int + limit: int + batch_size: int + model_args_hash: str + + @classmethod + def from_test_instance(cls, test_instance: Any) -> "LMEvalCacheKey": + lmeval = test_instance.lmeval + model_args_json = json.dumps(lmeval.model_args, sort_keys=True) + + return cls( + model=test_instance.model, + task=lmeval.task, + num_fewshot=lmeval.num_fewshot, + limit=lmeval.limit, + batch_size=lmeval.batch_size, + model_args_hash=_sha256_hash(model_args_json), + ) + + @property + def cache_filepath(self) -> Path: + key_components = ( + f"{self.model}_{self.task}_{self.num_fewshot}_" + f"{self.limit}_{self.batch_size}_{self.model_args_hash}" + ) + return LMEVAL_CACHE_DIR / f"{_sha256_hash(key_components, 16)}.json" + + def get_cached_result(self) -> Optional[dict]: + if not self.cache_filepath.exists(): + return + + try: + with open(self.cache_filepath) as f: + return json.load(f) + except Exception as e: + logger.warning(f"Failed to load cache from {self.cache_filepath}: {e}") + return + + def store_result(self, result: dict) -> None: + try: + self.cache_filepath.parent.mkdir(parents=True, exist_ok=True) + with open(self.cache_filepath, "w") as f: + json.dump(result, f, default=str) + logger.info(f"LM-Eval cache WRITE: {self.cache_filepath}") + except Exception as e: + logger.warning(f"Failed to save cache to {self.cache_filepath}: {e}") # TODO: maybe test type as decorators? @@ -300,3 +367,24 @@ def requires_cadence(cadence: Union[str, List[str]]) -> Callable: return pytest.mark.skipif( (current_cadence not in cadence), reason="cadence mismatch" ) + + +def cached_lm_eval_run(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs): + if DISABLE_LMEVAL_CACHE: + logger.info("LM-Eval cache disabled") + return func(self, *args, **kwargs) + + cache_key = LMEvalCacheKey.from_test_instance(self) + + if (cached_result := cache_key.get_cached_result()) is not None: + logger.info(f"LM-Eval cache HIT: {cache_key.cache_filepath}") + return cached_result + + logger.info(f"LM-Eval cache MISS: {cache_key.cache_filepath}") + result = func(self, *args, **kwargs) + cache_key.store_result(result) + return result + + return wrapper