Skip to content

Commit 23c9ec1

Browse files
authored
[Frontend][Backend] Implement support for scale_dot(-, bf16) (#4996)
In the passing we also improve a few other things: - Now `scaled_dot` accepts both uint8/uint16 fp8/bf16 as inputs (before you had to cast it to uint8, which was weird when extending it to bf16). - Add `scaled_dot` to the docs and improve the docs overall (have not render them, might need a few further tweaks)
1 parent 0591b37 commit 23c9ec1

File tree

14 files changed

+128
-114
lines changed

14 files changed

+128
-114
lines changed

docs/python-api/triton.language.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Linear Algebra Ops
5959
:nosignatures:
6060

6161
dot
62+
dot_scaled
6263

6364

6465
Memory/Pointer Ops

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
119119
let cppNamespace = "::mlir::triton";
120120
}
121121

122-
// Type for F8F6F4 kind of floats.
123-
def TT_F8F6F4TypeAttr : I32EnumAttr<
124-
"F8F6F4Type", "",
122+
// Type for ScaleDotElemType kind of floats.
123+
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
124+
"ScaleDotElemType", "",
125125
[
126126
I32EnumAttrCase<"E4M3", 0, "e4m3">,
127127
I32EnumAttrCase<"E5M2", 1, "e5m2">,
128128
I32EnumAttrCase<"E2M3", 2, "e2m3">,
129129
I32EnumAttrCase<"E3M2", 3, "e3m2">,
130-
I32EnumAttrCase<"E2M1", 4, "e2m1">
130+
I32EnumAttrCase<"E2M1", 4, "e2m1">,
131+
I32EnumAttrCase<"BF16", 5, "bf16">
131132

132133
]>{
133134
let cppNamespace = "::mlir::triton";

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
685685

686686
let arguments = (
687687
ins
688-
// inputs are integer types as they are packed types and we currently
689-
// don't have a representation for those.
690-
TT_IntTensor:$lhs,
691-
TT_IntTensor:$rhs,
688+
// inputs are floats if we have a type for them, otherwise (fp4),
689+
// they are packed in pairs in an I8Tensor
690+
RankedTensorOf<[TT_Float,I8]>:$lhs,
691+
RankedTensorOf<[TT_Float,I8]>:$rhs,
692692
TT_FloatTensor:$c,
693-
TT_IntTensor:$lhs_scale,
694-
Optional<TT_IntTensor>:$rhs_scale,
695-
TT_F8F6F4TypeAttr:$lhs_type,
696-
TT_F8F6F4TypeAttr:$rhs_type
693+
RankedTensorOf<[I8]>:$lhs_scale,
694+
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
695+
TT_ScaleDotElemTypeAttr:$lhs_type,
696+
TT_ScaleDotElemTypeAttr:$rhs_type
697697
);
698698

699699
let results = (outs TT_FloatTensor:$d);

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
268268
let arguments = (ins
269269
TT_Tensor:$src,
270270
TT_Tensor:$scale,
271-
TT_F8F6F4TypeAttr:$fp_type);
271+
TT_ScaleDotElemTypeAttr:$fp_type);
272272
let results = (outs TT_Tensor:$result);
273273

274274
let assemblyFormat = [{

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ LogicalResult UpcastMXFPOp::verify() {
3434
"operands must have the same number of dimensions, at least 2");
3535
}
3636

37-
if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
38-
fpType == F8F6F4Type::E5M2)) {
37+
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
38+
fpType == ScaleDotElemType::E5M2)) {
3939
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
4040
}
4141

4242
// Change to support fp8 types
43-
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;
43+
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
4444

4545
if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
4646
return emitOpError("last dimension of first operand must be 16 times "
@@ -93,7 +93,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
9393
return emitOptionalError(loc, "expected a dotOperand encoding");
9494
}
9595

96-
if (typeEncoded == F8F6F4Type::E2M1) {
96+
if (typeEncoded == ScaleDotElemType::E2M1) {
9797
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
9898
auto newVEncoding = DotOperandEncodingAttr::get(
9999
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2
415415
auto aType = dotOp.getLhsType();
416416
auto bType = dotOp.getRhsType();
417417

418-
auto enumToType = [&rewriter](F8F6F4Type type) {
419-
switch (type) {
420-
case F8F6F4Type::E4M3:
421-
return rewriter.getFloat8E4M3FNType();
422-
case F8F6F4Type::E5M2:
423-
return rewriter.getFloat8E5M2Type();
424-
default:
425-
llvm_unreachable("unexpected type");
426-
}
427-
};
428-
429-
assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
430-
aType == F8F6F4Type::E2M1) &&
418+
assert((aType == ScaleDotElemType::E4M3 ||
419+
aType == ScaleDotElemType::E5M2 ||
420+
aType == ScaleDotElemType::E2M1) &&
431421
"NYI: lhs supports fp4 or fp8");
432-
assert(bType == F8F6F4Type::E4M3 ||
433-
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
422+
assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 ||
423+
bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16");
434424

435425
// TODO run accelerate matmul on A and B first to choose their layouts
436426
// Set return type
@@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2
454444
auto newAcc =
455445
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
456446

457-
auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
458-
TypedValue<RankedTensorType> v, int idx,
459-
F8F6F4Type type) -> TypedValue<RankedTensorType> {
447+
auto toMMABf16 =
448+
[&newRetType, &rewriter,
449+
&ctx](TypedValue<RankedTensorType> v, int idx,
450+
ScaleDotElemType type) -> TypedValue<RankedTensorType> {
460451
auto vType = v.getType();
461-
if (type == F8F6F4Type::E2M1) {
452+
if (type == ScaleDotElemType::E2M1) {
462453
// A bit too dynamically typed...
463454
// perhaps return ints in both cases?
464455

@@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2
469460
vType.getShape(), vType.getElementType(), newVEncoding);
470461
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
471462
} else {
472-
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
463+
assert(type == ScaleDotElemType::E5M2 ||
464+
type == ScaleDotElemType::E4M3 ||
465+
type == ScaleDotElemType::BF16);
473466
auto newVEncoding = DotOperandEncodingAttr::get(
474467
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
475468
auto newVType = RankedTensorType::get(
476469
vType.getShape(), vType.getElementType(), newVEncoding);
477470
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
478471

479-
// Bitcast
480-
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
481-
enumToType(type), newVEncoding);
482-
v = cast<TypedValue<RankedTensorType>>(
483-
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
484-
485-
// Convert to bf16
486-
auto vTypeBf16 = RankedTensorType::get(
487-
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
488-
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
472+
if (type == ScaleDotElemType::BF16) {
473+
return v;
474+
} else {
475+
// Convert to bf16
476+
auto vTypeBf16 = RankedTensorType::get(
477+
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
478+
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
479+
}
489480
}
490481
};
491482
a = toMMABf16(a, 0, aType);
@@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2
515506
auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
516507
ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout);
517508

518-
auto newScaleType = RankedTensorType::get(scale.getType().getShape(),
519-
scale.getType().getElementType(),
520-
newScaleEncoding);
521-
scale =
522-
rewriter.create<ConvertLayoutOp>(scale.getLoc(), newScaleType, scale);
509+
auto newScaleDotElemType = RankedTensorType::get(
510+
scale.getType().getShape(), scale.getType().getElementType(),
511+
newScaleEncoding);
512+
scale = rewriter.create<ConvertLayoutOp>(scale.getLoc(),
513+
newScaleDotElemType, scale);
523514

524515
auto scaledA = rewriter.create<triton::gpu::UpcastMXFPOp>(
525516
dotOp.getLoc(), a, scale, dotOp.getLhsType());

python/src/ir.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) {
205205
.value("IEEE", InputPrecision::IEEE)
206206
.export_values();
207207

208-
py::enum_<F8F6F4Type>(m, "F8F6F4TY", py::module_local())
209-
.value("E4M3", F8F6F4Type::E4M3)
210-
.value("E5M2", F8F6F4Type::E5M2)
211-
.value("E2M3", F8F6F4Type::E2M3)
212-
.value("E3M2", F8F6F4Type::E3M2)
213-
.value("E2M1", F8F6F4Type::E2M1)
208+
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())
209+
.value("E4M3", ScaleDotElemType::E4M3)
210+
.value("E5M2", ScaleDotElemType::E5M2)
211+
.value("E2M3", ScaleDotElemType::E2M3)
212+
.value("E3M2", ScaleDotElemType::E3M2)
213+
.value("E2M1", ScaleDotElemType::E2M1)
214+
.value("BF16", ScaleDotElemType::BF16)
214215
.export_values();
215216

216217
py::class_<MLIRContext>(m, "context", py::module_local())
@@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) {
14231424
})
14241425
.def("create_dot_scaled",
14251426
[](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale,
1426-
F8F6F4Type lhs_format, mlir::Value &rhs,
1427-
std::optional<mlir::Value> &rhs_scale, F8F6F4Type rhs_format,
1428-
mlir::Value &c) -> mlir::Value {
1427+
ScaleDotElemType lhs_format, mlir::Value &rhs,
1428+
std::optional<mlir::Value> &rhs_scale,
1429+
ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value {
14291430
return self.create<DotScaledOp>(
14301431
c.getType(), lhs, rhs, c, lhs_scale,
14311432
rhs_scale.value_or(Value()), lhs_format, rhs_format);

python/test/unit/language/test_core.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,7 +3330,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33303330
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
33313331
for col_a, col_b in itertools.product([True, False], repeat=2)
33323332
for type_a in ["e2m1", "e4m3", "e5m2"]
3333-
for type_b in ["e4m3", "e5m2"]
3333+
for type_b in ["e4m3", "e5m2", "bf16"]
33343334
for mma in ([32, 16] if is_hip() else [16])
33353335
for kpack in ([1, 2] if is_hip() else [1])])
33363336
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
@@ -3351,7 +3351,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33513351
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
33523352
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
33533353
type_b: tl.constexpr):
3354-
tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8")
3354+
tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16")
33553355
IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2"
33563356
DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2
33573357
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR
@@ -3442,7 +3442,7 @@ def mxfp_to_bf16_kernel(
34423442

34433443
def dot_scale_ref(x, scale, y, type_x, type_y):
34443444
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
3445-
type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y]
3445+
type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y]
34463446

34473447
comp_dtype = torch.bfloat16
34483448

@@ -3455,7 +3455,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34553455
mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps)
34563456
assert x_upcast.isfinite().all()
34573457

3458-
y_upcast = y.view(type_fp8_y).to(comp_dtype)
3458+
y_upcast = y.view(type_y).to(comp_dtype)
34593459

34603460
class AccumulateInFp32:
34613461

@@ -3467,28 +3467,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
34673467
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value
34683468

34693469
with AccumulateInFp32():
3470-
return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype))
3470+
return torch.matmul(x_upcast, y_upcast)
34713471

34723472
torch.manual_seed(0)
34733473

3474-
def create_uint8(shape, col_major=False, max_val=255):
3474+
def make_arg(shape, ty, col_major=False, max_val=255):
34753475
if col_major:
34763476
shape = shape[:-2] + (shape[-1], shape[-2])
3477-
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
3477+
if ty == "bf16":
3478+
ret = torch.randn(shape, dtype=torch.bfloat16, device=device)
3479+
# Clamp to avoid relative error issues
3480+
ret.clamp_(-2**15, 2**15 - 1)
3481+
else:
3482+
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
34783483
if col_major:
34793484
ret = ret.mT
34803485
return ret
34813486

34823487
DIV_FACTOR = 2 if type_a == "e2m1" else 1
3483-
x = create_uint8((M, K // DIV_FACTOR), col_major=col_a)
3484-
y = create_uint8((K, N), col_major=col_b)
3488+
x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a)
3489+
y = make_arg((K, N), type_b, col_major=col_b)
34853490

34863491
# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
3487-
# We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
3488-
m_bytes = int(type_a[1])
3489-
bias_type_a = 1 << (m_bytes - 1) - 1
3490-
max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a
3491-
scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64)
3492+
# Max scale= 2**15
3493+
scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15)
34923494

34933495
def make_finite(x, dtype):
34943496
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
@@ -3513,7 +3515,6 @@ def make_finite(x, dtype):
35133515

35143516
z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)
35153517

3516-
# generous rtol as we are sampling the whole range of floats
35173518
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)
35183519

35193520
# make sure ld/st are vectorized

python/triton/language/core.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,15 +1556,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
15561556
lhs and rhs use microscaling formats described here:
15571557
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
15581558
:param lhs: The first tensor to be multiplied.
1559-
:type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1559+
:type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs.
15601560
:param lhs_scale: Scale factor for lhs tensor.
1561-
:type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1562-
:param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1561+
:type lhs_scale: e8m0 type represented as an uint8 tensor.
1562+
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}.
1563+
:type lhs_format: str
15631564
:param rhs: The second tensor to be multiplied.
1564-
:type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
1565+
:type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs.
15651566
:param rhs_scale: Scale factor for rhs tensor.
1566-
:type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
1567-
:param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
1567+
:type rhs_scale: e8m0 type represented as an uint8 tensor.
1568+
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}.
1569+
:type rhs_format: str
15681570
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
15691571
"""
15701572
out_dtype = _constexpr_to_value(out_dtype)

python/triton/language/semantic.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
15271527
ret_ty)
15281528

15291529

1530-
def _str_to_fp_type(float_format: Optional[str]):
1531-
if float_format == 'e4m3':
1532-
return ir.F8F6F4TY.E4M3
1533-
if float_format == 'e5m2':
1534-
return ir.F8F6F4TY.E5M2
1535-
if float_format == 'e2m3':
1536-
return ir.F8F6F4TY.E2M3
1537-
if float_format == 'e3m2':
1538-
return ir.F8F6F4TY.E3M2
1539-
if float_format == 'e2m1':
1540-
return ir.F8F6F4TY.E2M1
1541-
raise ValueError(f"Invalid float format: {float_format}.")
1542-
1543-
1544-
def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1545-
rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1530+
def _str_to_fp_type(float_format: str):
1531+
ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
1532+
if ty_enum is None:
1533+
raise ValueError(f"Invalid float format: {float_format}.")
1534+
return ty_enum
1535+
1536+
1537+
def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
1538+
"""
1539+
If float_format is subbyte, make sure it's packed as uint8 and return it.
1540+
Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
1541+
"""
1542+
triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format)
1543+
if triton_ty is None:
1544+
assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
1545+
assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
1546+
return val
1547+
if val.dtype == triton_ty:
1548+
return val
1549+
else:
1550+
unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format]
1551+
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
1552+
return bitcast(val, triton_ty, builder)
1553+
1554+
1555+
def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1556+
rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
15461557
assert lhs.type.is_block() and rhs.type.is_block()
15471558
#TODO: validate types.
15481559
lhs_rank = len(lhs.shape)
15491560
rhs_rank = len(rhs.shape)
15501561
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1562+
lhs_format: str = lhs_format.value
1563+
rhs_format: str = rhs_format.value
15511564
lhs_format_enum = _str_to_fp_type(lhs_format)
15521565
rhs_format_enum = _str_to_fp_type(rhs_format)
15531566
assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}"
1554-
assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}"
1567+
assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}"
15551568
rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
15561569
assert rhs_scale_is_none, "NYI: rhs_scale not supported"
1570+
lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
1571+
rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)
15571572

15581573
M = lhs.type.shape[-2]
15591574
K, N = rhs.type.shape[-2:]

0 commit comments

Comments
 (0)