From da215d3e22dce9dbeff97c47f0bf4543e4f15a28 Mon Sep 17 00:00:00 2001 From: NJX-njx <3771829673@qq.com> Date: Wed, 4 Mar 2026 18:31:33 +0800 Subject: [PATCH] feat: add KL-divergence evaluation tool for quantized models Ref #2031 Implement a tool to evaluate how well quantized models preserve the original model's probability distribution using KL divergence. Features: - Computes KLD(base || target) and KLD(target || base) since KL divergence is asymmetric - Reports per-sample and aggregate statistics - Supports any HuggingFace causal LM and dataset - Python API via evaluate_kl_divergence() function - CLI via python -m llmcompressor.evaluation.kl_divergence - Memory-efficient: logits moved to CPU immediately to free GPU VRAM Files added: - src/llmcompressor/evaluation/__init__.py - src/llmcompressor/evaluation/__main__.py - src/llmcompressor/evaluation/kl_divergence.py - tests/llmcompressor/evaluation/test_kl_divergence.py --- src/llmcompressor/evaluation/__init__.py | 10 + src/llmcompressor/evaluation/__main__.py | 5 + src/llmcompressor/evaluation/kl_divergence.py | 393 ++++++++++++++++++ tests/llmcompressor/evaluation/__init__.py | 0 .../evaluation/test_kl_divergence.py | 165 ++++++++ 5 files changed, 573 insertions(+) create mode 100644 src/llmcompressor/evaluation/__init__.py create mode 100644 src/llmcompressor/evaluation/__main__.py create mode 100644 src/llmcompressor/evaluation/kl_divergence.py create mode 100644 tests/llmcompressor/evaluation/__init__.py create mode 100644 tests/llmcompressor/evaluation/test_kl_divergence.py diff --git a/src/llmcompressor/evaluation/__init__.py b/src/llmcompressor/evaluation/__init__.py new file mode 100644 index 0000000000..34a8b3c7d0 --- /dev/null +++ b/src/llmcompressor/evaluation/__init__.py @@ -0,0 +1,10 @@ +""" +Evaluation utilities for assessing the quality of compressed/quantized models. +""" + +from llmcompressor.evaluation.kl_divergence import ( + KLDivergenceResult, + evaluate_kl_divergence, +) + +__all__ = ["evaluate_kl_divergence", "KLDivergenceResult"] diff --git a/src/llmcompressor/evaluation/__main__.py b/src/llmcompressor/evaluation/__main__.py new file mode 100644 index 0000000000..dd88f56ddb --- /dev/null +++ b/src/llmcompressor/evaluation/__main__.py @@ -0,0 +1,5 @@ +"""Allow running KL-divergence evaluation as ``python -m llmcompressor.evaluation.kl_divergence``.""" + +from llmcompressor.evaluation.kl_divergence import main + +main() diff --git a/src/llmcompressor/evaluation/kl_divergence.py b/src/llmcompressor/evaluation/kl_divergence.py new file mode 100644 index 0000000000..46d1ceef40 --- /dev/null +++ b/src/llmcompressor/evaluation/kl_divergence.py @@ -0,0 +1,393 @@ +""" +KL-divergence evaluation tool for comparing quantized models against +their base (unquantized) counterparts. + +Computes forward KLD(base || quant) and reverse KLD(quant || base) to +measure how well a quantized model preserves the original probability +distribution. KL divergence is asymmetric, so both directions are +reported. + +Usage (CLI):: + + python -m llmcompressor.evaluation.kl_divergence \\ + --base-model meta-llama/Llama-3.1-8B \\ + --target-model quantized-model-path \\ + --dataset wikitext --dataset-config wikitext-2-raw-v1 \\ + --num-samples 128 --max-seq-length 512 + +Usage (Python API):: + + from llmcompressor.evaluation.kl_divergence import evaluate_kl_divergence + + results = evaluate_kl_divergence( + base_model="meta-llama/Llama-3.1-8B", + target_model="quantized-model-path", + dataset_id="wikitext", + dataset_config="wikitext-2-raw-v1", + num_samples=128, + max_seq_length=512, + ) + print(results) +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +__all__ = ["evaluate_kl_divergence", "KLDivergenceResult"] + + +# --------------------------------------------------------------------------- +# Result container +# --------------------------------------------------------------------------- + + +@dataclass +class KLDivergenceResult: + """Stores per-sample and aggregate KL-divergence statistics.""" + + # Aggregate + forward_kld_mean: float = 0.0 # KLD(base || target), averaged over tokens + reverse_kld_mean: float = 0.0 # KLD(target || base), averaged over tokens + symmetric_kld_mean: float = 0.0 # (forward + reverse) / 2 + num_samples: int = 0 + num_tokens: int = 0 + + # Per-sample lists + forward_kld_per_sample: list[float] = field(default_factory=list) + reverse_kld_per_sample: list[float] = field(default_factory=list) + + def to_dict(self) -> dict: + return asdict(self) + + def summary(self) -> str: + return ( + f"KL-Divergence Evaluation ({self.num_samples} samples, " + f"{self.num_tokens} tokens)\n" + f" KLD(base || target): {self.forward_kld_mean:.6f}\n" + f" KLD(target || base): {self.reverse_kld_mean:.6f}\n" + f" Symmetric KLD: {self.symmetric_kld_mean:.6f}" + ) + + +# --------------------------------------------------------------------------- +# Core computation +# --------------------------------------------------------------------------- + + +def _kl_divergence_per_token( + log_probs_p: torch.Tensor, + log_probs_q: torch.Tensor, +) -> torch.Tensor: + """ + Compute token-level KL(P || Q) from log-probabilities. + + KL(P || Q) = sum_x P(x) * (log P(x) - log Q(x)) + + :param log_probs_p: shape (seq_len, vocab_size), log-softmax of P + :param log_probs_q: shape (seq_len, vocab_size), log-softmax of Q + :return: shape (seq_len,), per-token KL divergence + """ + p = log_probs_p.exp() + kl = (p * (log_probs_p - log_probs_q)).sum(dim=-1) + return kl + + +@torch.no_grad() +def _collect_logits( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Run a forward pass and return logits (on CPU to save GPU memory). + + :param model: the causal LM + :param input_ids: shape (1, seq_len) + :param attention_mask: shape (1, seq_len) or None + :return: logits tensor of shape (seq_len, vocab_size) + """ + outputs = model( + input_ids=input_ids.to(model.device), + attention_mask=( + attention_mask.to(model.device) if attention_mask is not None else None + ), + ) + # Move to CPU immediately to free GPU memory + return outputs.logits[0].float().cpu() + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def evaluate_kl_divergence( + base_model: str | torch.nn.Module, + target_model: str | torch.nn.Module, + dataset_id: str = "wikitext", + dataset_config: str | None = "wikitext-2-raw-v1", + dataset_split: str = "test", + text_column: str = "text", + num_samples: int = 128, + max_seq_length: int = 512, + batch_size: int = 1, + device: str | None = None, + base_model_kwargs: dict | None = None, + target_model_kwargs: dict | None = None, +) -> KLDivergenceResult: + """ + Evaluate KL divergence between a base model and a target (quantized) model. + + :param base_model: HuggingFace model ID or an already-loaded model + :param target_model: HuggingFace model ID or an already-loaded model + :param dataset_id: HuggingFace dataset ID for evaluation + :param dataset_config: dataset configuration name (e.g. "wikitext-2-raw-v1") + :param dataset_split: dataset split to use + :param text_column: name of the text column in the dataset + :param num_samples: number of samples to evaluate + :param max_seq_length: maximum token sequence length + :param batch_size: not used yet (reserved for future batched evaluation) + :param device: device to run on ("cuda", "cpu", "auto"). Defaults to "auto" + :param base_model_kwargs: additional kwargs for AutoModelForCausalLM.from_pretrained + :param target_model_kwargs: additional kwargs for AutoModelForCausalLM.from_pretrained + :return: KLDivergenceResult with per-sample and aggregate statistics + """ + if device is None: + device = "auto" + + base_model_kwargs = base_model_kwargs or {} + target_model_kwargs = target_model_kwargs or {} + + # --- Load models --- + if isinstance(base_model, str): + logger.info("Loading base model: %s", base_model) + tokenizer = AutoTokenizer.from_pretrained(base_model) + base_model_obj = AutoModelForCausalLM.from_pretrained( + base_model, + torch_dtype="auto", + device_map=device, + **base_model_kwargs, + ) + base_model_id = base_model + else: + base_model_obj = base_model + base_model_id = getattr(base_model, "name_or_path", "base_model") + tokenizer = AutoTokenizer.from_pretrained(base_model_id) + + if isinstance(target_model, str): + logger.info("Loading target model: %s", target_model) + target_model_obj = AutoModelForCausalLM.from_pretrained( + target_model, + torch_dtype="auto", + device_map=device, + **target_model_kwargs, + ) + else: + target_model_obj = target_model + + base_model_obj.eval() + target_model_obj.eval() + + # --- Load and tokenize dataset --- + logger.info( + "Loading dataset: %s (config=%s, split=%s)", + dataset_id, + dataset_config, + dataset_split, + ) + + ds_kwargs = {"split": dataset_split} + if dataset_config: + ds = load_dataset(dataset_id, dataset_config, **ds_kwargs) + else: + ds = load_dataset(dataset_id, **ds_kwargs) + + # Filter out empty texts + ds = ds.filter(lambda x: len(x[text_column].strip()) > 0) + + if len(ds) > num_samples: + ds = ds.select(range(num_samples)) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # --- Evaluate --- + result = KLDivergenceResult(num_samples=len(ds)) + total_forward_kld = 0.0 + total_reverse_kld = 0.0 + total_tokens = 0 + + for sample in tqdm(ds, desc="Evaluating KL divergence"): + text = sample[text_column] + inputs = tokenizer( + text, + return_tensors="pt", + max_length=max_seq_length, + truncation=True, + padding=False, + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask") + + if input_ids.shape[1] < 2: + continue + + # Get logits from both models + base_logits = _collect_logits(base_model_obj, input_ids, attention_mask) + target_logits = _collect_logits( + target_model_obj, input_ids, attention_mask + ) + + # Convert to log-probabilities + base_log_probs = F.log_softmax(base_logits, dim=-1) + target_log_probs = F.log_softmax(target_logits, dim=-1) + + # Compute KL divergence per token (skip the last position which has no + # next-token prediction to compare against, but for logit comparison + # all positions are valid) + forward_kld = _kl_divergence_per_token(base_log_probs, target_log_probs) + reverse_kld = _kl_divergence_per_token(target_log_probs, base_log_probs) + + # Clamp to avoid negative values from numerical imprecision + forward_kld = forward_kld.clamp(min=0.0) + reverse_kld = reverse_kld.clamp(min=0.0) + + seq_len = forward_kld.shape[0] + sample_fwd = forward_kld.mean().item() + sample_rev = reverse_kld.mean().item() + + result.forward_kld_per_sample.append(sample_fwd) + result.reverse_kld_per_sample.append(sample_rev) + + total_forward_kld += forward_kld.sum().item() + total_reverse_kld += reverse_kld.sum().item() + total_tokens += seq_len + + # Aggregate + if total_tokens > 0: + result.forward_kld_mean = total_forward_kld / total_tokens + result.reverse_kld_mean = total_reverse_kld / total_tokens + result.symmetric_kld_mean = ( + result.forward_kld_mean + result.reverse_kld_mean + ) / 2.0 + result.num_tokens = total_tokens + + return result + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate KL divergence between a base model and a " + "quantized/target model.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--base-model", + type=str, + required=True, + help="HuggingFace model ID or path for the base (unquantized) model", + ) + parser.add_argument( + "--target-model", + type=str, + required=True, + help="HuggingFace model ID or path for the target (quantized) model", + ) + parser.add_argument( + "--dataset", + type=str, + default="wikitext", + help="HuggingFace dataset ID (default: wikitext)", + ) + parser.add_argument( + "--dataset-config", + type=str, + default="wikitext-2-raw-v1", + help="Dataset configuration (default: wikitext-2-raw-v1)", + ) + parser.add_argument( + "--dataset-split", + type=str, + default="test", + help="Dataset split (default: test)", + ) + parser.add_argument( + "--text-column", + type=str, + default="text", + help="Name of the text column in the dataset (default: text)", + ) + parser.add_argument( + "--num-samples", + type=int, + default=128, + help="Number of samples to evaluate (default: 128)", + ) + parser.add_argument( + "--max-seq-length", + type=int, + default=512, + help="Maximum sequence length (default: 512)", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device to run on (default: auto)", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to save JSON results (optional)", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") + + result = evaluate_kl_divergence( + base_model=args.base_model, + target_model=args.target_model, + dataset_id=args.dataset, + dataset_config=args.dataset_config, + dataset_split=args.dataset_split, + text_column=args.text_column, + num_samples=args.num_samples, + max_seq_length=args.max_seq_length, + device=args.device, + ) + + print("\n" + result.summary()) + + if args.output: + with open(args.output, "w") as f: + json.dump(result.to_dict(), f, indent=2) + print(f"\nResults saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tests/llmcompressor/evaluation/__init__.py b/tests/llmcompressor/evaluation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/llmcompressor/evaluation/test_kl_divergence.py b/tests/llmcompressor/evaluation/test_kl_divergence.py new file mode 100644 index 0000000000..e68320d472 --- /dev/null +++ b/tests/llmcompressor/evaluation/test_kl_divergence.py @@ -0,0 +1,165 @@ +""" +Unit tests for the KL-divergence evaluation tool. + +Tests core computation logic using small synthetic models and tensors. +Does not require GPU or large model downloads. +""" + +import pytest +import torch +import torch.nn.functional as F + +from llmcompressor.evaluation.kl_divergence import ( + KLDivergenceResult, + _kl_divergence_per_token, + evaluate_kl_divergence, +) + + +class TestKLDivergencePerToken: + """Tests for the core KL divergence computation.""" + + def test_identical_distributions(self): + """KL divergence of identical distributions should be zero.""" + logits = torch.randn(10, 100) + log_probs = F.log_softmax(logits, dim=-1) + + kl = _kl_divergence_per_token(log_probs, log_probs) + assert kl.shape == (10,) + assert torch.allclose(kl, torch.zeros(10), atol=1e-5) + + def test_different_distributions(self): + """KL divergence of different distributions should be positive.""" + torch.manual_seed(42) + logits_p = torch.randn(5, 50) + logits_q = torch.randn(5, 50) + + log_probs_p = F.log_softmax(logits_p, dim=-1) + log_probs_q = F.log_softmax(logits_q, dim=-1) + + forward_kl = _kl_divergence_per_token(log_probs_p, log_probs_q) + reverse_kl = _kl_divergence_per_token(log_probs_q, log_probs_p) + + # KL should be positive + assert (forward_kl >= -1e-6).all() + assert (reverse_kl >= -1e-6).all() + + # KL should be asymmetric (generally different values) + assert not torch.allclose(forward_kl, reverse_kl, atol=1e-3) + + def test_known_value(self): + """Test against manually computed KL divergence.""" + # P = [0.5, 0.5], Q = [0.25, 0.75] + # KL(P||Q) = 0.5*log(0.5/0.25) + 0.5*log(0.5/0.75) + # = 0.5*log(2) + 0.5*log(2/3) + # ≈ 0.5*0.6931 + 0.5*(-0.4055) ≈ 0.1438 + import math + + expected_kl = 0.5 * math.log(2) + 0.5 * math.log(2 / 3) + + p = torch.tensor([[0.5, 0.5]]) + q = torch.tensor([[0.25, 0.75]]) + log_p = p.log() + log_q = q.log() + + kl = _kl_divergence_per_token(log_p, log_q) + assert kl.shape == (1,) + assert abs(kl.item() - expected_kl) < 1e-5 + + +class TestKLDivergenceResult: + """Tests for the result dataclass.""" + + def test_summary(self): + result = KLDivergenceResult( + forward_kld_mean=0.05, + reverse_kld_mean=0.08, + symmetric_kld_mean=0.065, + num_samples=10, + num_tokens=500, + ) + summary = result.summary() + assert "0.050000" in summary + assert "0.080000" in summary + assert "10 samples" in summary + assert "500 tokens" in summary + + def test_to_dict(self): + result = KLDivergenceResult( + forward_kld_mean=0.1, + reverse_kld_mean=0.2, + num_samples=5, + num_tokens=100, + ) + d = result.to_dict() + assert d["forward_kld_mean"] == 0.1 + assert d["reverse_kld_mean"] == 0.2 + assert d["num_samples"] == 5 + assert isinstance(d, dict) + + +class TestEvaluateKLDivergenceWithMocks: + """Integration-style tests using tiny models.""" + + @pytest.fixture + def tiny_models(self): + """Create two tiny randomly-initialized models for testing.""" + try: + from transformers import AutoConfig, AutoModelForCausalLM + + config = AutoConfig.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + except Exception: + pytest.skip("Cannot load tiny model config (no network or HF access)") + + model_a = AutoModelForCausalLM.from_config(config) + model_b = AutoModelForCausalLM.from_config(config) + model_a.eval() + model_b.eval() + return model_a, model_b + + def test_evaluate_with_preloaded_models(self, tiny_models): + """Test that evaluation runs end-to-end with preloaded models.""" + model_a, model_b = tiny_models + + try: + result = evaluate_kl_divergence( + base_model=model_a, + target_model=model_b, + dataset_id="wikitext", + dataset_config="wikitext-2-raw-v1", + dataset_split="test", + num_samples=4, + max_seq_length=32, + device="cpu", + ) + except Exception: + pytest.skip("Cannot load dataset (no network)") + + assert isinstance(result, KLDivergenceResult) + assert result.num_tokens > 0 + assert result.forward_kld_mean >= 0 + assert result.reverse_kld_mean >= 0 + assert len(result.forward_kld_per_sample) > 0 + + def test_same_model_gives_zero_kld(self, tiny_models): + """When evaluating the same model against itself, KLD should be ~0.""" + model_a, _ = tiny_models + + try: + result = evaluate_kl_divergence( + base_model=model_a, + target_model=model_a, + dataset_id="wikitext", + dataset_config="wikitext-2-raw-v1", + dataset_split="test", + num_samples=2, + max_seq_length=32, + device="cpu", + ) + except Exception: + pytest.skip("Cannot load dataset (no network)") + + assert result.forward_kld_mean < 1e-4 + assert result.reverse_kld_mean < 1e-4