You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2042
-
:param lhs_scale: Scale factor for lhs tensor.
2043
-
:type lhs_scale: e8m0 type represented as an uint8 tensor.
2042
+
:param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`.
2043
+
:type lhs_scale: e8m0 type represented as an uint8 tensor, or None.
2044
2044
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2045
2045
:type lhs_format: str
2046
2046
:param rhs: The second tensor to be multiplied.
2047
2047
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
2048
-
:param rhs_scale: Scale factor for rhs tensor.
2049
-
:type rhs_scale: e8m0 type represented as an uint8 tensor.
2048
+
:param rhs_scale: Scale factor for rhs tensor. Shape should be [N, K//group_size] where rhs is [K, N].
2049
+
Important: Do NOT transpose rhs_scale
2050
+
:type rhs_scale: e8m0 type represented as an uint8 tensor, or None.
2050
2051
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
2051
2052
:type rhs_format: str
2052
2053
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
0 commit comments