Skip to content

Commit e45315f

Browse files
authored
[DOC] Clarify lhs_scale and rhs_scale requirements in dot_scaled (#8433)
Addresses confusion in #8431 Update `dot_scaled` docstring to clarify: - `lhs_scale` shape: `[M, K//group_size]` when lhs is `[M, K]` - `rhs_scale` shape: `[N, K//group_size]` when rhs is `[K, N]` The compiler internally handles scale transposition, so scales should remain in their original orientation. ```python # Correct acc = tl.dot_scaled(a, None, "bf16", b.T, b_scale, "e4m3", acc=acc) # Incorrect acc = tl.dot_scaled(a, None, "bf16", b.T, b_scale.T, "e4m3", acc=acc) ``` # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `it's a doc change`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 626ff18 commit e45315f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/triton/language/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,14 +2039,15 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
20392039
20402040
:param lhs: The first tensor to be multiplied.
20412041
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2042-
:param lhs_scale: Scale factor for lhs tensor.
2043-
:type lhs_scale: e8m0 type represented as an uint8 tensor.
2042+
:param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`.
2043+
:type lhs_scale: e8m0 type represented as an uint8 tensor, or None.
20442044
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
20452045
:type lhs_format: str
20462046
:param rhs: The second tensor to be multiplied.
20472047
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2048-
:param rhs_scale: Scale factor for rhs tensor.
2049-
:type rhs_scale: e8m0 type represented as an uint8 tensor.
2048+
:param rhs_scale: Scale factor for rhs tensor. Shape should be [N, K//group_size] where rhs is [K, N].
2049+
Important: Do NOT transpose rhs_scale
2050+
:type rhs_scale: e8m0 type represented as an uint8 tensor, or None.
20502051
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
20512052
:type rhs_format: str
20522053
:param acc: The accumulator tensor. If not None, the result is added to this tensor.

0 commit comments

Comments
 (0)