diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index a44cd042f..deee4117f 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -15,7 +15,7 @@ from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing from tests.test_timer.timer_utils import get_singleton_manager, log_time -from tests.testing_utils import requires_gpu +from tests.testing_utils import cached_lm_eval_run, requires_gpu class LmEvalConfig(BaseModel): @@ -149,8 +149,9 @@ def test_lm_eval(self, test_data_file: str): self.tear_down() @log_time + @cached_lm_eval_run def _eval_base_model(self): - """Evaluate the base (uncompressed) model.""" + """Evaluate the base (uncompressed) model with caching.""" model_args = {**self.lmeval.model_args, "pretrained": self.model} results = lm_eval.simple_evaluate( diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 4ce6a5de6..cb63b4384 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,20 +1,51 @@ import dataclasses import enum +import hashlib +import json import logging import os from dataclasses import dataclass from enum import Enum +from functools import wraps from pathlib import Path from subprocess import PIPE, STDOUT, run -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, NamedTuple, Optional, Union import pytest import torch import yaml from datasets import Dataset +from loguru import logger from transformers import ProcessorMixin TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", None) +DISABLE_LMEVAL_CACHE = os.environ.get("DISABLE_LMEVAL_CACHE", "").lower() in ( + "1", + "true", + "yes", +) + +# Module-level cache - persists for duration of Python process +_LMEVAL_CACHE: dict = {} + + +class LMEvalCacheKey(NamedTuple): + """Hashable cache key for a base model evaluation. + + :param model: HuggingFace model identifier + :param task: LM-Eval task name + :param num_fewshot: Number of few-shot examples + :param limit: Maximum number of samples to evaluate + :param batch_size: Batch size for evaluation + :param model_args_hash: SHA256 hash of model_args dict + """ + + model: str + task: str + num_fewshot: int + limit: int + batch_size: int + model_args_hash: str # TODO: maybe test type as decorators? @@ -300,3 +331,49 @@ def requires_cadence(cadence: Union[str, List[str]]) -> Callable: return pytest.mark.skipif( (current_cadence not in cadence), reason="cadence mismatch" ) + + +def _make_lmeval_cache_key(test_instance: Any) -> LMEvalCacheKey: + """Create a hashable cache key from a TestLMEval instance. + + :param test_instance: Instance with model, lmeval attributes + :return: LMEvalCacheKey for this evaluation configuration + :raises AttributeError: If required attributes are missing + """ + model_args = test_instance.lmeval.model_args + args_str = json.dumps(model_args, sort_keys=True) + args_hash = hashlib.sha256(args_str.encode()).hexdigest() + + return LMEvalCacheKey( + model=test_instance.model, + task=test_instance.lmeval.task, + num_fewshot=test_instance.lmeval.num_fewshot, + limit=test_instance.lmeval.limit, + batch_size=test_instance.lmeval.batch_size, + model_args_hash=args_hash, + ) + + + +def cached_lm_eval_run(func: Callable) -> Callable: + """Decorator that caches LM-Eval results for instance methods, with optional disabling.""" + @wraps(func) + def cached(self, *args, **kwargs): + if DISABLE_LMEVAL_CACHE: + logger.info("LM-Eval cache disabled via DISABLE_LMEVAL_CACHE") + return func(self, *args, **kwargs) + + key = _make_lmeval_cache_key(self) + cached_result = _LMEVAL_CACHE.get(key) + if cached_result is not None: + logger.info(f"LM-Eval cache HIT: {key}") + return cached_result + + logger.info(f"LM-Eval cache MISS: {key}") + result = func(self, *args, **kwargs) + _LMEVAL_CACHE[key] = result + logger.info(f"LM-Eval cache WRITE: {key} ({len(_LMEVAL_CACHE)} entries)") + return result + + return cached +