diff --git a/examples/evaluate_precision/qwen.yaml b/examples/evaluate_precision/qwen.yaml new file mode 100644 index 000000000..9a28e270d --- /dev/null +++ b/examples/evaluate_precision/qwen.yaml @@ -0,0 +1,25 @@ +# Precision-evaluation config on Qwen2.5-0.5B — the model used for the Fast-LLM vs DeepSpeed +# precision-pattern comparison (DeepSpeed side: tools/evaluate_precision_deepspeed.py). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen.yaml +pretrained: + path: Qwen/Qwen2.5-0.5B + format: qwen2 +output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_features +sequence_length: 2048 +variants: + # Maps to the DeepSpeed harness's `bf16_head_bf16` (compute bf16, lm head in compute dtype). + bf16: + model.distributed.compute_dtype: bfloat16 + # Maps to the DeepSpeed harness's `bf16` (compute bf16, fp32 lm head — the stack default). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Maps to the DeepSpeed harness's `fp16_head_fp16`. + fp16: + model.distributed.compute_dtype: float16 + # Maps to the DeepSpeed harness's `fp16`. + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/sample_text.txt b/examples/evaluate_precision/sample_text.txt new file mode 100644 index 000000000..8b207173d --- /dev/null +++ b/examples/evaluate_precision/sample_text.txt @@ -0,0 +1,23 @@ +The history of computing is often told as a story of ever-smaller and ever-faster machines, but the more interesting thread is the slow accumulation of good abstractions. Early programmers spoke directly to the hardware, toggling switches and rewiring panels, and every problem had to be solved in the vocabulary of the machine in front of them. The arrival of assembly language, and then of compiled languages, did not make the computers any faster; it made the programmers faster, because it let them think in terms closer to the problem and further from the circuitry. Each new layer hid a mess of detail beneath a clean interface, and each clean interface freed the people above it to build something larger than the layer below could have imagined. + +Numerical computation followed the same pattern, though its abstractions were mathematical rather than mechanical. The first scientific programs tracked every digit by hand, and a single rounding decision could quietly ruin a long calculation. Floating point arithmetic was a hard-won compromise: it traded a little accuracy for an enormous gain in range and convenience, and it came with rules subtle enough that careful engineers spent entire careers studying them. The promise was never that the answers would be exact, only that the errors would be small and, more importantly, predictable. A method whose errors stay bounded and behave smoothly is far more useful than one that is occasionally perfect and occasionally catastrophic, because predictability is what lets you reason about a system you cannot fully observe. + +This distinction between bounded error and occasional disaster runs through the whole of engineering. A bridge is not designed to bear exactly the load it will encounter; it is designed with margins, so that the inevitable surprises fall inside a region the designer has already considered. Software that processes real data is no different. The inputs will be messier than the specification promised, the edge cases will arrive in combinations nobody enumerated, and the only durable defense is to build systems whose failure modes are gentle. A program that degrades gracefully under unexpected input is worth more than one that is flawless on the cases its author happened to imagine, because the world is under no obligation to supply only imaginable cases. + +Modern machine learning lives squarely inside this tradition, even when its practitioners do not describe it that way. Training a large model means multiplying enormous matrices billions of times, and the precision of each multiplication is a design choice rather than a fixed fact of nature. Lower precision means smaller numbers to move and faster hardware to move them, but it also means coarser rounding, and the central question is always whether that rounding stays in the harmless regime or crosses into the dangerous one. The answer depends on the model, the data, and the particular sequence of operations involved, which is exactly why it has to be measured rather than assumed. Intuition about numerical behavior is notoriously unreliable at scale, where quantities interact in ways that small examples never reveal. + +Consider what happens to a single number as it flows through a deep network. It begins as an input, is scaled and shifted and combined with thousands of its neighbors, passes through a nonlinearity, and emerges as part of the input to the next layer, where the whole process repeats. By the time it reaches the final layer it has been transformed dozens of times, and any error introduced early has had dozens of opportunities to grow or shrink. Sometimes these errors cancel, averaging out across many independent contributions; sometimes they reinforce, when the same systematic bias is applied at every step. The difference between these two fates is the difference between a model that trains stably and one that diverges for reasons its authors struggle to explain. + +The output layer deserves special attention, because it is where the model finally commits to a prediction. Up to that point the internal representations are abstract and somewhat forgiving; small perturbations shift them a little without changing their meaning. But the final projection turns those representations into concrete scores over a large vocabulary, and those scores are then exponentiated and normalized into probabilities. Exponentiation is unforgiving of additive error: a small shift in a score becomes a multiplicative change in a probability, and a small change in a probability can flip a decision. This is why the precision of the last step is often discussed out of proportion to its share of the total computation. It is not that the last matrix multiply is expensive; it is that it sits at the most sensitive point in the pipeline. + +Yet sensitivity at a single point does not automatically translate into importance for the whole. If the representation arriving at that point already carries substantial error from everything upstream, then cleaning up only the final step yields little, because the dominant error was introduced earlier and is simply passed through. The benefit of high precision at the output is largest exactly when the rest of the pipeline is already clean, and smallest when the upstream is noisy. This is a general principle of error analysis that beginners frequently miss: the value of fixing one stage depends entirely on whether that stage is the bottleneck, and the bottleneck is rarely where attention is first drawn. + +There is a further subtlety, which is that the magnitude of the quantities involved changes how much a fixed rounding error matters in relative terms. When a model is confident, the score it assigns to the chosen outcome is close to the maximum, the corresponding log probability is close to zero, and a small absolute error in that log probability is a large fraction of its tiny value. When a model is uncertain, spreading its belief across many outcomes, the same log probability is a large negative number, and the identical absolute error is a negligible fraction of it. The relative importance of a rounding step therefore depends not only on where it sits in the pipeline but on the regime the model is operating in, which is set by the data it happens to be processing at that moment. + +This is why measurements that look contradictory are often perfectly consistent once the regime is accounted for. A change that appears to make no difference on one dataset can make a clear difference on another, not because the underlying arithmetic changed, but because the quantities being rounded shifted from one regime to the other. An honest investigation reports both results and the condition that distinguishes them, rather than picking whichever supports a tidy story. The condition is the finding; the individual numbers are only evidence for it. + +Reinforcement learning from human feedback adds yet another layer to this picture, because it compares the behavior of two systems rather than examining one in isolation. A model generates text under one implementation and is then evaluated under another, and the learning signal depends on the ratio between the probabilities the two implementations assign to the same tokens. If the two implementations agree, the ratio is near one and the signal is clean; if they disagree systematically, the ratio carries a bias that no amount of careful optimization can remove, because it is baked into the comparison itself. The danger here is not random noise, which averages away over many samples, but systematic disagreement, which does not. Two correct-looking systems can still disagree in a way that quietly corrupts everything built on top of their comparison. + +The practical lesson is that matching matters more than absolute accuracy in this setting. It is better for two systems to be wrong in the same way than for one to be right and the other wrong, because a shared error cancels in the ratio while an unshared one does not. This inverts the usual intuition, which prizes accuracy above all. It explains why engineers sometimes deliberately make a fast system reproduce the quirks of a slow one, rather than improving it, and why a change that improves a system in isolation can hurt the larger pipeline it lives in if it breaks an agreement that other parts relied upon. Consistency is a feature, even when it is consistency in imperfection. + +All of this argues for a particular discipline: measure the thing you actually care about, under the conditions it will actually face, and report the conditions alongside the numbers. Good measurement, like a good abstraction, is what lets us trust the layers we cannot see. It does not eliminate uncertainty, but it bounds it, and a bounded uncertainty is something an engineer can build on. The goal is never to pretend the errors are gone. The goal is to know how large they are, where they come from, and whether they stay in the gentle regime or threaten to cross into the steep one where small causes produce large and unwelcome effects. diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml new file mode 100644 index 000000000..cc17c19e0 --- /dev/null +++ b/examples/evaluate_precision/smol.yaml @@ -0,0 +1,59 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M. +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +output_dir: /tmp/fast_llm_tests/evaluate_precision/features +sequence_length: 2048 +variants: + # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). + bf16: + model.distributed.compute_dtype: bfloat16 + # Turn ON full-precision residual stream. + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + # Turn ON fp32 LM head matmul (PR #526). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Both stability features on (most precise bf16-compute configuration). + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true + # Diagnostic: enable bf16 reduced-precision reductions in cuBLAS GEMMs. Tests whether the + # within-engine bf16-vs-fp32 gap is sensitive to the partial-sum reduction precision (the + # MMA accumulator is fp32 by hardware on H100/A100; this flag affects split-K reductions). + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + # Diagnostic: simulate a "bf16 inputs, fp32 output" lm-head matmul kernel. fp32_lm_head=True + # upcasts inputs+weights to fp32, then matmul_precision='medium' runs the matmul through + # bf16 Tensor Cores anyway, then logits stay fp32. Tests whether fp32_lm_head's gain comes + # from input precision or from skipping the bf16 output cast. + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + # fp16 sweep: probes whether the precision-vs-noise picture (rms noise ~0.1 nats per token + # for bf16) shrinks ~8× for fp16 (10 mantissa bits vs 7), as the literature's "switch to + # fp16" recommendation implies. Default dynamic grad-scaler (initial 2^16) is uniform + # across variants, so relative comparisons stay meaningful. + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml new file mode 100644 index 000000000..b0e8e319d --- /dev/null +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -0,0 +1,52 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M +# with the GSPO policy-gradient loss (uses advantages and old log-probabilities). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol_gspo.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +model: + base_model: + head: + losses: + gspo: + type: gspo +output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo +data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data +sequence_length: 2048 +variants: + bf16: + model.distributed.compute_dtype: bfloat16 + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index a90bcdebc..fbfe60ac3 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -80,6 +80,12 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + output_hidden_states: list[str] = Field( + default_factory=list, + desc="Regex patterns to add to each model input's `output_hidden_states` set." + " Matching `_debug`-named tensors get populated into `kwargs[hidden_states]`" + " and (when running under a `Run` context) emitted into `tensor_logs`.", + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 16114cb80..000fcc01d 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -161,6 +161,13 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis self._set_target_inputs(model_inputs, config) + if config.output_hidden_states: + import re + + patterns = {re.compile(pattern) for pattern in config.output_hidden_states} + for model_input in model_inputs: + model_input.output_hidden_states.update(patterns) + return model_inputs def _set_target_inputs( diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index c055a7f2c..4c99798c5 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -100,6 +100,18 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name + @classmethod + def _resolve_path(cls, path: pathlib.Path) -> pathlib.Path: + """Resolve a local directory or HF Hub model id (e.g. ``meta-llama/Llama-3.2-1B``) to a + local snapshot directory. Local directories pass through unchanged; everything else is + materialized via :func:`huggingface_hub.snapshot_download` (cached on subsequent calls). + """ + if path.is_dir(): + return path + import huggingface_hub + + return pathlib.Path(huggingface_hub.snapshot_download(str(path))) + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: @@ -128,20 +140,32 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: { # transformers PretrainedConfig "_name_or_path", + "add_cross_attention", "architectures", "auto_map", "chunk_size_feed_forward", + "cross_attention_hidden_size", "dtype", + "finetuning_task", "id2label", + "is_decoder", "is_encoder_decoder", "label2id", "model_type", "output_attentions", "output_hidden_states", + "prefix", "problem_type", + "pruned_heads", "return_dict", + "task_specific_params", + "tf_legacy_loss", + "tie_encoder_decoder", + "tokenizer_class", "torch_dtype", + "torchscript", "transformers_version", + "use_bfloat16", "use_cache", # Token ids — generation/inference, not architecture. "bos_token_id", @@ -149,10 +173,39 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: "eos_token_id", "pad_token_id", "sep_token_id", + # Generation defaults — never architecture. + "bad_words_ids", + "begin_suppress_tokens", + "diversity_penalty", + "do_sample", + "early_stopping", + "encoder_no_repeat_ngram_size", + "exponential_decay_length_penalty", + "forced_bos_token_id", + "forced_eos_token_id", + "length_penalty", + "max_length", + "min_length", + "no_repeat_ngram_size", + "num_beam_groups", + "num_beams", + "num_return_sequences", + "output_scores", + "remove_invalid_values", + "repetition_penalty", + "return_dict_in_generate", + "suppress_tokens", + "temperature", + "top_k", + "top_p", + "typical_p", # Initialization / pretraining metadata Fast-LLM does not consume. "initializer_range", "max_position_embeddings", "pretraining_tp", + # Family markers / default-valued knobs serialized by recent transformers versions. + "is_llama_config", + "rope_interleaved", } ) @@ -181,28 +234,29 @@ def _load_weights( import transformers Assert.eq(self.get_shard_names(config), ("weights",)) - if (config.path / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.SAFE_WEIGHTS_NAME} - elif (config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") + directory = self._resolve_path(config.path) + if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} + elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } - elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.WEIGHTS_NAME} - elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") + elif (directory / transformers.utils.WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.WEIGHTS_NAME} + elif (directory / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } else: - raise FileNotFoundError(f"No compatible checkpoint found in {config.path}") + raise FileNotFoundError(f"No compatible checkpoint found in {directory}") for path in paths: logger.info(f"Loading from {path}") diff --git a/tests/utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py similarity index 69% rename from tests/utils/compare_tensor_logs.py rename to fast_llm/engine/config_utils/compare_tensor_logs.py index f02d62c79..dbad78a25 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -87,6 +87,52 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): # Avoid set to preserve ordering. return [key for key in dict_test if key in dict_ref] + def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict | None: + # Returns per-tensor error metrics, or None on shape/sampling mismatch. + if tensor_ref["shape"] != tensor_test["shape"]: + return None + if tensor_ref["step"] != tensor_test["step"]: + return None + sub_config = self._get_sub_config(step_name, tensor_name) + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale + scale_unreg = (samples_ref**2).mean() ** 0.5 + rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 + diff = samples_test - samples_ref + rms = (diff**2).mean() ** 0.5 + max_diff = diff.abs().max() + bias = diff.mean() + # Linear-regression decomposition: `test ≈ slope * ref + intercept + residual`. + # Useful for separating systematic distortion (slope ≠ 1) from per-position decorrelated + # noise (residual). For RL importance ratios, slope ≠ 1 indicates likely-token-dependent + # bias which is more dangerous than a uniform shift. + centered_test = samples_test - samples_test.mean() + centered_ref = samples_ref - samples_ref.mean() + var_ref = (centered_ref**2).mean() + var_test = (centered_test**2).mean() + cov = (centered_test * centered_ref).mean() + denom = (var_test * var_ref) ** 0.5 + correlation = (cov / denom).item() if denom > 0 else float("nan") + slope = (cov / var_ref).item() if var_ref > 0 else float("nan") + residual_var = (var_test - cov**2 / var_ref).clamp(min=0.0) if var_ref > 0 else var_test + residual_rms = residual_var**0.5 + return { + "rms_abs": rms.item(), + "rms_rel": (rms / rms_scale).item(), + "max_abs": max_diff.item(), + "max_rel": (max_diff / rms_scale).item(), + "ref_scale": scale_unreg.item(), + "ref_scale_regularized": rms_scale.item(), + "bias_abs": bias.item(), + "bias_rel": (bias / rms_scale).item(), + "correlation": correlation, + "slope": slope, + "residual_rms_abs": residual_rms.item(), + "residual_rms_rel": (residual_rms / rms_scale).item(), + } + def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): sub_config = self._get_sub_config(step_name, tensor_name) if tensor_ref["shape"] != tensor_test["shape"]: @@ -108,34 +154,33 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam ) return - samples_ref = tensor_ref["samples"].flatten().float() - samples_test = tensor_test["samples"].flatten().float() - if sub_config.scale != 1.0: - samples_test = samples_test / sub_config.scale - scale_unreg = (samples_ref**2).mean() ** 0.5 - rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + metrics = self._compute_diff(tensor_ref, tensor_test, step_name, tensor_name) + rms_scale = metrics["ref_scale_regularized"] + scale_unreg = metrics["ref_scale"] tensor_errors = [] - if rms > sub_config.rms_abs_tolerance: - tensor_errors.append(f" * RMS diff absolute = {rms} > {sub_config.rms_abs_tolerance}") + if metrics["rms_abs"] > sub_config.rms_abs_tolerance: + tensor_errors.append(f" * RMS diff absolute = {metrics['rms_abs']} > {sub_config.rms_abs_tolerance}") - if rms / rms_scale > sub_config.rms_rel_tolerance: + if metrics["rms_rel"] > sub_config.rms_rel_tolerance: tensor_errors.append( - f" * RMS diff scaled = {rms / rms_scale} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * RMS diff scaled = {metrics['rms_rel']} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) - if max_diff > sub_config.max_abs_tolerance: - tensor_errors.append(f" * Max diff absolute = {max_diff} > {sub_config.max_abs_tolerance}") + if metrics["max_abs"] > sub_config.max_abs_tolerance: + tensor_errors.append(f" * Max diff absolute = {metrics['max_abs']} > {sub_config.max_abs_tolerance}") - if max_diff / rms_scale > sub_config.max_rel_tolerance: + if metrics["max_rel"] > sub_config.max_rel_tolerance: tensor_errors.append( - f" * Max diff scaled = {max_diff / rms_scale} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * Max diff scaled = {metrics['max_rel']} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) if tensor_errors: + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale tensor_errors.extend( [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 32deb4562..b82d4c847 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -76,6 +76,15 @@ class TensorLogsConfig(Config): valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") + sample_level_overrides: dict[str, int] = Field( + default_factory=dict, + desc="Per-tensor sample-density overrides (regex pattern -> level)." + " For tensors whose logged name matches a pattern, the effective `log_tensor` level is" + " raised to the matching override (samples = 2 ** (level - 3))." + " Useful for sparse tensors like embedding-weight gradients where the default sampling" + " stride misses most non-zero rows.", + hint=FieldHint.logging, + ) class TensorLogs: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 958a3d228..96cb52f09 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -139,6 +139,14 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) + debug_hidden_states_log: list[str] = Field( + default_factory=list, + desc="Regex patterns for `_debug`-named tensors (`.`, e.g. `head.logits`," + " `decoder.0.norm_1`) to log to `tensor_logs`. Patterns are appended to each model" + " input's `output_hidden_states` set, so matching tensors are both populated into" + " `kwargs[hidden_states]` for downstream consumers and emitted into `tensor_logs`.", + hint=FieldHint.logging, + ) @config_class() diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 805eae1e5..0476a8107 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) +# Verbosity used for `output_hidden_states`-driven tensor logging. `log_tensor` collects sampled +# tensor values only at level >= 3; 13 matches the convention in the layer-comparison tests +# (1024 sampled values per tensor). +_HIDDEN_STATE_LOG_LEVEL = 13 + + class DebugLayer: """ A debugging utility for blocks. @@ -55,11 +61,14 @@ def __call__( if level > 1: log_pipeline_parallel_main_rank(lambda: log_memory_usage(name, str)) - if level > 0 and tensor is not None: + # `output_hidden_state` requests full-fidelity capture even when `model_debug_level` is + # off — clamp the log level so samples are saved alongside summary stats. + log_level = max(level, _HIDDEN_STATE_LOG_LEVEL) if output_hidden_state else level + if log_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, - level=level, + level=log_level, meta=meta, **logging_kwargs, ) @@ -67,7 +76,7 @@ def __call__( log_distributed_grad( "", tensor, - level=level, + level=log_level, meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bde33f297..6a0bfcfd6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 22c750082..8dd511480 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -285,12 +293,38 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - if grad is not None and self._config.final_logit_softcap is not None: + if grad is not None: + # `logits` has `requires_grad=False` (custom-autograd), so the existing + # `_debug(logits, ...)` can't auto-capture the gradient. Log it explicitly here + # so `output_hidden_states` patterns covering `head.logits` also catch the grad. + self._debug( + grad, + f"logits.grad{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._vocab_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) + + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + if self._config.final_logit_softcap is not None: grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ diff --git a/fast_llm/layers/language_model/loss/chosen_logprob.py b/fast_llm/layers/language_model/loss/chosen_logprob.py new file mode 100644 index 000000000..cb99e7c17 --- /dev/null +++ b/fast_llm/layers/language_model/loss/chosen_logprob.py @@ -0,0 +1,41 @@ +import math +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelChosenLogprobLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.logging import log_tensor + + +class LanguageModelChosenLogprobLoss[ConfigType: LanguageModelChosenLogprobLossConfig](LanguageModelLoss[ConfigType]): + """Logs log π(label) per position via the tensor-log pipeline; contributes nothing to gradients.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Don't surface a "chosen_logprob: 0" line in the training metrics. + self._do_register_loss = False + + def _forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + if self._vocab_parallel: + raise NotImplementedError("chosen_logprob loss does not support vocab parallel") + labels = self._get_labels(kwargs, split_index).reshape(-1).long() + with torch.no_grad(): + log_probs = torch.log_softmax(logits.float() * self._logits_scale_factor, dim=-1) + # Mask out-of-range labels (e.g. -100 for prompt tokens in RL data) before gather to + # avoid CUDA assert. Fast-LLM convention: any label < 0 is masked. + valid = labels >= 0 + safe_labels = labels.clamp(min=0) + chosen_logprob = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) + chosen_logprob = chosen_logprob[valid] + # Capture the full tensor: bias is the mean over all positions, not a sampled subset. + level = math.ceil(math.log2(max(chosen_logprob.numel(), 1))) + 3 + log_tensor(f"Global : {self._name}", chosen_logprob, level=level) + return torch.zeros((), dtype=logits.dtype, device=logits.device), grad_logits diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 9a220aacf..aa05fbb9a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -9,6 +9,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss from fast_llm.layers.language_model.loss.entropy_loss import ( LanguageModelDistillationLoss, @@ -186,6 +187,30 @@ def get_reference_models(self) -> set[str]: return {self.reference_model} +@config_class(dynamic_type={LanguageModelLossConfig: "chosen_logprob"}) +class LanguageModelChosenLogprobLossConfig(LanguageModelLossConfig): + """No-gradient diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + + The chosen-token log-prob is the scalar that policy-gradient importance ratios depend on, + so its precision drift is a more direct signal than bulk-logit RMS. + """ + + _abstract: typing.ClassVar[bool] = False + + weight: float = Field( + default=0.0, + hint=FieldHint.derived, + desc="Forced to 0: this loss has no gradient contribution.", + valid=check_field(Assert.eq, 0.0), + ) + + @property + def loss_class(self) -> "type[LanguageModelChosenLogprobLoss]": + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss + + return LanguageModelChosenLogprobLoss + + @config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) class LanguageModelZLossConfig(LanguageModelLossConfig): """Z-loss regularization to prevent overconfidence.""" diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 2619883d6..6326e7e4b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -131,6 +131,15 @@ def log_tensor[T]( ) -> T | None: if level < 1: return + # Per-tensor sample-density override: lets users boost the effective level for specific + # tensors (e.g. sparse embedding-weight gradients) via `TensorLogsConfig`. + overrides = TensorLogs.config.sample_level_overrides if TensorLogs.config else None + if overrides: + import re + + for pattern, override in overrides.items(): + if re.search(pattern, name): + level = max(level, override) tensor = tensor.detach() if tensor.ndim == 0: tensor = tensor[None] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2e9b4365b..f4d4b286a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -112,6 +112,7 @@ def get_preprocessing_config( return LanguageModelBatchPreprocessingConfig( phase=phase, micro_batch_splits=micro_batch_splits, + output_hidden_states=list(self._config.multi_stage.debug_hidden_states_log), **self._base_model.get_preprocessing_config(), ) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 184294551..04a24e2ae 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -2,13 +2,13 @@ from fast_llm.data.preparation.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.global_variables import TOKENIZER_PATH @pytest.fixture(scope="session") def common_tokenizer() -> Tokenizer: - download_santacoder_tokenizer() + download_test_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 0b4dbafc1..f3febae4b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -18,9 +18,9 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.utils import Assert, header -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 7ae26c2d6..c8b5fd004 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,7 +3,7 @@ import pytest -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup @@ -15,7 +15,7 @@ @pytest.fixture(scope="module") def tokenizer_path(): - download_santacoder_tokenizer() + download_test_tokenizer() return TOKENIZER_PATH diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 03ebac757..3c95d0dea 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -18,9 +18,9 @@ from fast_llm.data.dataset.sampled import logger from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preparation.tokenizer import TokenizerConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a2ea2f46e..e7b206cf5 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -14,7 +14,7 @@ from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH -def download_santacoder_tokenizer(): +def download_test_tokenizer(): if not TOKENIZER_FILE.is_file(): import transformers @@ -218,7 +218,7 @@ def _get_test_dataset( if has_grpo_data: source_schema["advantages"] = "advantages" - download_santacoder_tokenizer() + download_test_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( { "dataset": { diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f3bbbac8d..d08b023b9 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -4,7 +4,7 @@ import torch -from tests.utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py new file mode 100644 index 000000000..9d63a54c8 --- /dev/null +++ b/tools/evaluate_precision.py @@ -0,0 +1,707 @@ +import json +import logging +import math +import pathlib +import shutil +import statistics +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig + +# Populate the trainer dynamic-type registry. +import fast_llm.data.auto # noqa: F401 # isort:skip +import fast_llm.engine.checkpoint.convert # noqa: F401 # isort:skip +import fast_llm.models.auto # noqa: F401 # isort:skip + +logger = logging.getLogger(__name__) + + +_REFERENCE_NAME = "reference" +_MODEL_TYPE = "gpt" +# Embedding-weight gradients are row-sparse (only input-token rows non-zero), so a +# uniformly-spaced sample of vocab_size entries usually misses all of them. The pattern +# is applied via `TensorLogsConfig.sample_level_overrides` and picked up inside +# `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). +_SPARSE_GRAD_LEVEL = 23 +_SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} +_CHOSEN_LOGPROB_NAME = "chosen_logprob" +# Seed for the random-token fixed input when no input text file is given. +_INPUT_SEED = 0 +# Auto-calibration of the constant gradient scaler. Each variant runs a calibration pass at +# `scale=1` (no overflow risk), then the actual run uses the largest power-of-2 scale that +# keeps logged gradient magnitudes (and a small safety factor for hidden in-kernel +# intermediates like norm partial sums) within fp16's representable range. Per-variant +# unscaling at compare time lets different variants pick different scales without polluting +# the relative metrics. +_HIDDEN_INTERMEDIATE_HEADROOM = 4.0 # safety factor for fused-kernel partial sums we don't log +_CALIBRATION_SUBDIR_PREFIX = ".calibration_" +# Variant-override keys starting with this prefix are interpreted as `torch.backends.` and +# applied before each run. Used for diagnostics (e.g. enabling bf16 reduced-precision reductions); +# entries are listed in `_TORCH_BACKEND_DEFAULTS` and reset to their defaults before applying. +_TORCH_BACKEND_PREFIX = "_torch_backend." +_TORCH_BACKEND_DEFAULTS = { + "cuda.matmul.allow_bf16_reduced_precision_reduction": False, +} +_TORCH_MATMUL_PRECISION_KEY = "_torch_matmul_precision" + + +@config_class() +class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): + """Evaluate layer-wise numerical-error propagation against an fp32 reference. + + Inherits `model` and `pretrained` from `PretrainedGPTModelConfig`: either or both + can be set in the YAML. The tool runs one fp32 reference + one trainer invocation + per variant, captures per-layer forward activations and input gradients via the + standard tensor-logs pipeline, and reports per-tensor RMS / max diffs. + """ + + _abstract = False + variants: dict[str, typing.Any] = Field( + desc="Named override bundles to evaluate against the fp32 reference." + " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", + hint=FieldHint.core, + ) + output_dir: pathlib.Path = Field( + desc="Directory for per-run tensor-log artifacts and the final JSON report.", + hint=FieldHint.core, + ) + num_samples: int = Field( + default=8192, + desc="Number of sampled values stored per logged tensor (rounded up to next power of 2)." + " Sparse tensors (e.g. embedding-weight gradients) get a higher level via" + " `TensorLogsConfig.sample_level_overrides`.", + hint=FieldHint.feature, + ) + sequence_length: int = Field( + default=2048, + desc="Sequence length per micro-batch sample. Drives both `data.micro_batch_size` (the" + " per-sample token count, despite the name) and `data.maximum_document_length`.", + hint=FieldHint.feature, + ) + input_text_file: pathlib.Path | None = Field( + default=None, + desc="If set, tokenize this text file (via the pretrained tokenizer) to build the fixed model" + " input, tiled/truncated to `sequence_length`. If unset, the input is uniform-random token ids." + " The exact input tensor is saved to `output_dir/input_ids.pt` so the DeepSpeed-side tool" + " (`tools/evaluate_precision_deepspeed.py`) can consume the identical model input.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() + assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." + for name, overrides in self.variants.items(): + assert isinstance(overrides, dict) and all( + isinstance(k, str) for k in overrides + ), f"Variant {name!r} must be a flat dict of dotted-path string keys." + + def run(self) -> None: + self.output_dir.mkdir(parents=True, exist_ok=True) + input_ids = self._prepare_input_ids() + runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} + runs.update(self.variants) + scales: dict[str, float] = {} + for name, variant_overrides in runs.items(): + scales[name] = self._calibrate_and_run(name, variant_overrides, input_ids) + + ref_artifacts = self._artifact_path(_REFERENCE_NAME) + results = { + name: self._compare(ref_artifacts, self._artifact_path(name), scales[_REFERENCE_NAME], scales[name]) + for name in self.variants + } + + report_path = self.output_dir / "precision_report.json" + report_path.write_text(json.dumps({"scales": scales, "variants": results}, indent=2)) + logger.info(f"Wrote report to {report_path}") + logger.info(f"Per-variant gradient scales: {scales}") + + for name, rows in results.items(): + _print_table(name, rows) + _print_summary(results) + + def _calibrate_and_run( + self, name: str, variant_overrides: dict[str, typing.Any], input_ids: "torch.Tensor" + ) -> float: + """Pick a power-of-2 gradient scale for this variant via a calibration pass, then run with it. + + Calibration runs with `constant=1.0` so no overflow is possible; scanning logged gradients + then gives us `max_unscaled`. The largest safe power of 2 keeps `scale * max_unscaled` below + `fp16_max / hidden_intermediate_budget`, where the budget reserves headroom for partial sums + inside fused kernels (e.g. norm-weight grads sum over the sequence dimension). + """ + import torch + + cal_dir = self.output_dir / f"{_CALIBRATION_SUBDIR_PREFIX}{name}" + self._run_one(name, variant_overrides, input_ids, constant_scale=1.0, experiment_dir=cal_dir) + max_unscaled = _scan_max_grad(cal_dir / "runs" / "0" / "artifacts") + shutil.rmtree(cal_dir) + if max_unscaled <= 0.0: + scale = 1.0 + logger.warning(f"[{name}] calibration found no nonzero gradient — falling back to scale=1.0") + else: + fp16_max = torch.finfo(torch.float16).max + optimal_unrounded = fp16_max / max_unscaled / _HIDDEN_INTERMEDIATE_HEADROOM + scale = float(2 ** max(0, math.floor(math.log2(optimal_unrounded)))) + logger.info(f"[{name}] calibration: max_unscaled={max_unscaled:.4e} -> gradient_scaler.constant={scale:g}") + self._run_one(name, variant_overrides, input_ids, constant_scale=scale) + return scale + + def _prepare_input_ids(self) -> "torch.Tensor": + """Build the fixed model input once and save it so the DeepSpeed-side tool feeds the exact + same tokens. Going through Fast-LLM's data pipeline would re-randomize the model input + (shuffle/packing), so the input is constructed directly here and fed verbatim to the runner.""" + import torch + + vocab_size = self.model.base_model.embeddings.vocab_size + if self.input_text_file is not None: + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(str(self.pretrained.path)) + ids = tokenizer(self.input_text_file.read_text(), return_tensors="pt").input_ids[0] + if ids.numel() < self.sequence_length: + ids = ids.repeat((self.sequence_length + ids.numel() - 1) // ids.numel()) + ids = ids[: self.sequence_length].to(torch.int64) + else: + generator = torch.Generator().manual_seed(_INPUT_SEED) + ids = torch.randint(0, vocab_size, (self.sequence_length,), generator=generator, dtype=torch.int64) + input_ids = ids.unsqueeze(0) + path = self.output_dir / "input_ids.pt" + torch.save(input_ids, path) + logger.info(f"Shared model input: {tuple(input_ids.shape)} saved to {path}") + return input_ids + + def _artifact_path(self, name: str) -> pathlib.Path: + return self.output_dir / name / "runs" / "0" / "artifacts" + + def _run_one( + self, + name: str, + variant_overrides: dict[str, typing.Any], + input_ids: "torch.Tensor", + *, + constant_scale: float | None = None, + experiment_dir: pathlib.Path | None = None, + ) -> None: + # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe + # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be + # read by `_artifact_path` below. + if experiment_dir is None: + experiment_dir = self.output_dir / name + if experiment_dir.exists(): + shutil.rmtree(experiment_dir) + # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. + # Forced fp32 on the reference baseline lives in here too so a variant can override it. + optimizer_config: dict[str, typing.Any] = { + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, + } + if constant_scale is not None: + optimizer_config["gradient_scaler"] = {"constant": float(constant_scale)} + base_dict: dict[str, typing.Any] = { + "pretrained": self.pretrained.to_dict(), + "model": self.model.to_dict(), + "training": { + "train_iters": 1, + "num_workers": 0, + "logs": {"interval": 1}, + }, + "optimizer": optimizer_config, + # The lean runner feeds a fixed input directly and ignores this dataset; it's only here so + # the TrainerConfig validates. Despite the name, `data.micro_batch_size` is the per-sample + # sequence length, not the batch dimension. + "data": { + "datasets": {"training": {"type": "random"}}, + "micro_batch_size": self.sequence_length, + "maximum_document_length": self.sequence_length, + }, + "run": { + "experiment_dir": str(experiment_dir.resolve()), + "tensor_logs": { + "save": True, + "show": False, + "sample_level_overrides": _SPARSE_GRAD_OVERRIDES, + }, + }, + } + # Translate `num_samples` to a `log_tensor` level: 2**(level-3) = samples. + log_level = math.ceil(math.log2(max(self.num_samples, 1))) + 3 + fp32_dtypes = { + ("model", "distributed", "compute_dtype"): "float32", + ("model", "distributed", "optimization_dtype"): "float32", + } + # Split off torch-backend overrides before passing the rest to Fast-LLM's config system. + backend_overrides = { + key[len(_TORCH_BACKEND_PREFIX) :]: value + for key, value in variant_overrides.items() + if key.startswith(_TORCH_BACKEND_PREFIX) + } + _apply_torch_backend_overrides(backend_overrides) + matmul_precision = variant_overrides.get(_TORCH_MATMUL_PRECISION_KEY, "highest") + _apply_torch_matmul_precision(matmul_precision) + variant_updates = { + tuple(key.split(".")): value + for key, value in variant_overrides.items() + if not key.startswith(_TORCH_BACKEND_PREFIX) and key != _TORCH_MATMUL_PRECISION_KEY + } + # Tool-required overrides win over variants — a variant must not silently disable tensor logging. + tool_overrides: dict[tuple[str, ...], typing.Any] = { + ("model", "multi_stage", "debug_layer_outputs"): log_level, + ("model", "multi_stage", "debug_layer_gradients"): log_level, + ("model", "multi_stage", "debug_all_param_gradients"): log_level, + # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's + # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. + ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], + # Diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + # Contributes no gradient (weight=0); the comparison code picks it up by name. + ("model", "base_model", "head", "losses", _CHOSEN_LOGPROB_NAME): {"type": "chosen_logprob"}, + } + # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a + # loss explicitly suppresses that default, so re-add it so gradients still flow. + if not (self.model.base_model.head.losses or {}): + tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) + trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) + trainer_config.configure_logging() + _run_fixed_input(trainer_config, input_ids, self.sequence_length) + + def _compare( + self, + ref_path: pathlib.Path, + test_path: pathlib.Path, + ref_scale: float, + test_scale: float, + ) -> list[dict[str, typing.Any]]: + compare_config = CompareConfig() + errors: list[str] = [] + ref_logs = compare_config._extract_tensor_logs(ref_path, errors) + test_logs = compare_config._extract_tensor_logs(test_path, errors) + for error in errors: + logger.warning(error) + # Each variant's gradient logs are scaled by its own `constant` factor (auto-calibrated). + # Undo per-variant scaling so the relative comparison reflects unscaled gradient diffs. + _unscale_gradients_in_place(ref_logs, ref_scale) + _unscale_gradients_in_place(test_logs, test_scale) + rows: list[dict[str, typing.Any]] = [] + for step_name in sorted(ref_logs): + if step_name not in test_logs: + logger.warning(f"Step {step_name!r} missing from test logs") + continue + step_ref = ref_logs[step_name] + step_test = test_logs[step_name] + for tensor_name, ref in step_ref.items(): + if tensor_name not in step_test: + continue + metrics = compare_config._compute_diff(ref, step_test[tensor_name], step_name, tensor_name) + if metrics is None: + continue + rows.append( + { + "step": step_name, + "tensor_name": tensor_name, + "kind": _classify(tensor_name), + "shape": ref["shape"], + **metrics, + } + ) + return rows + + +def _run_fixed_input(config, input_ids, sequence_length: int) -> None: + """Lean forward+backward on a fixed, already-preprocessed input — like `InferenceRunner` but with a + training-phase schedule + an (lr-0) optimizer so `run_step` runs the backward and the existing + chosen-logprob loss / `debug_all_param_gradients` logging captures everything. Replaces the trainer + + data pipeline so the model sees exactly `input_ids` (the pipeline would re-randomize it) and so the + tool stops paying for training/data-loading infrastructure it doesn't need.""" + import gc + + import torch + + from fast_llm.data.document.language_model import LanguageModelBatch + from fast_llm.engine.distributed.config import PhaseType + from fast_llm.engine.distributed.distributed import Distributed + from fast_llm.engine.multi_stage.config import StageMode + from fast_llm.engine.optimizer.config import ParamGroup + from fast_llm.engine.schedule.runner import ScheduleRunner + from fast_llm.engine.schedule.schedule import Schedule + + distributed = Distributed(config.model.distributed) + run = config.get_run(distributed) + with run: + multi_stage = config.model.get_model_class()( + config.model, optimizer_state_names=config.optimizer.state_names() + ) + with torch.no_grad(): + multi_stage.setup(distributed, mode=StageMode.training) + if config.pretrained.path is not None and config.pretrained.model_weights: + multi_stage.load_checkpoint(config.pretrained) + else: + multi_stage.initialize_weights() + param_groups, grads_for_norm = multi_stage.get_param_groups(ParamGroup) + optimizer = config.optimizer.optimizer_cls( + config.optimizer, param_groups=param_groups, grads_for_norm=grads_for_norm, distributed=distributed + ) + optimizer.reset_state() + runner = ScheduleRunner( + config=config.schedule, multi_stage=multi_stage, distributed_config=config.model.distributed + ) + with torch.no_grad(): + runner.setup(distributed, optimizer) + preprocessing_config = multi_stage.get_preprocessing_config( + PhaseType.training, config.schedule.micro_batch_splits + ) + # `get_model_inputs` splits off `num_labels` tokens for the shifted next-token labels, so the + # actual model input is `len(tokens) - num_labels`. The schedule meta must match that length. + schedule = Schedule( + config=config.schedule, + multi_stage=multi_stage, + batch_meta=preprocessing_config.get_input_meta(sequence_length - preprocessing_config.num_labels), + distributed_config=config.model.distributed, + phase=PhaseType.training, + ) + tokens = input_ids.flatten().to(device=distributed.device, dtype=torch.int64) + batch = LanguageModelBatch(tokens=tokens, lengths=[tokens.numel()]) + model_inputs = batch.get_model_inputs(preprocessing_config) + runner.run_step(iter((tuple(model_inputs),)), schedule, iteration=1) + # Break the trainer/model/runner reference cycles so each variant's GPU memory is reclaimed. + del multi_stage, optimizer, runner, schedule, distributed, run + gc.collect() + torch.cuda.empty_cache() + + +def _is_gradient_like(tensor_name: str) -> bool: + # Anything affected by the loss-scaling multiplier: parameter gradients from `Fsdp.log_shard`, + # backward activations from layer hooks, and explicit `.grad` debug entries (e.g. logits.grad). + return ("gradient:" in tensor_name) or (" bw" in tensor_name) or (".grad" in tensor_name) + + +def _scan_max_grad(artifact_path: pathlib.Path) -> float: + max_abs = 0.0 + compare_config = CompareConfig() + errors: list[str] = [] + logs = compare_config._extract_tensor_logs(artifact_path, errors) + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + # Saved stats include min/max; fall back to samples if absent. + if "max" in entry and "min" in entry: + value = max(abs(float(entry["max"])), abs(float(entry["min"]))) + else: + value = float(entry["samples"].abs().max().item()) + if math.isfinite(value) and value > max_abs: + max_abs = value + return max_abs + + +def _unscale_gradients_in_place(logs: dict, scale: float) -> None: + if scale == 1.0: + return + inv = 1.0 / scale + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + entry["samples"] = entry["samples"].float() * inv + for key in ("min", "max", "mu", "std"): + if key in entry and entry[key] is not None: + entry[key] = float(entry[key]) * inv + + +def _apply_torch_backend_overrides(overrides: dict[str, typing.Any]) -> None: + import torch + + unknown = set(overrides) - set(_TORCH_BACKEND_DEFAULTS) + if unknown: + logger.warning(f"Unknown torch backend overrides (ignored): {sorted(unknown)}") + for path, default in _TORCH_BACKEND_DEFAULTS.items(): + value = overrides.get(path, default) + obj: typing.Any = torch.backends + parts = path.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def _apply_torch_matmul_precision(precision: str) -> None: + import torch + + torch.set_float32_matmul_precision(precision) + + +def _layer_name(tensor_name: str) -> str: + # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; + # Fsdp.log_shard names weight gradients `Global gradient: `. + prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix == ["Global", "gradient"]: + param = tensor_name.split(":", 1)[1].strip() + return param.split(".")[0] + if prefix and prefix[0] == "Global": + prefix = prefix[1:] + if prefix and prefix[-1] in ("fw", "bw"): + prefix = prefix[:-1] + return " ".join(prefix) if prefix else "?" + + +def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing.Any] | None: + return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) + + +_LM_HEAD_NAME = "head.output_weights" +_EMBEDDINGS_NAME = "embeddings.word_embeddings_weight" + + +def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + sample = next(iter(results.values())) + has_fw_logits = _named_row(sample, "head.logits") is not None + has_bw_logits = _named_row(sample, "head.logits.grad") is not None + has_bias = any( + r["kind"] == "grad" and r["tensor_name"].split(":", 1)[-1].strip().endswith(".bias") for r in sample + ) + # Each kind's aggregation columns are listed chronologically (left-to-right matches + # the order tensors are logged). Logits show up via `output_hidden_states` on the + # fw/bw boundary; weight gradients have no logits hook. + fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) + bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") + grad_aggs = ( + ("lm_head", "linear_med", "linear_max", "norm_med", "norm_max") + + (("bias_med", "bias_max") if has_bias else ()) + + ("embeddings",) + ) + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} + for kind in ("fw", "bw", "grad"): + _print_summary_table(results, kind, aggs_per_kind[kind]) + if _named_row(sample, _CHOSEN_LOGPROB_NAME) is not None: + _print_chosen_logprob_summary(results) + + +def _print_chosen_logprob_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + rows_by_variant = {name: _named_row(rows, _CHOSEN_LOGPROB_NAME) for name, rows in results.items()} + # log π(label) is the scalar that policy-gradient importance ratios depend on. Bias persists + # under per-document averaging where RMS shrinks ~1/√T, so for RL stability it's the more + # informative signal — surface it alongside RMS, slope and residual. + rms_rel_decimals = _column_decimals((r["rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + bias_rel_decimals = _column_decimals((r["bias_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + resid_rel_decimals = _column_decimals( + (r["residual_rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5 + ) + name_width = max((len(name) for name in results), default=7) + 1 + cols = [ + ("RMS rel", lambda r: f"{r['rms_rel'] * 100:.{rms_rel_decimals}f}%"), + ("Bias rel", lambda r: f"{r['bias_rel'] * 100:+.{bias_rel_decimals}f}%"), + ("Resid rel", lambda r: f"{r['residual_rms_rel'] * 100:.{resid_rel_decimals}f}%"), + ("Corr", lambda r: f"{r['correlation']:.5f}"), + ("Slope", lambda r: f"{r['slope']:+.5f}"), + ("Max abs", lambda r: f"{r['max_abs']:.4g}"), + ("Scale", lambda r: f"{r['ref_scale']:.4g}"), + ] + widths = [max(len(label), max(len(fn(r)) for r in rows_by_variant.values())) for label, fn in cols] + print(f"\n=== Summary: chosen_logprob (per-token) ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, row in rows_by_variant.items(): + cells = [fn(row) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) + + +def _grad_category(tensor_name: str) -> str: + name = tensor_name.split(":", 1)[-1].strip() + if name.endswith(".bias"): + return "bias" + if ".norm_" in name or name.endswith(".norm.weight"): + return "norm" + return "linear" + + +def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: + sample = next(iter(results.values())) + group = [r for r in sample if r["kind"] == kind] + if not group: + return + endpoint_labels = { + "first": _layer_name(group[0]["tensor_name"]), + "last": _layer_name(group[-1]["tensor_name"]), + } + mid_labels = { + "median": "mid med", + "max": "mid max", + "logits": "logits", + "lm_head": "lm head", + "embeddings": "embeddings", + "linear_med": "linear med", + "linear_max": "linear max", + "norm_med": "norm med", + "norm_max": "norm max", + "bias_med": "bias med", + "bias_max": "bias max", + } + + def _label(agg: str) -> str: + return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] + + name_width = max((len(name) for name in results), default=7) + 1 + cell_width = max(len(_label(a)) for a in aggs) + cell_sep = " " + raw: dict[str, dict[str, float | None]] = {} + for name, rows in results.items(): + logits_fw = _named_row(rows, "head.logits") + logits_bw = _named_row(rows, "head.logits.grad") + logits_value = { + "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), + "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), + } + kind_rows = [r for r in rows if r["kind"] == kind] + values = [r["rms_rel"] for r in kind_rows] + if kind == "grad": + decoder_rows = [r for r in kind_rows if r["tensor_name"].split(":", 1)[-1].strip().startswith("decoder.")] + category_values: dict[str, list[float]] = {"linear": [], "norm": [], "bias": []} + for r in decoder_rows: + category_values[_grad_category(r["tensor_name"])].append(r["rms_rel"]) + lm_head_row = _named_row(kind_rows, _LM_HEAD_NAME) + embeddings_row = _named_row(kind_rows, _EMBEDDINGS_NAME) + else: + category_values = {} + lm_head_row = embeddings_row = None + intermediate = values[1:-1] or values + cells: dict[str, float | None] = {} + for agg in aggs: + if agg == "first": + cells[agg] = values[0] if values else None + elif agg == "last": + cells[agg] = values[-1] if values else None + elif agg == "logits": + cells[agg] = logits_value[kind] + elif agg == "lm_head": + cells[agg] = lm_head_row["rms_rel"] if lm_head_row else None + elif agg == "embeddings": + cells[agg] = embeddings_row["rms_rel"] if embeddings_row else None + elif "_" in agg and agg.split("_", 1)[0] in category_values: + cat, stat = agg.split("_", 1) + cat_values = category_values[cat] + if not cat_values: + cells[agg] = None + elif stat == "max": + cells[agg] = max(cat_values) + else: + cells[agg] = statistics.median(cat_values) + elif agg == "max": + cells[agg] = max(intermediate) if intermediate else None + else: + cells[agg] = statistics.median(intermediate) if intermediate else None + raw[name] = cells + + column_decimals = { + agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs + } + if kind == "grad": + subtitle = " (Relative %)" + else: + subtitle = " (Relative %; mid = excluding first/last)" + print(f"\n=== Summary: {kind}{subtitle} ===") + header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) + print(header) + print("-" * len(header)) + for name, cells in raw.items(): + formatted = [ + f"{cells[agg] * 100:.{column_decimals[agg]}f}%" if cells[agg] is not None else "n/a" for agg in aggs + ] + print(f"{name:<{name_width}}" + cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) + + +def _column_decimals( + values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3, max_decimals: int | None = None +) -> int: + # Keep the default precision, but bump up so the smallest non-zero value carries at least + # `min_sig_figs` significant digits when formatted as percent. `max_decimals` caps the + # bump so a single tiny noisy value doesn't widen the whole column. + smallest = min((abs(v) * 100 for v in values if v != 0), default=None) + if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): + result = default + else: + result = max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + return min(result, max_decimals) if max_decimals is not None else result + + +def _display_group(row: dict[str, typing.Any]) -> str: + # Map each row to one of "fw"/"bw"/"grad" for the per-variant table, independent + # of `kind`: head.logits is a forward activation, head.logits.grad is a backward + # quantity, parameter gradients are their own group. + if row["kind"] == "grad": + return "grad" + if row["kind"] == "bw" or row["tensor_name"].endswith(".grad"): + return "bw" + return "fw" + + +def _classify(tensor_name: str) -> str: + # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" + # and " bw[, mb=…]"; log_distributed_tensor may prefix the name + # with "Global " and append a ": " suffix when reconstructing a + # tensor-parallel-global tensor. Per-parameter gradient logs come from + # `Fsdp.log_shard(name="gradient", ...)` and are tagged "grad" so they appear + # in the per-variant table but stay out of the fw/bw summary aggregation. + # Other entries (e.g. `Global : head.logits`, `Global : head.logits.grad`) come + # from the `_debug` / `output_hidden_states` path and are surfaced via dedicated + # logits columns in the summary. + if "gradient:" in tensor_name: + return "grad" + for kind in ("fw", "bw"): + if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): + return kind + return "other" + + +def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: + print(f"\n=== Variant: {name} ===") + if not rows: + print("(no matching tensors)") + return + name_fn = lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})" + name_width = max(len("Tensor"), max(len(name_fn(r)) for r in rows)) + # Adaptive precision for the relative column: bump decimals so small but real values + # (typical for weight gradients) stay legible, capped at 5 to bound column width. + relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) + relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + bias_decimals = _column_decimals((r["bias_rel"] for r in rows), default=2, max_decimals=5) + bias_fn = lambda r: f"{r['bias_rel'] * 100:+.{bias_decimals}f}%" + relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) + bias_width = max(len("Bias"), max(len(bias_fn(r)) for r in rows)) + columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ + ("Tensor", name_width, name_fn), + ("Relative", relative_width, relative_fn), + ("Bias", bias_width, bias_fn), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), + ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), + ] + header = " ".join(f"{title:<{width}}" for title, width, _ in columns) + print(header) + print("-" * len(header)) + # Display grouping (fw / bw / grad) separates the chronologically-interleaved + # backward and reduce_gradients hooks. Independent of `kind` so the summary + # aggregation isn't affected. + groups = ("fw", "bw", "grad") + grouped: dict[str, list[dict[str, typing.Any]]] = {g: [] for g in groups} + for row in rows: + grouped[_display_group(row)].append(row) + first = True + for group in groups: + if not grouped[group]: + continue + if not first: + print() + first = False + for row in grouped[group]: + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) + + +if __name__ == "__main__": + EvaluatePrecisionConfig.parse_and_run() diff --git a/tools/evaluate_precision_deepspeed.py b/tools/evaluate_precision_deepspeed.py new file mode 100644 index 000000000..d9cf102dc --- /dev/null +++ b/tools/evaluate_precision_deepspeed.py @@ -0,0 +1,299 @@ +"""Within-engine numerical-precision sweep for the HF-transformers + DeepSpeed stack. + +This is the DeepSpeed-side counterpart to `tools/evaluate_precision.py` (which measures the +same thing inside Fast-LLM). It loads a HF checkpoint, runs one forward + backward per precision +variant through a DeepSpeed engine, and reports two quantities against the fp32 reference, using +the same metrics (`CompareConfig._compute_diff`: RMS / bias / correlation / slope / residual): + + * chosen-token log-probability per position (the RL importance-ratio input); + * parameter gradients, aggregated by category (embedding/head, linear, norm, bias). + +The point is to check whether Fast-LLM's bf16 loses precision the *same way* DeepSpeed's bf16 +does — each measured against its own fp32 reference. + +The log-π computation and the fp32 LM-head mechanism mirror PipelineRL's DeepSpeed trainer +(`pipelinerl/finetune/rl/__init__.py` and `pipelinerl/finetune/checkpoints.py`) so the numbers +reflect the proven baseline rather than a bespoke path. `param.grad` is populated and already +unscaled after `engine.backward` (verified for both bf16 and fp16), so gradients are read directly. + +Run where transformers + deepspeed are installed (e.g. the PipelineRL stack image): + + python -m tools.evaluate_precision_deepspeed --model Qwen/Qwen2.5-0.5B --sequence-length 2048 +""" + +import argparse +import functools +import logging +import os +import statistics +import typing + +import torch + +logger = logging.getLogger(__name__) + +_REFERENCE_NAME = "fp32" +# (name, compute dtype, fp32 lm head). Reference is fp32 + fp32 head. `*_head_` variants +# turn the fp32 head OFF (head runs in compute dtype) to reproduce the within-engine +# "fp32 lm head has ~no effect" finding on the DeepSpeed side. +_VARIANTS: list[tuple[str, torch.dtype, bool]] = [ + (_REFERENCE_NAME, torch.float32, True), + ("bf16", torch.bfloat16, True), + ("bf16_head_bf16", torch.bfloat16, False), + ("fp16", torch.float16, True), + ("fp16_head_fp16", torch.float16, False), +] + +_FIXED_TEXT = ( + "The numerical precision of large language model training depends on the dtype used for " + "matrix multiplications, the accumulation precision of the hardware, and whether the output " + "projection is kept in full precision. In reinforcement learning from human feedback, the " + "importance ratio between the new and old policy is the exponential of the difference of " + "log-probabilities, so even small per-token errors in the log-probability can compound. " + "We compute the chosen-token log-probability as the log-softmax of the logits evaluated at " + "the next token, and we compare bfloat16 and float16 against a float32 reference. " +) + + +def apply_fp32_lm_head(model: torch.nn.Module, layer_prefix: str = "lm_head") -> torch.nn.Module: + """Cast the LM head to fp32 at compute time. Mirrors PipelineRL `apply_fp32_lm_head`. + + For tied embeddings (e.g. Qwen2.5-0.5B) the weight storage stays in the model dtype and is + upcast only for the head matmul; for untied heads the storage itself is moved to fp32. + """ + lm_head = model.get_output_embeddings() + if lm_head is None or not isinstance(lm_head, torch.nn.Linear): + raise RuntimeError(f"Could not find an nn.Linear LM head via get_output_embeddings(): {lm_head!r}") + tied = False + inp_emb = model.get_input_embeddings() + if inp_emb is not None and hasattr(inp_emb, "weight"): + tied = lm_head.weight is inp_emb.weight + if not tied and lm_head.weight.dtype != torch.float32: + lm_head.to(dtype=torch.float32) + original_forward = lm_head.forward + + @functools.wraps(original_forward) + def fp32_forward(x: torch.Tensor) -> torch.Tensor: + x32 = x if x.dtype == torch.float32 else x.float() + w = lm_head.weight + w32 = w if w.dtype == torch.float32 else w.float() + b = lm_head.bias + b32 = b.float() if (b is not None and b.dtype != torch.float32) else b + return torch.nn.functional.linear(x32, w32, b32) + + lm_head.forward = fp32_forward + logger.info(f"Applied fp32 lm head (tied={tied})") + return model + + +def chosen_logprob(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + """log π(next token) per position. Mirrors PipelineRL `rl/__init__.py:203-208`.""" + logits = logits[:, :-1, :].float() / temperature + next_ids = input_ids[:, 1:].unsqueeze(-1) + selected = torch.gather(logits, 2, next_ids).squeeze(-1) + log_norm = torch.logsumexp(logits, dim=-1) + return (selected - log_norm).reshape(-1) + + +def build_input_ids(tokenizer, vocab_size: int, sequence_length: int, device: torch.device, mode: str) -> torch.Tensor: + if mode == "random": + # Match Fast-LLM's random dataset (uniform token ids over the model vocab) so both engines + # see the same input distribution. The relative metrics depend strongly on it: on random + # tokens the model is maximally surprised (|log π| large), on realistic text |log π| ≈ 0, + # which shifts the relative RMS by several-fold even at identical absolute precision. + generator = torch.Generator().manual_seed(0) + ids = torch.randint(0, vocab_size, (sequence_length,), generator=generator) + else: + ids = tokenizer(_FIXED_TEXT, return_tensors="pt").input_ids[0] + repeats = (sequence_length + ids.numel() - 1) // ids.numel() + ids = ids.repeat(repeats)[:sequence_length] + return ids.unsqueeze(0).to(device) + + +def _ds_config(dtype: torch.dtype) -> dict[str, typing.Any]: + config: dict[str, typing.Any] = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": {"type": "Adam", "params": {"lr": 1e-6}}, + } + if dtype == torch.bfloat16: + config["bf16"] = {"enabled": True} + elif dtype == torch.float16: + config["fp16"] = {"enabled": True, "initial_scale_power": 16} + return config + + +def grad_category(name: str) -> str: + if name.endswith(".bias"): + return "bias" + if "layernorm" in name or name.endswith("norm.weight"): + return "norm" + if "embed_tokens" in name or "lm_head" in name: + return "embed_head" + return "linear" + + +def capture_variant( + model_id: str, + dtype: torch.dtype, + fp32_head: bool, + input_ids: torch.Tensor, + attn_implementation: str, + random_init: bool = False, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Forward + backward one variant through a DeepSpeed engine. Returns (chosen_logprob, + {param_name: gradient}), both on CPU in fp32.""" + import deepspeed + import transformers + + if random_init: + model = transformers.AutoModelForCausalLM.from_config( + transformers.AutoConfig.from_pretrained(model_id), dtype=dtype, attn_implementation=attn_implementation + ) + else: + model = transformers.AutoModelForCausalLM.from_pretrained( + model_id, dtype=dtype, attn_implementation=attn_implementation + ) + if fp32_head: + apply_fp32_lm_head(model) + engine, *_ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=_ds_config(dtype)) + outputs = engine(input_ids) + logprob = chosen_logprob(outputs.logits, input_ids) + # fp16's narrow exponent range underflows small gradients; scale the loss up before backward and + # divide it back out (loss scaling, as in fp16 training). bf16/fp32 have fp32 range, no scaling. + # engine.backward leaves param.grad unscaled, so dividing by our own loss_scale recovers the true + # gradient computed with extra headroom against underflow. + loss_scale = 256.0 if dtype == torch.float16 else 1.0 + engine.backward(-logprob.mean() * loss_scale) + grads = { + name: (p.grad.detach().float() / loss_scale).cpu() + for name, p in model.named_parameters() + if p.grad is not None + } + logprob = logprob.detach().float().cpu() + del engine, model, outputs + torch.cuda.empty_cache() + return logprob, grads + + +def _entry(tensor: torch.Tensor) -> dict[str, typing.Any]: + return {"shape": list(tensor.shape), "step": 1, "samples": tensor} + + +def _print_logprob_summary(metrics_by_variant: dict[str, dict[str, typing.Any]]) -> None: + cols = [ + ("RMS rel", lambda m: f"{m['rms_rel'] * 100:.4f}%"), + ("Bias rel", lambda m: f"{m['bias_rel'] * 100:+.4f}%"), + ("Resid rel", lambda m: f"{m['residual_rms_rel'] * 100:.4f}%"), + ("Corr", lambda m: f"{m['correlation']:.5f}"), + ("Slope", lambda m: f"{m['slope']:+.5f}"), + ("Max abs", lambda m: f"{m['max_abs']:.4g}"), + ("Scale", lambda m: f"{m['ref_scale']:.4g}"), + ] + _print_table("chosen_logprob (per-token) vs fp32 reference", metrics_by_variant, cols) + + +def _print_grad_summary(grad_metrics_by_variant: dict[str, dict[str, list[float]]]) -> None: + # Per-category aggregation of gradient RMS-rel, mirroring tools/evaluate_precision.py's grad table. + def med(values: list[float]) -> str: + return f"{statistics.median(values) * 100:.4f}%" if values else "n/a" + + def mx(values: list[float]) -> str: + return f"{max(values) * 100:.4f}%" if values else "n/a" + + cols = [ + ("embed_head", lambda c: med(c.get("embed_head", []))), + ("linear med", lambda c: med(c.get("linear", []))), + ("linear max", lambda c: mx(c.get("linear", []))), + ("norm med", lambda c: med(c.get("norm", []))), + ("norm max", lambda c: mx(c.get("norm", []))), + ("bias med", lambda c: med(c.get("bias", []))), + ("bias max", lambda c: mx(c.get("bias", []))), + ] + _print_table("gradient RMS-rel by category vs fp32 reference", grad_metrics_by_variant, cols) + + +def _print_table(title: str, by_variant: dict, cols: list[tuple[str, typing.Callable]]) -> None: + name_width = max((len(n) for n in by_variant), default=7) + 1 + widths = [max(len(label), max((len(fn(v)) for v in by_variant.values()), default=0)) for label, fn in cols] + print(f"\n=== DeepSpeed/HF: {title} ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, value in by_variant.items(): + cells = [fn(value) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B") + parser.add_argument("--sequence-length", type=int, default=2048) + parser.add_argument("--attn-implementation", default="sdpa") + parser.add_argument("--input-mode", choices=["random", "text"], default="random") + parser.add_argument( + "--input-file", + default=None, + help="Path to an input_ids.pt saved by tools/evaluate_precision.py. When set, feeds that exact" + " model input (so Fast-LLM and DeepSpeed see byte-identical tokens); --input-mode is ignored.", + ) + parser.add_argument( + "--random-init", + action="store_true", + help="Build the model from config with random weights instead of loading the pretrained" + " checkpoint (contrast with the pretrained run; weights won't match Fast-LLM's random init).", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + for key, value in ( + ("RANK", "0"), + ("LOCAL_RANK", "0"), + ("WORLD_SIZE", "1"), + ("MASTER_ADDR", "127.0.0.1"), + ("MASTER_PORT", "29555"), + ): + os.environ.setdefault(key, value) + + import transformers + + from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig + + device = torch.device("cuda:0") + if args.input_file is not None: + input_ids = torch.load(args.input_file).to(device=device, dtype=torch.int64) + logger.info(f"Loaded shared model input {tuple(input_ids.shape)} from {args.input_file}") + else: + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model) + vocab_size = transformers.AutoConfig.from_pretrained(args.model).vocab_size + input_ids = build_input_ids(tokenizer, vocab_size, args.sequence_length, device, args.input_mode) + logger.info(f"input_ids shape {tuple(input_ids.shape)}") + + compare = CompareConfig() + ref_logprob: torch.Tensor | None = None + ref_grads: dict[str, torch.Tensor] = {} + logprob_metrics: dict[str, dict[str, typing.Any]] = {} + grad_metrics: dict[str, dict[str, list[float]]] = {} + for name, dtype, fp32_head in _VARIANTS: + logger.info(f"=== variant {name} (dtype={dtype}, fp32_head={fp32_head}) ===") + logprob, grads = capture_variant( + args.model, dtype, fp32_head, input_ids, args.attn_implementation, args.random_init + ) + if name == _REFERENCE_NAME: + ref_logprob, ref_grads = logprob, grads + logprob_metrics[name] = compare._compute_diff(_entry(ref_logprob), _entry(logprob), "step", "chosen_logprob") + by_category: dict[str, list[float]] = {} + for param_name, grad in grads.items(): + if param_name not in ref_grads: + continue + metrics = compare._compute_diff(_entry(ref_grads[param_name]), _entry(grad), "step", param_name) + by_category.setdefault(grad_category(param_name), []).append(metrics["rms_rel"]) + grad_metrics[name] = by_category + + _print_logprob_summary(logprob_metrics) + _print_grad_summary(grad_metrics) + + +if __name__ == "__main__": + main()