Skip to content

Commit 8615c62

Browse files
Fix rms_norm fp16/bf16 (Lightning-AI#1751)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1506dae commit 8615c62

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

thunder/tests/opinfos.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7967,16 +7967,11 @@ def rms_norm_error_generator(op, device, **kwargs):
79677967
dtypes=(datatypes.float16,),
79687968
devicetypes=(devices.DeviceType.CPU,),
79697969
),
7970-
# See issue - https://github.com/Lightning-AI/lightning-thunder/issues/1395
79717970
DecorateInfo(
7972-
custom_comparator(partial(assert_close, atol=1e-2, rtol=1e-2)),
7973-
dtypes=(datatypes.float16,),
7971+
pytest.mark.xfail,
7972+
dtypes=(datatypes.float16, datatypes.bfloat16),
79747973
devicetypes=(devices.DeviceType.CUDA,),
7975-
),
7976-
DecorateInfo(
7977-
pytest.mark.skip(reason="Flaky. See https://github.com/Lightning-AI/lightning-thunder/issues/1678"),
7978-
"test_core_vs_torch_consistency",
7979-
dtypes=(datatypes.bfloat16,),
7974+
active_if=LooseVersion(torch.__version__) < "2.7",
79807975
),
79817976
),
79827977
)

thunder/torch/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3800,13 +3800,22 @@ def rms_norm(
38003800
weight: None | TensorLike = None,
38013801
eps: None | float = None,
38023802
):
3803+
input_dtype = a.dtype
3804+
3805+
if a.dtype in (thunder.float16, thunder.bfloat16):
3806+
a = clang.maybe_convert_to_dtype(a, thunder.float32, enforce_safe_casting=True)
3807+
38033808
if eps is None:
38043809
eps = torch.finfo(to_torch_dtype(a.dtype)).eps
3810+
38053811
reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight)
3806-
norm_a = mean(a * a, dim=reduction_dims, keepdim=True)
3812+
norm_a = mean(a * a, dim=reduction_dims, keepdim=True, dtype=None)
38073813
a_normed = a * rsqrt(norm_a + eps)
3814+
38083815
if weight is not None:
38093816
a_normed = a_normed * weight
3817+
3818+
a_normed = clang.maybe_convert_to_dtype(a_normed, input_dtype, enforce_safe_casting=True)
38103819
return a_normed
38113820

38123821

0 commit comments

Comments
 (0)