Skip to content

Commit 9b8a495

Browse files
committed
xegpu: add temporary expand_result_vector op
1 parent 43add00 commit 9b8a495

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,30 @@ def SetGPULaunchThreadsOp
353353
}];
354354
}
355355

356+
def ExpandResultVectorOp : Op<Transform_Dialect, "xegpu.expand_result_vector", [
357+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
358+
TransformOpInterface
359+
]> {
360+
361+
let summary = "Adds a singleton dimension to the op's return vector.";
362+
let description = [{
363+
Adds a singleton dimension to the op's return vector.
364+
}];
365+
366+
let arguments = (ins TransformHandleTypeInterface : $target);
367+
let results = (outs TransformHandleTypeInterface : $transformed);
368+
369+
let assemblyFormat = [{
370+
$target attr-dict `:` functional-type(operands, results)
371+
}];
372+
373+
let extraClassDeclaration = [{
374+
::mlir::DiagnosedSilenceableFailure apply(
375+
::mlir::transform::TransformRewriter &rewriter,
376+
::mlir::transform::TransformResults &transformResults,
377+
::mlir::transform::TransformState &state);
378+
}];
379+
}
380+
381+
356382
#endif // XEGPU_EXTENSION

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,3 +1056,50 @@ void transform::SetGPULaunchThreadsOp::getEffects(
10561056
onlyReadsHandle(getLaunchOpMutable(), effects);
10571057
modifiesPayload(effects);
10581058
}
1059+
1060+
DiagnosedSilenceableFailure
1061+
transform::ExpandResultVectorOp::apply(transform::TransformRewriter &rewriter,
1062+
transform::TransformResults &results,
1063+
transform::TransformState &state) {
1064+
1065+
auto targetOps = state.getPayloadOps(getTarget());
1066+
if (!llvm::hasSingleElement(targetOps)) {
1067+
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
1068+
<< llvm::range_size(targetOps) << ")";
1069+
}
1070+
Operation *target = *targetOps.begin();
1071+
1072+
// Check that target is a vector.transfer_read op.
1073+
if (!isa<vector::TransferReadOp>(target)) {
1074+
return emitDefiniteFailure() << "expected a vector.transfer_read op, but got: "
1075+
<< target->getName();
1076+
}
1077+
auto readOp = dyn_cast<vector::TransferReadOp>(target);
1078+
1079+
// Replace transfer_read op with new op whose return vector's dimension
1080+
// has been extended by a singleton dim in the leading dimension.
1081+
auto vecType = cast<VectorType>(target->getResult(0).getType());
1082+
auto oldShape = vecType.getShape();
1083+
SmallVector<int64_t> newShape{1};
1084+
newShape.append(oldShape.begin(), oldShape.end());
1085+
auto newType = VectorType::get(newShape, vecType.getElementType());
1086+
rewriter.setInsertionPointAfter(readOp);
1087+
// TODO clone read op retaining attributes (if any)
1088+
auto inBounds = SmallVector<bool>{true, true};
1089+
auto newOp = rewriter.create<vector::TransferReadOp>(
1090+
target->getLoc(), newType, readOp.getBase(), ValueRange{readOp.getIndices()},
1091+
std::nullopt, inBounds);
1092+
rewriter.replaceOp(target, newOp);
1093+
1094+
// Map result handles.
1095+
results.set(cast<OpResult>(getTransformed()), {newOp.getOperation()});
1096+
1097+
return DiagnosedSilenceableFailure::success();
1098+
}
1099+
1100+
void transform::ExpandResultVectorOp::getEffects(
1101+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1102+
consumesHandle(getTargetMutable(), effects);
1103+
producesHandle(getOperation()->getOpResults(), effects);
1104+
modifiesPayload(effects);
1105+
}

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,24 @@ def __init__(
261261
loc=loc,
262262
ip=ip
263263
)
264+
265+
266+
@_ods_cext.register_operation(_Dialect, replace=True)
267+
class ExpandResultVectorOp(ExpandResultVectorOp):
268+
"""Specialization for ExpandResultVectorOp class."""
269+
270+
def __init__(
271+
self,
272+
target: Union[Operation, Value],
273+
*,
274+
loc=None,
275+
ip=None,
276+
):
277+
target_value = _get_op_result_or_value(target)
278+
279+
super().__init__(
280+
target_value.type,
281+
target_value,
282+
loc=loc,
283+
ip=ip
284+
)

0 commit comments

Comments
 (0)