From 85e647834a55ac3fe8144509f053c2a982f7822b Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 1 Jul 2025 18:58:52 +0300 Subject: [PATCH 1/7] add xegpu transform ops --- .../include/mlir/Dialect/XeGPU/CMakeLists.txt | 1 + .../Dialect/XeGPU/TransformOps/CMakeLists.txt | 6 + .../XeGPU/TransformOps/XeGPUTransformOps.h | 29 + .../XeGPU/TransformOps/XeGPUTransformOps.td | 132 ++++ mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 + .../Dialect/XeGPU/TransformOps/CMakeLists.txt | 15 + .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 712 ++++++++++++++++++ mlir/test/Dialect/XeGPU/transform-ops.mlir | 213 ++++++ 9 files changed, 1111 insertions(+) create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp create mode 100644 mlir/test/Dialect/XeGPU/transform-ops.mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..5924606402a02 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td) +mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls) +mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen) + +add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h new file mode 100644 index 0000000000000..25b58273b95b2 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h @@ -0,0 +1,29 @@ +//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H +#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +#define GET_OP_CLASSES +#include + +namespace mlir { +class DialectRegistry; + +namespace xegpu { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td new file mode 100644 index 0000000000000..25d9565ef10a5 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -0,0 +1,132 @@ +//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef XEGPU_EXTENSION +#define XEGPU_EXTENSION + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def XeGPUHoistDescOp : Op +]> { + + let summary = "Hoists xegpu tile descriptor ops outside the containing loop"; + let description = [{ + Hoists `xepu.create_nd_tdesc` out of the loop. If the + descriptor's offset is loop dependent, a `xegpu.update_nd_offset` op is + inserted in the loop to increment the offset. + }]; + + let arguments = (ins TransformHandleTypeInterface : $loop); + let results = (outs TransformHandleTypeInterface : $transformed); + + let assemblyFormat = "$loop attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +def XeGPUSetDPASLayoutOp : Op +]> { + + let summary = "Set xegpu.layout attribute to an DPAS op operand."; + let description = [{ + Given a `xegpu.dpas` operation, this transform adds `xegpu.layout` + attribute to it's operand's tensor descriptor. The target operand is + defined by the `tileIndex` argument. The layout is defined by the + `sg_layout`, `sg_data` and `inst_data` attributes. The `load_data` + attribute defines the tile size used for loading the data. It must be a + multiple of the `inst_data` size. + }]; + + let arguments = (ins TransformHandleTypeInterface : $dpasOp, + I64Attr : $tileIndex, + DenseI32ArrayAttr : $sgLayout, + DenseI32ArrayAttr : $sgData, + OptionalAttr : $loadData, + DenseI32ArrayAttr : $instData); + + let results = (outs); + + let assemblyFormat = + "$dpasOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` " + "$sgData (`load_data` `=` $loadData^)? `inst_data` `=` $instData attr-dict `:` type($dpasOp)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +def XeGPUInsertPrefetchOp : Op]> { + + let summary = "Adds xegpu prefetch ops to matmul operand tiles."; + let description = [{ + Given a `xegpu.dpas` operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes. + }]; + + let arguments = (ins TransformHandleTypeInterface : $dpasOp, + TransformHandleTypeInterface : $loopOp, + I64Attr : $tileIndex, + DenseI32ArrayAttr : $sgLayout, + DenseI32ArrayAttr : $sgData); + + let results = (outs TransformHandleTypeInterface : $transformedDpasOp, + TransformHandleTypeInterface : $transformedLoopOp); + + let assemblyFormat = + "$dpasOp $loopOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` " + "$sgData attr-dict `:` functional-type(operands, results)"; +} + +// TODO this should be handled with gpu transform ops. +// Add gpu mapping to scf.forall op and use something like +// transform.gpu.map_forall_to_blocks to convert to gpu.launch op. +def XeGPUSetGPULaunchThreadsOp + : Op + ]> { + + let summary = "Set number of threads for a given gpu.launch operation"; + let description = [{Set number of threads for a given gpu.launch operation}]; + + let arguments = (ins TransformHandleTypeInterface + : $launchOp, DenseI32ArrayAttr + : $threads); + let results = (outs); + let assemblyFormat = + "$launchOp `threads` `=` $threads attr-dict `:` type($launchOp)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +#endif // XEGPU_EXTENSION diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c3aeba7..8adac87014486 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -55,6 +55,7 @@ #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -114,6 +115,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); + xegpu::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6af908b..46b8251a57797 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..63245148938a5 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000000000..9756fe6180e47 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,712 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "xegpu-transforms" + +using namespace mlir; + +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include + >(); +} + +#define GET_OP_CLASSES +#include + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} + +/// Recurse operands and collect all producer ops in the given region. +void collectProducerOps(Operation *op, Region &inRegion, + SmallVector &ops) { + for (auto val : op->getOperands()) { + if (const auto definingOp = val.getDefiningOp(); + definingOp && definingOp->getParentRegion() == &inRegion) { + ops.push_back(definingOp); + collectProducerOps(definingOp, inRegion, ops); + } + } +} + +/// Returns all producer ops in the given region +SmallVector getProducerOpsInRegion(Operation *op, Region &inRegion, + bool includeOp = true) { + SmallVector producerOps; + if (includeOp) { + producerOps.push_back(op); + } + collectProducerOps(op, inRegion, producerOps); + return producerOps; +} + +/// Find xegpu.create_nd_desc op for the given operand value. +static std::optional +findDescriptorOp(Value operandValue, Operation *userOp) { + // FIXME more generic way of finding desc op that may be outside the loop + Value currentValue = operandValue; + if (!currentValue.getDefiningOp()) { + // Desc op may reside outside a loop. + auto forOp = userOp->getParentOfType(); + if (!forOp) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to find operand desc op, def op not a loop."); + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast(currentValue)) { + auto numInductionVars = forOp.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = forOp.getInits()[iterArgIdx]; + } else { + LLVM_DEBUG(llvm::dbgs() + << "Failed to find operand desc op, def op not an init val."); + return std::nullopt; + } + } + auto findDescOp = [](Value val) -> std::optional { + Operation *producerOp = val.getDefiningOp(); + while (producerOp) { + if (auto maybeDescOp = dyn_cast(producerOp)) { + return maybeDescOp; + } + if (producerOp->getNumOperands() == 0) + break; + producerOp = producerOp->getOperand(0).getDefiningOp(); + } + return std::nullopt; + }; + return findDescOp(currentValue); +} + +// Get user of type T in immediate users of the value. +template +static std::optional getUserOfType(Value value) { + auto users = value.getUsers(); + auto it = llvm::find_if(users, [&](Operation *op) { return isa(op); }); + if (it != users.end()) { + return cast(*it); + } + return std::nullopt; +} + +/// Add offset update op after create desc op if tile is updated in the loop. +xegpu::CreateNdDescOp insertUpdateOp(transform::TransformRewriter &rewriter, + scf::ForOp parentLoopOp, + xegpu::CreateNdDescOp descOp) { + // DescOp offset is an affine map with loop dependent and independent + // components. The new desc op will be loop independent, i.e. it uses the + // constant offset. The remainder offset, 'offset - constant', will be used + // in the update offset op. + + // Compute the constant offset. + // Clone desc op producers and replace loop variable with lower bound. + rewriter.setInsertionPointAfter(descOp); + auto loc = descOp.getLoc(); + IRMapping mapping; + SmallVector clonedOps; + auto producers = getProducerOpsInRegion(descOp.getOperation(), + parentLoopOp.getRegion(), true); + for (auto &op : llvm::reverse(producers)) { + auto newOp = rewriter.clone(*op, mapping); + clonedOps.push_back(newOp); + } + // Replace loop induction variable. + rewriter.replaceUsesWithIf(parentLoopOp.getInductionVar(), + parentLoopOp.getLowerBound(), [&](OpOperand &use) { + return ::llvm::is_contained(clonedOps, + use.getOwner()); + }); + auto newDescOp = cast(clonedOps.back()); + + // Compute offset for update operation: original offset - constant offset. + llvm::SmallVector origDynamicOffsets, constDynamicOffsets, + dynamicOffsets; + llvm::SmallVector origStaticOffsets, constStaticOffsets, + staticOffsets; + dispatchIndexOpFoldResults(descOp.getMixedOffsets(), origDynamicOffsets, + origStaticOffsets); + dispatchIndexOpFoldResults(newDescOp.getMixedOffsets(), constDynamicOffsets, + constStaticOffsets); + int64_t dynIndex = 0; + for (auto [i, origStaticOffset] : llvm::enumerate(origStaticOffsets)) { + if (origStaticOffset != ShapedType::kDynamic) { + // Original offset was a constant, difference must be 0. + staticOffsets.push_back(0); + } else { + auto origDynOffset = origDynamicOffsets[dynIndex]; + auto cstDynOffset = constDynamicOffsets[dynIndex]; + auto subValue = rewriter.createOrFold( + loc, origDynOffset.getType(), origDynOffset, cstDynOffset); + auto maybeIntValue = getConstantIntValue(subValue); + if (maybeIntValue) { + // Folded to a constant int. + staticOffsets.push_back(*maybeIntValue); + } else { + // Dynamic offset. + dynamicOffsets.push_back(subValue); + staticOffsets.push_back(ShapedType::kDynamic); + } + dynIndex++; + } + } + + // Insert an offset update op if non-trivial offset. + bool allZeros = llvm::all_of(staticOffsets, [](int64_t s) { return s == 0; }); + if (!dynamicOffsets.empty() || !allZeros) { + auto tile = newDescOp.getResult(); + auto offsetOp = rewriter.create( + loc, tile.getType(), tile, dynamicOffsets, staticOffsets); + // replace subsequent uses of the descriptor with the offset descriptor + rewriter.replaceUsesWithIf( + descOp.getResult(), offsetOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != offsetOp.getOperation(); + }); + } + rewriter.replaceOp(descOp, newDescOp); + return newDescOp; +} + +/// Add offset update ops after create desc ops in the loop body. +LogicalResult insertOffsetUpdateOps(transform::TransformRewriter &rewriter, + scf::ForOp loopOp) { + // Find all create desc operations in the loop body + SmallVector createDescOps; + for (auto &op : loopOp.getBody()->getOperations()) { + if (isa(op)) { + createDescOps.push_back(&op); + } + } + if (createDescOps.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No xegpu.create_nd_desc ops found in the loop body.\n"); + return failure(); + } + // Split to desc and offset update ops. + for (auto &op : createDescOps) { + auto descOp = cast(op); + insertUpdateOp(rewriter, loopOp, descOp); + } + return success(); +} + +/// Check if an op can be hoisted out of the loop. +static bool canBeHoisted(Operation *op, LoopLikeOpInterface &loopLike) { + return llvm::all_of(op->getOperands(), [&](Value value) { + return loopLike.isDefinedOutsideOfLoop(value); + }); +} + +/// Hoist create desc ops out of the loop. +/// If offset update ops exist, add values to loop iter_args and yield +FailureOr hoistDescOps(transform::TransformRewriter &rewriter, + scf::ForOp loopOp) { + SmallVector descOps; + auto loopLike = cast(loopOp.getOperation()); + for (auto &op : loopOp.getBody()->getOperations()) { + if (auto descOp = dyn_cast(op)) { + if (canBeHoisted(descOp.getOperation(), loopLike)) { + descOps.push_back(descOp); + } + } + } + if (descOps.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No hoistable create_nd_desc ops found in the loop body.\n"); + return loopOp; + } + + SmallVector initValues, yieldValues; + for (auto &descOp : descOps) { + // We assume tensor desc is used by an offset update op, find it. + auto maybeOffsetOp = getUserOfType(descOp.getResult()); + if (!maybeOffsetOp) { + continue; + } + auto offsetOp = *maybeOffsetOp; + + // Hoist desc op. + auto producers = + getProducerOpsInRegion(descOp.getOperation(), loopOp.getRegion(), true); + for (auto &op : llvm::reverse(producers)) { + rewriter.moveOpBefore(op, loopOp); + } + + // Offset update op must be converted to increment the offset, instead of + // defining an absolute offset wrt the original descriptor tile. + // In offset update producer ops, replace loop variable with step size. + auto offsetProducerOps = + getProducerOpsInRegion(offsetOp.getOperation(), loopOp.getRegion()); + rewriter.replaceUsesWithIf( + loopOp.getInductionVar(), loopOp.getStep(), [&](OpOperand &use) { + return llvm::is_contained(offsetProducerOps, use.getOwner()); + }); + // Offsetted desc now points to next tile, users must use the current tile + rewriter.replaceAllUsesWith(offsetOp.getResult(), offsetOp.getTensorDesc()); + // Add to loop init/yield values. + initValues.push_back(descOp.getResult()); + yieldValues.push_back(offsetOp.getResult()); + } + // Rewrite loop with new init/yield values. + NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, + llvm::ArrayRef newBBArgs) { + return yieldValues; + }; + auto maybeNewLoop = loopOp.replaceWithAdditionalYields( + rewriter, initValues, + /*replaceInitOperandUsesInLoop=*/true, yieldFn); + if (failed(maybeNewLoop)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to generate a new loop.\n"); + return failure(); + } + return cast(*maybeNewLoop); +} + +/// Create a layout attribute from the given parameters. +xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef sgLayout, + ArrayRef sgData, + std::optional> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, + xegpu::LayoutAttr layout) { + auto ctx = rewriter.getContext(); + auto oldTensorDesc = descOp.getResult(); + auto descShapedType = cast(oldTensorDesc.getType()); + // This discards any block_tdesc_attr attributes. + auto descType = xegpu::TensorDescType::get(ctx, descShapedType.getShape(), + descShapedType.getElementType(), + /*encoding=*/nullptr, + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp( + descOp, descType, descOp.getSource(), descOp.getMixedOffsets(), + descOp.getMixedSizes(), descOp.getMixedStrides()); + + return newDescOp; +} + +/// Fuse two scf.for loops into one. Keeps track of source operations to their +/// cloned targets. Returns the new fused loop. +scf::ForOp fuseForLoops(scf::ForOp target, scf::ForOp source, + RewriterBase &rewriter, + SmallVector &sourceOps, + SmallVector &targetOps) { + // This method is modified from mlir::fuseIndependentSiblingForLoops to + // trace the source ops to their cloned targets. + + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused init_args, with target's init_args before source's init_args. + SmallVector fusedInitArgs; + llvm::append_range(fusedInitArgs, target.getInitArgs()); + llvm::append_range(fusedInitArgs, source.getInitArgs()); + + // Create a new scf.for op after the source loop (with scf.yield terminator + // (without arguments) only in case its init_args is empty). + rewriter.setInsertionPointAfter(source); + scf::ForOp fusedLoop = rewriter.create( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), fusedInitArgs); + + // Map original induction variables and operands to those of the fused loop. + IRMapping mapping; + mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Merge target's body into the new (fused) for loop and then source's body. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + IRMapping clonedOpsMapping; + for (Operation &op : target.getBody()->without_terminator()) { + auto newOp = rewriter.clone(op, mapping); + clonedOpsMapping.map(&op, newOp); + } + for (Operation &op : source.getBody()->without_terminator()) { + auto newOp = rewriter.clone(op, mapping); + clonedOpsMapping.map(&op, newOp); + } + // Map the given source operations to their cloned targets. + auto opsMap = clonedOpsMapping.getOperationMap(); + for (Operation *op : sourceOps) { + auto it = opsMap.find(op); + if (it != opsMap.end()) { + targetOps.push_back(it->second); + } else { + targetOps.push_back(nullptr); + } + } + + // Build fused yield results by appropriately mapping original yield operands. + SmallVector yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create(source.getLoc(), yieldResults); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + + return fusedLoop; +} + +DiagnosedSilenceableFailure transform::XeGPUHoistDescOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + auto loopOp = dyn_cast(target); + if (!loopOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a scf.for op, but got: " << target->getName(); + } + + if (failed(insertOffsetUpdateOps(rewriter, loopOp))) { + return emitSilenceableFailure(getLoc()) + << "No desc ops found in the loop body " << target->getName(); + } + auto newLoopOp = hoistDescOps(rewriter, loopOp); + if (failed(newLoopOp)) { + auto diag = emitSilenceableFailure(getLoc()) + << "Failed to hoist xegpu.create_nd_desc ops"; + diag.attachNote(loopOp.getLoc()) << "loop op"; + return diag; + } + loopOp = *newLoopOp; + results.push_back(loopOp.getOperation()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::XeGPUHoistDescOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + consumesHandle(getLoopMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::XeGPUInsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + + auto dpasOps = state.getPayloadOps(getDpasOp()); + auto loopOps = state.getPayloadOps(getLoopOp()); + + if (!llvm::hasSingleElement(dpasOps)) { + return emitDefiniteFailure() << "requires exactly one dpasOp handle (got " + << llvm::range_size(dpasOps) << ")"; + } + if (!llvm::hasSingleElement(loopOps)) { + return emitDefiniteFailure() << "requires exactly one loopOp handle (got " + << llvm::range_size(loopOps) << ")"; + } + + Operation *dpasPtr = *dpasOps.begin(); + auto dpasOp = dyn_cast(dpasPtr); + if (!dpasOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a xegpu.dpas op, but got: " << dpasPtr->getName(); + } + + Operation *loopPtr = *loopOps.begin(); + auto forOp = dyn_cast(loopPtr); + if (!forOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a scf.for op, but got: " << loopPtr->getName(); + } + + auto parentLoop = dpasOp->getParentOfType(); + if (!parentLoop || parentLoop != forOp) { + return emitSilenceableFailure(getLoc()) + << "dpasOp is not contained in the given scf.for loop."; + } + + int64_t tileIndex = getTileIndex(); + if (tileIndex >= dpasOp.getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "tileIndex exceeds the number of op operands."; + } + + auto sgLayout = getSgLayout(); + if (sgLayout.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_layout to be a 2D vector"; + } + + auto sgData = getSgData(); + if (sgData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_data to be a 2D vector"; + } + + // Find descriptor op of the operand. + Value opVec = dpasOp.getOperation()->getOperand(tileIndex); + auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + + // Clone reduction loop. + rewriter.setInsertionPoint(forOp); + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep()); + // Clone desc op into it. + rewriter.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + auto newDescOp = cast( + rewriter.clone(*descOp.getOperation(), mapping)); + // Set desc op layout. + auto layout = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, + /*instData=*/std::nullopt); + newDescOp = setDescLayout(rewriter, newDescOp, layout); + + // Insert prefetch op. + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + rewriter.create(newDescOp.getLoc(), + newDescOp.getResult(), readCacheHint, + readCacheHint, readCacheHint); + + // Insert offset update op. + insertUpdateOp(rewriter, newForOp, newDescOp); + // Hoist descriptor op out of the loop. + auto maybenewForOp = hoistDescOps(rewriter, newForOp); + if (failed(maybenewForOp)) { + auto diag = emitSilenceableFailure(getLoc()) + << "Failed to hoist xegpu.create_nd_desc ops"; + diag.attachNote(newForOp.getLoc()) << "loop op"; + return diag; + } + newForOp = *maybenewForOp; + + // Peel first iteration of the loop and reset lower bound to original value. + scf::ForOp firstLoopOp; + if (failed(scf::peelForLoopFirstIteration(rewriter, newForOp, firstLoopOp))) { + auto diag = emitSilenceableFailure(getLoc()) << "Failed to peel the loop"; + } + newForOp.setLowerBound(forOp.getLowerBound()); + + // Fuse with the original loop, keep track of cloned ops. + SmallVector sourceOps{dpasOp.getOperation()}, targetOps; + auto fusedLoop = + fuseForLoops(newForOp, forOp, rewriter, sourceOps, targetOps); + assert(fusedLoop && "failed to fuse loops"); + + // Get the cloned dpas op. + auto clonedDpasOp = targetOps[0]; + if (!clonedDpasOp) { + return emitSilenceableFailure(getLoc()) + << "Failed to find cloned dpas op in the fused loop."; + } + + // Map result handles. + results.set(cast(getTransformedLoopOp()), {fusedLoop}); + results.set(cast(getTransformedDpasOp()), {clonedDpasOp}); + + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::XeGPUSetDPASLayoutOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + auto dpasOp = dyn_cast(target); + if (!dpasOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.dpas op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + int64_t tileIndex = getTileIndex(); + if (tileIndex >= dpasOp.getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "tileIndex exceeds the number of op operands."; + } + + auto sgLayout = getSgLayout(); + if (sgLayout.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_layout to be a 2D vector"; + } + + auto sgData = getSgData(); + if (sgData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_data to be a 2D vector"; + } + + auto instData = getInstData(); + if (instData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected inst_data to be a 2D vector"; + } + + llvm::ArrayRef loadData = instData; + if (getLoadData().has_value()) { + loadData = getLoadData().value(); + if (loadData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected load_data to be a 2D vector"; + } + if (loadData[0] < instData[0] || loadData[1] < instData[1]) { + return emitSilenceableFailure(getLoc()) + << "load_data size must be larger or equal to inst_data size"; + } + if (loadData[0] % instData[0] != 0 || loadData[1] % instData[1] != 0) { + return emitSilenceableFailure(getLoc()) + << "load_data must be evenly divisible by inst_data"; + } + } + + // Replace descriptor op using layout attribute. + Value opVec = dpasOp.getOperation()->getOperand(tileIndex); + auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + // Layout for the load op. + auto loadLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, loadData); + descOp = setDescLayout(rewriter, descOp, loadLayoutAttr); + // Layout for the instruction. + auto instLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); + if (tileIndex == 2) { + // C operand: set layout attribute for the dpas op result + xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], instLayoutAttr); + } + + if (loadLayoutAttr != instLayoutAttr) { + // Insert convert layout op after load op. + auto maybeLoadOp = getUserOfType(descOp.getResult()); + if (!maybeLoadOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a xegpu.load_nd op as a user of the descriptor op."; + } + auto loadOp = *maybeLoadOp; + rewriter.setInsertionPointAfter(loadOp.getOperation()); + auto source = loadOp.getResult(); + auto convLayoutOp = rewriter.create( + loadOp.getLoc(), source.getType(), source, + loadLayoutAttr, instLayoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + source, convLayoutOp.getResult(), + [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::XeGPUSetDPASLayoutOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getDpasOpMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure transform::XeGPUSetGPULaunchThreadsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + auto launchOp = dyn_cast(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + auto threads = getThreads(); + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads to be a 3D vector"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return rewriter.create(launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + return DiagnosedSilenceableFailure::success(); +} + +void transform::XeGPUSetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getLaunchOpMutable(), effects); + modifiesPayload(effects); +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir new file mode 100644 index 0000000000000..11d2b000b03b1 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -0,0 +1,213 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @hoist_desc_ops +func.func @hoist_desc_ops(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-NEXT: scf.for + scf.for %arg1 = %c0 to %c4096 step %c32 { + // CHECK: xegpu.update_nd_offset + // CHECK: xegpu.load_nd + %0 = xegpu.create_nd_tdesc %arg0[0, %arg1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: = transform.xegpu.hoist_desc_ops %{{.*}} + %1 = transform.xegpu.hoist_desc_ops %0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_dpas_layout_a +func.func @set_dpas_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V1]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_dpas_layout %{{.*}} + transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_dpas_layout_b +func.func @set_dpas_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: = xegpu.create_nd_tdesc + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: = xegpu.load_nd + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg1 + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK-SAME: #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %1, %[[V1]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_dpas_layout %{{.*}} + transform.xegpu.set_dpas_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_dpas_layout_c +func.func @set_dpas_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: = xegpu.create_nd_tdesc + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: = xegpu.load_nd + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = xegpu.create_nd_tdesc + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: = xegpu.load_nd + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg2 + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // CHECK-SAME: #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %1, %3, %[[V1]] {layout_result_0 = #xegpu.layout} + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_dpas_layout %{{.*}} + transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] load_data = [8, 16] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_dpas_layout_load_a +func.func @set_dpas_layout_load_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK-SAME: resMap = #xegpu.layout + // CHECK-SAME: srcMap = #xegpu.layout + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_dpas_layout %{{.*}} + transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] load_data = [32, 16] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: %[[C32:.+]] = arith.constant 32 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16, #xegpu.layout> + // Peeled first iteration of the loop, canonicalization drops the scf.for + // CHECK: %[[V2:.+]] = scf.for + // CHECK-SAME: iter_args(%[[V1:.+]] = %[[V0]]) + // CHECK: = xegpu.update_nd_offset %[[V1]], [0, %[[C32]]] + // CHECK: xegpu.prefetch_nd %[[V1]] + // Reduction loop + // CHECK: scf.for + // CHECK-SAME: iter_args(%[[V3:.+]] = %[[V2]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: = xegpu.update_nd_offset %[[V3]], [0, %[[C32]]] + // CHECK: xegpu.prefetch_nd %[[V3]] + %3 = xegpu.create_nd_tdesc %arg0[0, %arg3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.load_nd %3 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %5 = xegpu.create_nd_tdesc %arg1[%arg3, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %4, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["xegpu.dpas"]} in %0 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.insert_prefetch %{{.*}} %{{.*}} + %2, %3 = transform.xegpu.insert_prefetch %1 %0 index = 0 sg_layout = [32, 1] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads +func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } {SCFToGPU_visited} + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} From 98d5f0a2282ab9b4686e51766adc8fce7793f760 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 09:28:10 +0300 Subject: [PATCH 2/7] xegpu: add xegpu transform op python bindinds --- .../Dialect/XeGPU/TransformOps/CMakeLists.txt | 2 + mlir/python/CMakeLists.txt | 9 ++ .../python/mlir/dialects/XeGPUTransformOps.td | 19 +++ mlir/python/mlir/dialects/transform/xegpu.py | 112 ++++++++++++++++++ 4 files changed, 142 insertions(+) create mode 100644 mlir/python/mlir/dialects/XeGPUTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/xegpu.py diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt index 63245148938a5..48fe841afaa83 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRXeGPUTransformOps MLIRXeGPUTransformOpsIncGen LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms MLIRIR MLIRTransformDialect MLIRFuncDialect diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..a6a2d6666b327 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -303,6 +303,15 @@ declare_mlir_dialect_extension_python_bindings( "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" ) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/XeGPUTransformOps.td + SOURCES + dialects/transform/xegpu.py + DIALECT_NAME transform + EXTENSION_NAME xegpu_transform) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td new file mode 100644 index 0000000000000..5a5e7b912c4a5 --- /dev/null +++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td @@ -0,0 +1,19 @@ +//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the XeGPU transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS + +include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td" + +#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py new file mode 100644 index 0000000000000..42ee59f70ce7f --- /dev/null +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -0,0 +1,112 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._xegpu_transform_ops_gen import * +from .._xegpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext + from .._ods_common import get_op_result_or_value as _get_op_result_or_value +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class XeGPUSetDPASLayoutOp(XeGPUSetDPASLayoutOp): + """Specialization for XeGPUSetDPASLayoutOp class.""" + + def __init__( + self, + dpas_op: Union[Operation, Value], + tile_index: Union[int, Attribute], + sg_layout: Union[Sequence[int], Attribute], + sg_data: Union[Sequence[int], Attribute], + inst_data: Union[Sequence[int], Attribute], + *, + load_data: Optional[Union[Sequence[int], Attribute]] = None, + loc=None, + ip=None, + ): + super().__init__( + dpas_op, + tile_index, + sg_layout, + sg_data, + inst_data, + loadData=load_data, + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class XeGPUInsertPrefetchOp(XeGPUInsertPrefetchOp): + """Specialization for XeGPUInsertPrefetchOp class.""" + + def __init__( + self, + dpas_op: Union[Operation, Value], + loop_op: Union[Operation, Value], + tile_index: Union[int, Attribute], + sg_layout: Union[Sequence[int], Attribute], + sg_data: Union[Sequence[int], Attribute], + loc=None, + ip=None, + ): + # results = get_op_result_or_op_results(dpas_op, loop_op) + transformed_dpas_type = transform.AnyOpType.get() + transformed_loop_type = transform.AnyOpType.get() + super().__init__( + transformed_dpas_type, + transformed_loop_type, + _get_op_result_or_value(dpas_op), + _get_op_result_or_value(loop_op), + tile_index, + sg_layout, + sg_data, + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class XeGPUHoistDescOp(XeGPUHoistDescOp): + """Specialization for XeGPUHoistDescOp class.""" + + def __init__( + self, + loop_op: Union[Operation, Value], + loc=None, + ip=None, + ): + transformed_loop_type = transform.AnyOpType.get() + super().__init__( + transformed_loop_type, + _get_op_result_or_value(loop_op), + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class XeGPUSetGPULaunchThreadsOp(XeGPUSetGPULaunchThreadsOp): + """Specialization for XeGPUSetGPULaunchThreadsOp class.""" + + def __init__( + self, + launch_op: Union[Operation, Value], + threads: Union[int, Attribute], + loc=None, + ip=None, + ): + super().__init__( + _get_op_result_or_value(launch_op), + threads, + loc=loc, + ip=ip + ) From d73ef0de4a6e67e91367ee1a953ffc54b9d61c01 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 09:41:01 +0300 Subject: [PATCH 3/7] xegpu: drop XeGPU prefix from transform op names --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 8 ++++---- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 14 +++++++------- mlir/python/mlir/dialects/transform/xegpu.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 25d9565ef10a5..cdcb48716efa4 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -15,7 +15,7 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -def XeGPUHoistDescOp : Op ]> { @@ -41,7 +41,7 @@ def XeGPUHoistDescOp : Op ]> { @@ -78,7 +78,7 @@ def XeGPUSetDPASLayoutOp : Op]> { @@ -104,7 +104,7 @@ def XeGPUInsertPrefetchOp : Op diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 9756fe6180e47..8198cfef8c10d 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -413,7 +413,7 @@ scf::ForOp fuseForLoops(scf::ForOp target, scf::ForOp source, return fusedLoop; } -DiagnosedSilenceableFailure transform::XeGPUHoistDescOp::applyToOne( +DiagnosedSilenceableFailure transform::HoistDescOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { @@ -440,7 +440,7 @@ DiagnosedSilenceableFailure transform::XeGPUHoistDescOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -void transform::XeGPUHoistDescOp::getEffects( +void transform::HoistDescOp::getEffects( ::llvm::SmallVectorImpl &effects) { consumesHandle(getLoopMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); @@ -448,7 +448,7 @@ void transform::XeGPUHoistDescOp::getEffects( } DiagnosedSilenceableFailure -transform::XeGPUInsertPrefetchOp::apply(transform::TransformRewriter &rewriter, +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { @@ -573,7 +573,7 @@ transform::XeGPUInsertPrefetchOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::XeGPUSetDPASLayoutOp::applyToOne( +DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { @@ -668,13 +668,13 @@ DiagnosedSilenceableFailure transform::XeGPUSetDPASLayoutOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -void transform::XeGPUSetDPASLayoutOp::getEffects( +void transform::SetDPASLayoutOp::getEffects( ::llvm::SmallVectorImpl &effects) { onlyReadsHandle(getDpasOpMutable(), effects); modifiesPayload(effects); } -DiagnosedSilenceableFailure transform::XeGPUSetGPULaunchThreadsOp::applyToOne( +DiagnosedSilenceableFailure transform::SetGPULaunchThreadsOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { @@ -705,7 +705,7 @@ DiagnosedSilenceableFailure transform::XeGPUSetGPULaunchThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } -void transform::XeGPUSetGPULaunchThreadsOp::getEffects( +void transform::SetGPULaunchThreadsOp::getEffects( ::llvm::SmallVectorImpl &effects) { onlyReadsHandle(getLaunchOpMutable(), effects); modifiesPayload(effects); diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 42ee59f70ce7f..00034db30ee0c 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -17,8 +17,8 @@ @_ods_cext.register_operation(_Dialect, replace=True) -class XeGPUSetDPASLayoutOp(XeGPUSetDPASLayoutOp): - """Specialization for XeGPUSetDPASLayoutOp class.""" +class SetDPASLayoutOp(SetDPASLayoutOp): + """Specialization for SetDPASLayoutOp class.""" def __init__( self, @@ -45,8 +45,8 @@ def __init__( @_ods_cext.register_operation(_Dialect, replace=True) -class XeGPUInsertPrefetchOp(XeGPUInsertPrefetchOp): - """Specialization for XeGPUInsertPrefetchOp class.""" +class InsertPrefetchOp(InsertPrefetchOp): + """Specialization for InsertPrefetchOp class.""" def __init__( self, @@ -75,8 +75,8 @@ def __init__( @_ods_cext.register_operation(_Dialect, replace=True) -class XeGPUHoistDescOp(XeGPUHoistDescOp): - """Specialization for XeGPUHoistDescOp class.""" +class HoistDescOp(HoistDescOp): + """Specialization for HoistDescOp class.""" def __init__( self, @@ -94,8 +94,8 @@ def __init__( @_ods_cext.register_operation(_Dialect, replace=True) -class XeGPUSetGPULaunchThreadsOp(XeGPUSetGPULaunchThreadsOp): - """Specialization for XeGPUSetGPULaunchThreadsOp class.""" +class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp): + """Specialization for SetGPULaunchThreadsOp class.""" def __init__( self, From bf7cf0a24ab04ddcaba38484e9a17b6e6e28c06b Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 17:39:25 +0300 Subject: [PATCH 4/7] xegpu: remove load_data argument from set_dpas_layout transform op --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 7 +-- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 50 ++----------------- mlir/python/mlir/dialects/transform/xegpu.py | 2 - mlir/test/Dialect/XeGPU/transform-ops.mlir | 32 +----------- 4 files changed, 8 insertions(+), 83 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index cdcb48716efa4..e2f19d41802bf 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -51,23 +51,20 @@ def SetDPASLayoutOp : Op : $loadData, DenseI32ArrayAttr : $instData); let results = (outs); let assemblyFormat = "$dpasOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` " - "$sgData (`load_data` `=` $loadData^)? `inst_data` `=` $instData attr-dict `:` type($dpasOp)"; + "$sgData `inst_data` `=` $instData attr-dict `:` type($dpasOp)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 8198cfef8c10d..3b359b53b62c6 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -610,23 +610,6 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( << "Expected inst_data to be a 2D vector"; } - llvm::ArrayRef loadData = instData; - if (getLoadData().has_value()) { - loadData = getLoadData().value(); - if (loadData.size() != 2) { - return emitSilenceableFailure(getLoc()) - << "Expected load_data to be a 2D vector"; - } - if (loadData[0] < instData[0] || loadData[1] < instData[1]) { - return emitSilenceableFailure(getLoc()) - << "load_data size must be larger or equal to inst_data size"; - } - if (loadData[0] % instData[0] != 0 || loadData[1] % instData[1] != 0) { - return emitSilenceableFailure(getLoc()) - << "load_data must be evenly divisible by inst_data"; - } - } - // Replace descriptor op using layout attribute. Value opVec = dpasOp.getOperation()->getOperand(tileIndex); auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation()); @@ -634,35 +617,12 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; } auto descOp = *maybeDescOp; - // Layout for the load op. - auto loadLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, loadData); - descOp = setDescLayout(rewriter, descOp, loadLayoutAttr); - // Layout for the instruction. - auto instLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); + // Set layout attribute. + auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); + descOp = setDescLayout(rewriter, descOp, layoutAttr); if (tileIndex == 2) { - // C operand: set layout attribute for the dpas op result - xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], instLayoutAttr); - } - - if (loadLayoutAttr != instLayoutAttr) { - // Insert convert layout op after load op. - auto maybeLoadOp = getUserOfType(descOp.getResult()); - if (!maybeLoadOp) { - return emitSilenceableFailure(getLoc()) - << "Expected a xegpu.load_nd op as a user of the descriptor op."; - } - auto loadOp = *maybeLoadOp; - rewriter.setInsertionPointAfter(loadOp.getOperation()); - auto source = loadOp.getResult(); - auto convLayoutOp = rewriter.create( - loadOp.getLoc(), source.getType(), source, - loadLayoutAttr, instLayoutAttr); - // Replace load op result with the converted layout. - rewriter.replaceUsesWithIf( - source, convLayoutOp.getResult(), - [&](OpOperand &use) { - return use.getOwner() != convLayoutOp.getOperation(); - }); + // C operand: set layout attribute for the dpas op result. + xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr); } return DiagnosedSilenceableFailure::success(); diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 00034db30ee0c..6324d80bf4fac 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -28,7 +28,6 @@ def __init__( sg_data: Union[Sequence[int], Attribute], inst_data: Union[Sequence[int], Attribute], *, - load_data: Optional[Union[Sequence[int], Attribute]] = None, loc=None, ip=None, ): @@ -38,7 +37,6 @@ def __init__( sg_layout, sg_data, inst_data, - loadData=load_data, loc=loc, ip=ip ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index 11d2b000b03b1..e4aec4f775b31 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -107,37 +107,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op // CHECK: transform.xegpu.set_dpas_layout %{{.*}} - transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] load_data = [8, 16] inst_data = [8, 16] : !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: @set_dpas_layout_load_a -func.func @set_dpas_layout_load_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { - // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 - // CHECK-SAME: #xegpu.layout> - %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> - // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] - // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] - // CHECK-SAME: resMap = #xegpu.layout - // CHECK-SAME: srcMap = #xegpu.layout - %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> - %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> - %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> - %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> - // CHECK: = xegpu.dpas %[[V2]] - %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: transform.xegpu.set_dpas_layout %{{.*}} - transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] load_data = [32, 16] inst_data = [8, 16] : !transform.any_op + transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op transform.yield } } From a45730c5722bc9878c7a4e0003800390bf41cf3e Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 18:21:52 +0300 Subject: [PATCH 5/7] xegpu: rename set_dpas_layout to set_operand_layout rename tileIndex to operandIndex remove all references to dpas ops where possible --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 26 +++---- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 71 ++++++++++--------- mlir/python/mlir/dialects/transform/xegpu.py | 25 ++++--- mlir/test/Dialect/XeGPU/transform-ops.mlir | 24 +++---- 4 files changed, 73 insertions(+), 73 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index e2f19d41802bf..ece830bd6d772 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -41,21 +41,21 @@ def HoistDescOp : Op ]> { - let summary = "Set xegpu.layout attribute to an DPAS op operand."; + let summary = "Set xegpu.layout attribute to an xegpu op operand."; let description = [{ - Given a `xegpu.dpas` operation, this transform adds `xegpu.layout` + Given an xegpu operation, this transform adds `xegpu.layout` attribute to it's operand's tensor descriptor. The target operand is - defined by the `tileIndex` argument. The layout is defined by the + defined by the `operandIndex` argument. The layout is defined by the `sg_layout`, `sg_data` and `inst_data` attributes. }]; - let arguments = (ins TransformHandleTypeInterface : $dpasOp, - I64Attr : $tileIndex, + let arguments = (ins TransformHandleTypeInterface : $target, + I64Attr : $operandIndex, DenseI32ArrayAttr : $sgLayout, DenseI32ArrayAttr : $sgData, DenseI32ArrayAttr : $instData); @@ -63,8 +63,8 @@ def SetDPASLayoutOp : Op(dpasPtr); - if (!dpasOp) { + Operation *targetPtr = *targetOps.begin(); + // For now only DPAS op is supported. + auto targetOp = dyn_cast(targetPtr); + if (!targetOp) { return emitSilenceableFailure(getLoc()) - << "Expected a xegpu.dpas op, but got: " << dpasPtr->getName(); + << "Expected a xegpu.dpas op, but got: " << targetPtr->getName(); } Operation *loopPtr = *loopOps.begin(); @@ -478,16 +479,16 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, << "Expected a scf.for op, but got: " << loopPtr->getName(); } - auto parentLoop = dpasOp->getParentOfType(); + auto parentLoop = targetOp->getParentOfType(); if (!parentLoop || parentLoop != forOp) { return emitSilenceableFailure(getLoc()) - << "dpasOp is not contained in the given scf.for loop."; + << "target op is not contained in the given scf.for loop."; } - int64_t tileIndex = getTileIndex(); - if (tileIndex >= dpasOp.getNumOperands()) { + int64_t operandIndex = getOperandIndex(); + if (operandIndex >= targetOp.getNumOperands()) { return emitSilenceableFailure(getLoc()) - << "tileIndex exceeds the number of op operands."; + << "operandIndex exceeds the number of op operands."; } auto sgLayout = getSgLayout(); @@ -503,8 +504,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, } // Find descriptor op of the operand. - Value opVec = dpasOp.getOperation()->getOperand(tileIndex); - auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation()); + Value opVec = targetOp.getOperation()->getOperand(operandIndex); + auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation()); if (!maybeDescOp) { return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; } @@ -554,42 +555,42 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, newForOp.setLowerBound(forOp.getLowerBound()); // Fuse with the original loop, keep track of cloned ops. - SmallVector sourceOps{dpasOp.getOperation()}, targetOps; - auto fusedLoop = - fuseForLoops(newForOp, forOp, rewriter, sourceOps, targetOps); + SmallVector sourceOps{targetOp.getOperation()}, dstOps; + auto fusedLoop = fuseForLoops(newForOp, forOp, rewriter, sourceOps, dstOps); assert(fusedLoop && "failed to fuse loops"); - // Get the cloned dpas op. - auto clonedDpasOp = targetOps[0]; - if (!clonedDpasOp) { + // Get the cloned target op. + auto clonedTargetOp = dstOps[0]; + if (!clonedTargetOp) { return emitSilenceableFailure(getLoc()) - << "Failed to find cloned dpas op in the fused loop."; + << "Failed to find cloned target op in the fused loop."; } // Map result handles. results.set(cast(getTransformedLoopOp()), {fusedLoop}); - results.set(cast(getTransformedDpasOp()), {clonedDpasOp}); + results.set(cast(getTransformedTargetOp()), {clonedTargetOp}); return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( +DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - auto dpasOp = dyn_cast(target); - if (!dpasOp) { + // For now only DPAS op is supported. + auto targetOp = dyn_cast(target); + if (!targetOp) { auto diag = emitSilenceableFailure(getLoc()) << "Expected a xegpu.dpas op, but got: " << target->getName(); diag.attachNote(target->getLoc()) << "target op"; return diag; } - int64_t tileIndex = getTileIndex(); - if (tileIndex >= dpasOp.getNumOperands()) { + int64_t operandIndex = getOperandIndex(); + if (operandIndex >= targetOp.getNumOperands()) { return emitSilenceableFailure(getLoc()) - << "tileIndex exceeds the number of op operands."; + << "operandIndex exceeds the number of op operands."; } auto sgLayout = getSgLayout(); @@ -611,8 +612,8 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( } // Replace descriptor op using layout attribute. - Value opVec = dpasOp.getOperation()->getOperand(tileIndex); - auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation()); + Value opVec = targetOp.getOperation()->getOperand(operandIndex); + auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation()); if (!maybeDescOp) { return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; } @@ -620,17 +621,17 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne( // Set layout attribute. auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); descOp = setDescLayout(rewriter, descOp, layoutAttr); - if (tileIndex == 2) { + if (operandIndex == 2) { // C operand: set layout attribute for the dpas op result. - xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr); + xegpu::setLayoutAttr(targetOp.getOperation()->getResults()[0], layoutAttr); } return DiagnosedSilenceableFailure::success(); } -void transform::SetDPASLayoutOp::getEffects( +void transform::SetOperandLayoutOp::getEffects( ::llvm::SmallVectorImpl &effects) { - onlyReadsHandle(getDpasOpMutable(), effects); + onlyReadsHandle(getTargetMutable(), effects); modifiesPayload(effects); } diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 6324d80bf4fac..8f097e59f19e2 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -17,13 +17,13 @@ @_ods_cext.register_operation(_Dialect, replace=True) -class SetDPASLayoutOp(SetDPASLayoutOp): - """Specialization for SetDPASLayoutOp class.""" +class SetOperandLayoutOp(SetOperandLayoutOp): + """Specialization for SetOperandLayoutOp class.""" def __init__( self, - dpas_op: Union[Operation, Value], - tile_index: Union[int, Attribute], + target: Union[Operation, Value], + index: Union[int, Attribute], sg_layout: Union[Sequence[int], Attribute], sg_data: Union[Sequence[int], Attribute], inst_data: Union[Sequence[int], Attribute], @@ -32,8 +32,8 @@ def __init__( ip=None, ): super().__init__( - dpas_op, - tile_index, + target, + index, sg_layout, sg_data, inst_data, @@ -48,23 +48,22 @@ class InsertPrefetchOp(InsertPrefetchOp): def __init__( self, - dpas_op: Union[Operation, Value], + target: Union[Operation, Value], loop_op: Union[Operation, Value], - tile_index: Union[int, Attribute], + index: Union[int, Attribute], sg_layout: Union[Sequence[int], Attribute], sg_data: Union[Sequence[int], Attribute], loc=None, ip=None, ): - # results = get_op_result_or_op_results(dpas_op, loop_op) - transformed_dpas_type = transform.AnyOpType.get() + transformed_target_type = transform.AnyOpType.get() transformed_loop_type = transform.AnyOpType.get() super().__init__( - transformed_dpas_type, + transformed_target_type, transformed_loop_type, - _get_op_result_or_value(dpas_op), + _get_op_result_or_value(target), _get_op_result_or_value(loop_op), - tile_index, + index, sg_layout, sg_data, loc=loc, diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index e4aec4f775b31..b6e480bab1f0b 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -27,8 +27,8 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: @set_dpas_layout_a -func.func @set_dpas_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { +// CHECK-LABEL: @set_operand_layout_a +func.func @set_operand_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 // CHECK-SAME: #xegpu.layout> %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> @@ -46,16 +46,16 @@ func.func @set_dpas_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: transform.xegpu.set_dpas_layout %{{.*}} - transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op transform.yield } } // ----- -// CHECK-LABEL: @set_dpas_layout_b -func.func @set_dpas_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { +// CHECK-LABEL: @set_operand_layout_b +func.func @set_operand_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { // CHECK: = xegpu.create_nd_tdesc %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> // CHECK: = xegpu.load_nd @@ -75,16 +75,16 @@ func.func @set_dpas_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: transform.xegpu.set_dpas_layout %{{.*}} - transform.xegpu.set_dpas_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op transform.yield } } // ----- -// CHECK-LABEL: @set_dpas_layout_c -func.func @set_dpas_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { +// CHECK-LABEL: @set_operand_layout_c +func.func @set_operand_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { // CHECK: = xegpu.create_nd_tdesc %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> // CHECK: = xegpu.load_nd @@ -106,8 +106,8 @@ func.func @set_dpas_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: transform.xegpu.set_dpas_layout %{{.*}} - transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op transform.yield } } From 8b11bfd78e935577f402b67f235eabd8ec7441fe Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 18:23:04 +0300 Subject: [PATCH 6/7] xegpu: add tests transform op python bindings --- .../python/dialects/transform_xegpu_ext.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 mlir/test/python/dialects/transform_xegpu_ext.py diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py new file mode 100644 index 0000000000000..673348fc45cb3 --- /dev/null +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -0,0 +1,99 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import xegpu +from mlir.dialects.transform import structured + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def setOperandLayout(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.SetOperandLayoutOp( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOperandLayout + # CHECK: transform.xegpu.set_operand_layout % + # CHECK: index = 0 + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def insertPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + for_op = structured.MatchOp.match_op_names(sequence.bodyTarget, ["scf.for"]) + dpas_op = structured.MatchOp.match_op_names(for_op, ["xegpu.dpas"]) + xegpu.InsertPrefetchOp( + dpas_op, + for_op, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch + # CHECK: %[[FOR_OP:.*]] = transform.structured.match + # CHECK: %[[DPAS_OP:.*]] = transform.structured.match + # CHECK: transform.xegpu.insert_prefetch %[[DPAS_OP]] %[[FOR_OP]] + # CHECK: index = 0 + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + + +@run +def hoistDescOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + xegpu.HoistDescOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: hoistDescOp + # CHECK: transform.xegpu.hoist_desc_ops + + +@run +def setGPULaunchThreadsOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("gpu.lauch"), + ) + with InsertionPoint(sequence.body): + xegpu.SetGPULaunchThreadsOp( + sequence.bodyTarget, + threads=[8, 4, 1] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setGPULaunchThreadsOp + # CHECK: transform.xegpu.set_gpu_launch_threads + # CHECK: threads = [8, 4, 1] From 018491e2c137c86c7f7ed1b206d12304e9b839c9 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Fri, 8 Aug 2025 18:50:11 +0300 Subject: [PATCH 7/7] xegpu: code formatting --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 5 ++--- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 16 +++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index ece830bd6d772..ddbf88538afb0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -110,9 +110,8 @@ def SetGPULaunchThreadsOp let summary = "Set number of threads for a given gpu.launch operation"; let description = [{Set number of threads for a given gpu.launch operation}]; - let arguments = (ins TransformHandleTypeInterface - : $launchOp, DenseI32ArrayAttr - : $threads); + let arguments = (ins TransformHandleTypeInterface : $launchOp, + DenseI32ArrayAttr : $threads); let results = (outs); let assemblyFormat = "$launchOp `threads` `=` $threads attr-dict `:` type($launchOp)"; diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index ec2b7802e9045..47f84658ee414 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -413,10 +413,11 @@ scf::ForOp fuseForLoops(scf::ForOp target, scf::ForOp source, return fusedLoop; } -DiagnosedSilenceableFailure transform::HoistDescOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure +transform::HoistDescOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { auto loopOp = dyn_cast(target); if (!loopOp) { @@ -449,8 +450,8 @@ void transform::HoistDescOp::getEffects( DiagnosedSilenceableFailure transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { + transform::TransformResults &results, + transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); auto loopOps = state.getPayloadOps(getLoopOp()); @@ -619,7 +620,8 @@ DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne( } auto descOp = *maybeDescOp; // Set layout attribute. - auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); + auto layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); descOp = setDescLayout(rewriter, descOp, layoutAttr); if (operandIndex == 2) { // C operand: set layout attribute for the dpas op result.