Skip to content

Commit 43add00

Browse files
committed
xegpu: relax transform op restrictions
1 parent 997830e commit 43add00

File tree

1 file changed

+3
-45
lines changed

1 file changed

+3
-45
lines changed

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -686,23 +686,14 @@ transform::GetDescOp::applyToOne(transform::TransformRewriter &rewriter,
686686
transform::ApplyToEachResultList &results,
687687
transform::TransformState &state) {
688688

689-
// For now only DPAS op is supported.
690-
auto targetOp = dyn_cast<xegpu::DpasOp>(target);
691-
if (!targetOp) {
692-
auto diag = emitSilenceableFailure(getLoc())
693-
<< "Expected a xegpu.dpas op, but got: " << target->getName();
694-
diag.attachNote(target->getLoc()) << "target op";
695-
return diag;
696-
}
697-
698689
int64_t operandIndex = getOperandIndex();
699-
if (operandIndex >= targetOp.getNumOperands()) {
690+
if (operandIndex >= target->getNumOperands()) {
700691
return emitSilenceableFailure(getLoc())
701692
<< "operandIndex exceeds the number of op operands.";
702693
}
703694

704-
Value opVec = targetOp.getOperation()->getOperand(operandIndex);
705-
auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation());
695+
Value opVec = target->getOperand(operandIndex);
696+
auto maybeDescOp = findDescriptorOp(opVec, target);
706697
if (!maybeDescOp) {
707698
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
708699
}
@@ -772,19 +763,6 @@ transform::SetResultLayoutOp::apply(transform::TransformRewriter &rewriter,
772763
if (!status.succeeded())
773764
return status;
774765

775-
if (sgLayout.size() != 2) {
776-
return emitSilenceableFailure(getLoc())
777-
<< "Expected sg_layout to be a 2D vector";
778-
}
779-
if (sgData.size() != 2) {
780-
return emitSilenceableFailure(getLoc())
781-
<< "Expected sg_data to be a 2D vector";
782-
}
783-
if (instData.size() != 2) {
784-
return emitSilenceableFailure(getLoc())
785-
<< "Expected inst_data to be a 2D vector";
786-
}
787-
788766
// For now only create_nd_desc op is supported.
789767
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
790768
if (!descOp) {
@@ -892,26 +870,6 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
892870
if (!status.succeeded())
893871
return status;
894872

895-
if (sgLayout.size() != 2) {
896-
return emitSilenceableFailure(getLoc())
897-
<< "Expected sg_layout to be a 2D vector";
898-
}
899-
if (sgData.size() != 2) {
900-
return emitSilenceableFailure(getLoc())
901-
<< "Expected sg_data to be a 2D vector";
902-
}
903-
if (instData.size() != 2) {
904-
return emitSilenceableFailure(getLoc())
905-
<< "Expected inst_data to be a 2D vector";
906-
}
907-
908-
// For now only dpas op is supported.
909-
if (!isa<xegpu::DpasOp>(target)) {
910-
auto diag = emitSilenceableFailure(getLoc())
911-
<< "Expected a xegpu.dpas op, but got: " << target->getName();
912-
diag.attachNote(target->getLoc()) << "target op";
913-
return diag;
914-
}
915873
auto layoutAttr =
916874
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
917875
// Set layout attribute for the op result or operand

0 commit comments

Comments
 (0)