Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,17 @@ def test_quant_model_reload(format, dtype, tmp_path):
dense_tensor = og_state_dict[key].to(device)
reconstructed_tensor = reconstructed_state_dict[key].to(device)
assert dense_tensor.dtype == reconstructed_tensor.dtype
# Skip LM Head weight for now
# Note that the embedding is quantized
# TODO(@kylesayrs): this is a manifestation not using proper save context
if key == "lm_head.weight":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding "lm_head.weight" makes this test brittle and dependent on the specific model's architecture. To make it more robust, consider dynamically identifying the output embedding layer's weight key.

You could achieve this by adding the following logic before the loop:

lm_head_key = None
if hasattr(model, "get_output_embeddings"):
    output_embeddings = model.get_output_embeddings()
    if output_embeddings is not None and hasattr(output_embeddings, "weight"):
        for name, param in model.named_parameters():
            if param is output_embeddings.weight:
                lm_head_key = name
                break

And then using lm_head_key in the comparison:

if key == lm_head_key:
    continue

This would make the test more resilient to changes in the model architecture or if other models are used in this test in the future.

continue
if key.endswith("weight") and format != "dense":
# we don't expect an exact match for compressed
diff = torch.abs(dense_tensor - reconstructed_tensor)
assert not torch.any(diff > 0.01).item()
assert not torch.any(diff > 0.01).item(), key
else:
assert torch.equal(dense_tensor, reconstructed_tensor)
assert torch.equal(dense_tensor, reconstructed_tensor), key
if os.path.isdir(tmp_path):
shutil.rmtree(tmp_path)

Expand Down
Loading