You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[DOC] Clarify lhs_scale and rhs_scale requirements in dot_scaled (#8433)
Addresses confusion in #8431
Update `dot_scaled` docstring to clarify:
- `lhs_scale` shape: `[M, K//group_size]` when lhs is `[M, K]`
- `rhs_scale` shape: `[N, K//group_size]` when rhs is `[K, N]`
The compiler internally handles scale transposition, so scales should
remain in their original orientation.
```python
# Correct
acc = tl.dot_scaled(a, None, "bf16", b.T, b_scale, "e4m3", acc=acc)
# Incorrect
acc = tl.dot_scaled(a, None, "bf16", b.T, b_scale.T, "e4m3", acc=acc)
```
# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [ ] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests
- [x] This PR does not need a test because `it's a doc change`.
- Select one of the following.
- [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2042
-
:param lhs_scale: Scale factor for lhs tensor.
2043
-
:type lhs_scale: e8m0 type represented as an uint8 tensor.
2042
+
:param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`.
2043
+
:type lhs_scale: e8m0 type represented as an uint8 tensor, or None.
2044
2044
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2045
2045
:type lhs_format: str
2046
2046
:param rhs: The second tensor to be multiplied.
2047
2047
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2048
-
:param rhs_scale: Scale factor for rhs tensor.
2049
-
:type rhs_scale: e8m0 type represented as an uint8 tensor.
2048
+
:param rhs_scale: Scale factor for rhs tensor. Shape should be [N, K//group_size] where rhs is [K, N].
2049
+
Important: Do NOT transpose rhs_scale
2050
+
:type rhs_scale: e8m0 type represented as an uint8 tensor, or None.
2050
2051
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2051
2052
:type rhs_format: str
2052
2053
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
0 commit comments