diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..5924606402a02 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td) +mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls) +mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen) + +add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h new file mode 100644 index 0000000000000..25b58273b95b2 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h @@ -0,0 +1,29 @@ +//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H +#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" + +#define GET_OP_CLASSES +#include + +namespace mlir { +class DialectRegistry; + +namespace xegpu { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td new file mode 100644 index 0000000000000..ddbf88538afb0 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -0,0 +1,128 @@ +//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef XEGPU_EXTENSION +#define XEGPU_EXTENSION + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def HoistDescOp : Op +]> { + + let summary = "Hoists xegpu tile descriptor ops outside the containing loop"; + let description = [{ + Hoists `xepu.create_nd_tdesc` out of the loop. If the + descriptor's offset is loop dependent, a `xegpu.update_nd_offset` op is + inserted in the loop to increment the offset. + }]; + + let arguments = (ins TransformHandleTypeInterface : $loop); + let results = (outs TransformHandleTypeInterface : $transformed); + + let assemblyFormat = "$loop attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +def SetOperandLayoutOp : Op +]> { + + let summary = "Set xegpu.layout attribute to an xegpu op operand."; + let description = [{ + Given an xegpu operation, this transform adds `xegpu.layout` + attribute to it's operand's tensor descriptor. The target operand is + defined by the `operandIndex` argument. The layout is defined by the + `sg_layout`, `sg_data` and `inst_data` attributes. + }]; + + let arguments = (ins TransformHandleTypeInterface : $target, + I64Attr : $operandIndex, + DenseI32ArrayAttr : $sgLayout, + DenseI32ArrayAttr : $sgData, + DenseI32ArrayAttr : $instData); + + let results = (outs); + + let assemblyFormat = + "$target `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` " + "$sgData `inst_data` `=` $instData attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +def InsertPrefetchOp : Op]> { + + let summary = "Adds xegpu prefetch ops to matmul operand tiles."; + let description = [{ + Given an xegpu operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes. + }]; + + let arguments = (ins TransformHandleTypeInterface : $target, + TransformHandleTypeInterface : $loopOp, + I64Attr : $operandIndex, + DenseI32ArrayAttr : $sgLayout, + DenseI32ArrayAttr : $sgData); + + let results = (outs TransformHandleTypeInterface : $transformedTargetOp, + TransformHandleTypeInterface : $transformedLoopOp); + + let assemblyFormat = + "$target $loopOp `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` " + "$sgData attr-dict `:` functional-type(operands, results)"; +} + +// TODO this should be handled with gpu transform ops. +// Add gpu mapping to scf.forall op and use something like +// transform.gpu.map_forall_to_blocks to convert to gpu.launch op. +def SetGPULaunchThreadsOp + : Op + ]> { + + let summary = "Set number of threads for a given gpu.launch operation"; + let description = [{Set number of threads for a given gpu.launch operation}]; + + let arguments = (ins TransformHandleTypeInterface : $launchOp, + DenseI32ArrayAttr : $threads); + let results = (outs); + let assemblyFormat = + "$launchOp `threads` `=` $threads attr-dict `:` type($launchOp)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter & rewriter, + ::mlir::Operation * target, + ::mlir::transform::ApplyToEachResultList & results, + ::mlir::transform::TransformState & state); + }]; +} + +#endif // XEGPU_EXTENSION diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c3aeba7..8adac87014486 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -55,6 +55,7 @@ #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -114,6 +115,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { vector::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); + xegpu::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6af908b..46b8251a57797 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..48fe841afaa83 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000000000..47f84658ee414 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,675 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "xegpu-transforms" + +using namespace mlir; + +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include + >(); +} + +#define GET_OP_CLASSES +#include + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} + +/// Recurse operands and collect all producer ops in the given region. +void collectProducerOps(Operation *op, Region &inRegion, + SmallVector &ops) { + for (auto val : op->getOperands()) { + if (const auto definingOp = val.getDefiningOp(); + definingOp && definingOp->getParentRegion() == &inRegion) { + ops.push_back(definingOp); + collectProducerOps(definingOp, inRegion, ops); + } + } +} + +/// Returns all producer ops in the given region +SmallVector getProducerOpsInRegion(Operation *op, Region &inRegion, + bool includeOp = true) { + SmallVector producerOps; + if (includeOp) { + producerOps.push_back(op); + } + collectProducerOps(op, inRegion, producerOps); + return producerOps; +} + +/// Find xegpu.create_nd_desc op for the given operand value. +static std::optional +findDescriptorOp(Value operandValue, Operation *userOp) { + // FIXME more generic way of finding desc op that may be outside the loop + Value currentValue = operandValue; + if (!currentValue.getDefiningOp()) { + // Desc op may reside outside a loop. + auto forOp = userOp->getParentOfType(); + if (!forOp) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to find operand desc op, def op not a loop."); + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast(currentValue)) { + auto numInductionVars = forOp.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = forOp.getInits()[iterArgIdx]; + } else { + LLVM_DEBUG(llvm::dbgs() + << "Failed to find operand desc op, def op not an init val."); + return std::nullopt; + } + } + auto findDescOp = [](Value val) -> std::optional { + Operation *producerOp = val.getDefiningOp(); + while (producerOp) { + if (auto maybeDescOp = dyn_cast(producerOp)) { + return maybeDescOp; + } + if (producerOp->getNumOperands() == 0) + break; + producerOp = producerOp->getOperand(0).getDefiningOp(); + } + return std::nullopt; + }; + return findDescOp(currentValue); +} + +// Get user of type T in immediate users of the value. +template +static std::optional getUserOfType(Value value) { + auto users = value.getUsers(); + auto it = llvm::find_if(users, [&](Operation *op) { return isa(op); }); + if (it != users.end()) { + return cast(*it); + } + return std::nullopt; +} + +/// Add offset update op after create desc op if tile is updated in the loop. +xegpu::CreateNdDescOp insertUpdateOp(transform::TransformRewriter &rewriter, + scf::ForOp parentLoopOp, + xegpu::CreateNdDescOp descOp) { + // DescOp offset is an affine map with loop dependent and independent + // components. The new desc op will be loop independent, i.e. it uses the + // constant offset. The remainder offset, 'offset - constant', will be used + // in the update offset op. + + // Compute the constant offset. + // Clone desc op producers and replace loop variable with lower bound. + rewriter.setInsertionPointAfter(descOp); + auto loc = descOp.getLoc(); + IRMapping mapping; + SmallVector clonedOps; + auto producers = getProducerOpsInRegion(descOp.getOperation(), + parentLoopOp.getRegion(), true); + for (auto &op : llvm::reverse(producers)) { + auto newOp = rewriter.clone(*op, mapping); + clonedOps.push_back(newOp); + } + // Replace loop induction variable. + rewriter.replaceUsesWithIf(parentLoopOp.getInductionVar(), + parentLoopOp.getLowerBound(), [&](OpOperand &use) { + return ::llvm::is_contained(clonedOps, + use.getOwner()); + }); + auto newDescOp = cast(clonedOps.back()); + + // Compute offset for update operation: original offset - constant offset. + llvm::SmallVector origDynamicOffsets, constDynamicOffsets, + dynamicOffsets; + llvm::SmallVector origStaticOffsets, constStaticOffsets, + staticOffsets; + dispatchIndexOpFoldResults(descOp.getMixedOffsets(), origDynamicOffsets, + origStaticOffsets); + dispatchIndexOpFoldResults(newDescOp.getMixedOffsets(), constDynamicOffsets, + constStaticOffsets); + int64_t dynIndex = 0; + for (auto [i, origStaticOffset] : llvm::enumerate(origStaticOffsets)) { + if (origStaticOffset != ShapedType::kDynamic) { + // Original offset was a constant, difference must be 0. + staticOffsets.push_back(0); + } else { + auto origDynOffset = origDynamicOffsets[dynIndex]; + auto cstDynOffset = constDynamicOffsets[dynIndex]; + auto subValue = rewriter.createOrFold( + loc, origDynOffset.getType(), origDynOffset, cstDynOffset); + auto maybeIntValue = getConstantIntValue(subValue); + if (maybeIntValue) { + // Folded to a constant int. + staticOffsets.push_back(*maybeIntValue); + } else { + // Dynamic offset. + dynamicOffsets.push_back(subValue); + staticOffsets.push_back(ShapedType::kDynamic); + } + dynIndex++; + } + } + + // Insert an offset update op if non-trivial offset. + bool allZeros = llvm::all_of(staticOffsets, [](int64_t s) { return s == 0; }); + if (!dynamicOffsets.empty() || !allZeros) { + auto tile = newDescOp.getResult(); + auto offsetOp = rewriter.create( + loc, tile.getType(), tile, dynamicOffsets, staticOffsets); + // replace subsequent uses of the descriptor with the offset descriptor + rewriter.replaceUsesWithIf( + descOp.getResult(), offsetOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != offsetOp.getOperation(); + }); + } + rewriter.replaceOp(descOp, newDescOp); + return newDescOp; +} + +/// Add offset update ops after create desc ops in the loop body. +LogicalResult insertOffsetUpdateOps(transform::TransformRewriter &rewriter, + scf::ForOp loopOp) { + // Find all create desc operations in the loop body + SmallVector createDescOps; + for (auto &op : loopOp.getBody()->getOperations()) { + if (isa(op)) { + createDescOps.push_back(&op); + } + } + if (createDescOps.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No xegpu.create_nd_desc ops found in the loop body.\n"); + return failure(); + } + // Split to desc and offset update ops. + for (auto &op : createDescOps) { + auto descOp = cast(op); + insertUpdateOp(rewriter, loopOp, descOp); + } + return success(); +} + +/// Check if an op can be hoisted out of the loop. +static bool canBeHoisted(Operation *op, LoopLikeOpInterface &loopLike) { + return llvm::all_of(op->getOperands(), [&](Value value) { + return loopLike.isDefinedOutsideOfLoop(value); + }); +} + +/// Hoist create desc ops out of the loop. +/// If offset update ops exist, add values to loop iter_args and yield +FailureOr hoistDescOps(transform::TransformRewriter &rewriter, + scf::ForOp loopOp) { + SmallVector descOps; + auto loopLike = cast(loopOp.getOperation()); + for (auto &op : loopOp.getBody()->getOperations()) { + if (auto descOp = dyn_cast(op)) { + if (canBeHoisted(descOp.getOperation(), loopLike)) { + descOps.push_back(descOp); + } + } + } + if (descOps.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No hoistable create_nd_desc ops found in the loop body.\n"); + return loopOp; + } + + SmallVector initValues, yieldValues; + for (auto &descOp : descOps) { + // We assume tensor desc is used by an offset update op, find it. + auto maybeOffsetOp = getUserOfType(descOp.getResult()); + if (!maybeOffsetOp) { + continue; + } + auto offsetOp = *maybeOffsetOp; + + // Hoist desc op. + auto producers = + getProducerOpsInRegion(descOp.getOperation(), loopOp.getRegion(), true); + for (auto &op : llvm::reverse(producers)) { + rewriter.moveOpBefore(op, loopOp); + } + + // Offset update op must be converted to increment the offset, instead of + // defining an absolute offset wrt the original descriptor tile. + // In offset update producer ops, replace loop variable with step size. + auto offsetProducerOps = + getProducerOpsInRegion(offsetOp.getOperation(), loopOp.getRegion()); + rewriter.replaceUsesWithIf( + loopOp.getInductionVar(), loopOp.getStep(), [&](OpOperand &use) { + return llvm::is_contained(offsetProducerOps, use.getOwner()); + }); + // Offsetted desc now points to next tile, users must use the current tile + rewriter.replaceAllUsesWith(offsetOp.getResult(), offsetOp.getTensorDesc()); + // Add to loop init/yield values. + initValues.push_back(descOp.getResult()); + yieldValues.push_back(offsetOp.getResult()); + } + // Rewrite loop with new init/yield values. + NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, + llvm::ArrayRef newBBArgs) { + return yieldValues; + }; + auto maybeNewLoop = loopOp.replaceWithAdditionalYields( + rewriter, initValues, + /*replaceInitOperandUsesInLoop=*/true, yieldFn); + if (failed(maybeNewLoop)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to generate a new loop.\n"); + return failure(); + } + return cast(*maybeNewLoop); +} + +/// Create a layout attribute from the given parameters. +xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef sgLayout, + ArrayRef sgData, + std::optional> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, + xegpu::LayoutAttr layout) { + auto ctx = rewriter.getContext(); + auto oldTensorDesc = descOp.getResult(); + auto descShapedType = cast(oldTensorDesc.getType()); + // This discards any block_tdesc_attr attributes. + auto descType = xegpu::TensorDescType::get(ctx, descShapedType.getShape(), + descShapedType.getElementType(), + /*encoding=*/nullptr, + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp( + descOp, descType, descOp.getSource(), descOp.getMixedOffsets(), + descOp.getMixedSizes(), descOp.getMixedStrides()); + + return newDescOp; +} + +/// Fuse two scf.for loops into one. Keeps track of source operations to their +/// cloned targets. Returns the new fused loop. +scf::ForOp fuseForLoops(scf::ForOp target, scf::ForOp source, + RewriterBase &rewriter, + SmallVector &sourceOps, + SmallVector &targetOps) { + // This method is modified from mlir::fuseIndependentSiblingForLoops to + // trace the source ops to their cloned targets. + + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused init_args, with target's init_args before source's init_args. + SmallVector fusedInitArgs; + llvm::append_range(fusedInitArgs, target.getInitArgs()); + llvm::append_range(fusedInitArgs, source.getInitArgs()); + + // Create a new scf.for op after the source loop (with scf.yield terminator + // (without arguments) only in case its init_args is empty). + rewriter.setInsertionPointAfter(source); + scf::ForOp fusedLoop = rewriter.create( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), fusedInitArgs); + + // Map original induction variables and operands to those of the fused loop. + IRMapping mapping; + mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Merge target's body into the new (fused) for loop and then source's body. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + IRMapping clonedOpsMapping; + for (Operation &op : target.getBody()->without_terminator()) { + auto newOp = rewriter.clone(op, mapping); + clonedOpsMapping.map(&op, newOp); + } + for (Operation &op : source.getBody()->without_terminator()) { + auto newOp = rewriter.clone(op, mapping); + clonedOpsMapping.map(&op, newOp); + } + // Map the given source operations to their cloned targets. + auto opsMap = clonedOpsMapping.getOperationMap(); + for (Operation *op : sourceOps) { + auto it = opsMap.find(op); + if (it != opsMap.end()) { + targetOps.push_back(it->second); + } else { + targetOps.push_back(nullptr); + } + } + + // Build fused yield results by appropriately mapping original yield operands. + SmallVector yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create(source.getLoc(), yieldResults); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); + + return fusedLoop; +} + +DiagnosedSilenceableFailure +transform::HoistDescOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + auto loopOp = dyn_cast(target); + if (!loopOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a scf.for op, but got: " << target->getName(); + } + + if (failed(insertOffsetUpdateOps(rewriter, loopOp))) { + return emitSilenceableFailure(getLoc()) + << "No desc ops found in the loop body " << target->getName(); + } + auto newLoopOp = hoistDescOps(rewriter, loopOp); + if (failed(newLoopOp)) { + auto diag = emitSilenceableFailure(getLoc()) + << "Failed to hoist xegpu.create_nd_desc ops"; + diag.attachNote(loopOp.getLoc()) << "loop op"; + return diag; + } + loopOp = *newLoopOp; + results.push_back(loopOp.getOperation()); + return DiagnosedSilenceableFailure::success(); +} + +void transform::HoistDescOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + consumesHandle(getLoopMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + + auto targetOps = state.getPayloadOps(getTarget()); + auto loopOps = state.getPayloadOps(getLoopOp()); + + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + if (!llvm::hasSingleElement(loopOps)) { + return emitDefiniteFailure() << "requires exactly one loopOp handle (got " + << llvm::range_size(loopOps) << ")"; + } + + Operation *targetPtr = *targetOps.begin(); + // For now only DPAS op is supported. + auto targetOp = dyn_cast(targetPtr); + if (!targetOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a xegpu.dpas op, but got: " << targetPtr->getName(); + } + + Operation *loopPtr = *loopOps.begin(); + auto forOp = dyn_cast(loopPtr); + if (!forOp) { + return emitSilenceableFailure(getLoc()) + << "Expected a scf.for op, but got: " << loopPtr->getName(); + } + + auto parentLoop = targetOp->getParentOfType(); + if (!parentLoop || parentLoop != forOp) { + return emitSilenceableFailure(getLoc()) + << "target op is not contained in the given scf.for loop."; + } + + int64_t operandIndex = getOperandIndex(); + if (operandIndex >= targetOp.getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "operandIndex exceeds the number of op operands."; + } + + auto sgLayout = getSgLayout(); + if (sgLayout.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_layout to be a 2D vector"; + } + + auto sgData = getSgData(); + if (sgData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_data to be a 2D vector"; + } + + // Find descriptor op of the operand. + Value opVec = targetOp.getOperation()->getOperand(operandIndex); + auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + + // Clone reduction loop. + rewriter.setInsertionPoint(forOp); + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep()); + // Clone desc op into it. + rewriter.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + auto newDescOp = cast( + rewriter.clone(*descOp.getOperation(), mapping)); + // Set desc op layout. + auto layout = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, + /*instData=*/std::nullopt); + newDescOp = setDescLayout(rewriter, newDescOp, layout); + + // Insert prefetch op. + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + rewriter.create(newDescOp.getLoc(), + newDescOp.getResult(), readCacheHint, + readCacheHint, readCacheHint); + + // Insert offset update op. + insertUpdateOp(rewriter, newForOp, newDescOp); + // Hoist descriptor op out of the loop. + auto maybenewForOp = hoistDescOps(rewriter, newForOp); + if (failed(maybenewForOp)) { + auto diag = emitSilenceableFailure(getLoc()) + << "Failed to hoist xegpu.create_nd_desc ops"; + diag.attachNote(newForOp.getLoc()) << "loop op"; + return diag; + } + newForOp = *maybenewForOp; + + // Peel first iteration of the loop and reset lower bound to original value. + scf::ForOp firstLoopOp; + if (failed(scf::peelForLoopFirstIteration(rewriter, newForOp, firstLoopOp))) { + auto diag = emitSilenceableFailure(getLoc()) << "Failed to peel the loop"; + } + newForOp.setLowerBound(forOp.getLowerBound()); + + // Fuse with the original loop, keep track of cloned ops. + SmallVector sourceOps{targetOp.getOperation()}, dstOps; + auto fusedLoop = fuseForLoops(newForOp, forOp, rewriter, sourceOps, dstOps); + assert(fusedLoop && "failed to fuse loops"); + + // Get the cloned target op. + auto clonedTargetOp = dstOps[0]; + if (!clonedTargetOp) { + return emitSilenceableFailure(getLoc()) + << "Failed to find cloned target op in the fused loop."; + } + + // Map result handles. + results.set(cast(getTransformedLoopOp()), {fusedLoop}); + results.set(cast(getTransformedTargetOp()), {clonedTargetOp}); + + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + // For now only DPAS op is supported. + auto targetOp = dyn_cast(target); + if (!targetOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.dpas op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + int64_t operandIndex = getOperandIndex(); + if (operandIndex >= targetOp.getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "operandIndex exceeds the number of op operands."; + } + + auto sgLayout = getSgLayout(); + if (sgLayout.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_layout to be a 2D vector"; + } + + auto sgData = getSgData(); + if (sgData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected sg_data to be a 2D vector"; + } + + auto instData = getInstData(); + if (instData.size() != 2) { + return emitSilenceableFailure(getLoc()) + << "Expected inst_data to be a 2D vector"; + } + + // Replace descriptor op using layout attribute. + Value opVec = targetOp.getOperation()->getOperand(operandIndex); + auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + // Set layout attribute. + auto layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData); + descOp = setDescLayout(rewriter, descOp, layoutAttr); + if (operandIndex == 2) { + // C operand: set layout attribute for the dpas op result. + xegpu::setLayoutAttr(targetOp.getOperation()->getResults()[0], layoutAttr); + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetOperandLayoutOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure transform::SetGPULaunchThreadsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + auto launchOp = dyn_cast(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + auto threads = getThreads(); + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads to be a 3D vector"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return rewriter.create(launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getLaunchOpMutable(), effects); + modifiesPayload(effects); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..a6a2d6666b327 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -303,6 +303,15 @@ declare_mlir_dialect_extension_python_bindings( "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" ) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/XeGPUTransformOps.td + SOURCES + dialects/transform/xegpu.py + DIALECT_NAME transform + EXTENSION_NAME xegpu_transform) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td new file mode 100644 index 0000000000000..5a5e7b912c4a5 --- /dev/null +++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td @@ -0,0 +1,19 @@ +//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the XeGPU transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS + +include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td" + +#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py new file mode 100644 index 0000000000000..8f097e59f19e2 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -0,0 +1,109 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._xegpu_transform_ops_gen import * +from .._xegpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext + from .._ods_common import get_op_result_or_value as _get_op_result_or_value +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SetOperandLayoutOp(SetOperandLayoutOp): + """Specialization for SetOperandLayoutOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + index: Union[int, Attribute], + sg_layout: Union[Sequence[int], Attribute], + sg_data: Union[Sequence[int], Attribute], + inst_data: Union[Sequence[int], Attribute], + *, + loc=None, + ip=None, + ): + super().__init__( + target, + index, + sg_layout, + sg_data, + inst_data, + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InsertPrefetchOp(InsertPrefetchOp): + """Specialization for InsertPrefetchOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + loop_op: Union[Operation, Value], + index: Union[int, Attribute], + sg_layout: Union[Sequence[int], Attribute], + sg_data: Union[Sequence[int], Attribute], + loc=None, + ip=None, + ): + transformed_target_type = transform.AnyOpType.get() + transformed_loop_type = transform.AnyOpType.get() + super().__init__( + transformed_target_type, + transformed_loop_type, + _get_op_result_or_value(target), + _get_op_result_or_value(loop_op), + index, + sg_layout, + sg_data, + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class HoistDescOp(HoistDescOp): + """Specialization for HoistDescOp class.""" + + def __init__( + self, + loop_op: Union[Operation, Value], + loc=None, + ip=None, + ): + transformed_loop_type = transform.AnyOpType.get() + super().__init__( + transformed_loop_type, + _get_op_result_or_value(loop_op), + loc=loc, + ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp): + """Specialization for SetGPULaunchThreadsOp class.""" + + def __init__( + self, + launch_op: Union[Operation, Value], + threads: Union[int, Attribute], + loc=None, + ip=None, + ): + super().__init__( + _get_op_result_or_value(launch_op), + threads, + loc=loc, + ip=ip + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir new file mode 100644 index 0000000000000..b6e480bab1f0b --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -0,0 +1,183 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @hoist_desc_ops +func.func @hoist_desc_ops(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + // CHECK: xegpu.create_nd_tdesc + // CHECK-NEXT: scf.for + scf.for %arg1 = %c0 to %c4096 step %c32 { + // CHECK: xegpu.update_nd_offset + // CHECK: xegpu.load_nd + %0 = xegpu.create_nd_tdesc %arg0[0, %arg1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: = transform.xegpu.hoist_desc_ops %{{.*}} + %1 = transform.xegpu.hoist_desc_ops %0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_operand_layout_a +func.func @set_operand_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V1]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_operand_layout_b +func.func @set_operand_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: = xegpu.create_nd_tdesc + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: = xegpu.load_nd + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg1 + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK-SAME: #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %1, %[[V1]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_operand_layout_c +func.func @set_operand_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: = xegpu.create_nd_tdesc + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + // CHECK: = xegpu.load_nd + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = xegpu.create_nd_tdesc + %2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: = xegpu.load_nd + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg2 + %4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // CHECK-SAME: #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %1, %3, %[[V1]] {layout_result_0 = #xegpu.layout} + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_operand_layout %{{.*}} + transform.xegpu.set_operand_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: %[[C32:.+]] = arith.constant 32 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16, #xegpu.layout> + // Peeled first iteration of the loop, canonicalization drops the scf.for + // CHECK: %[[V2:.+]] = scf.for + // CHECK-SAME: iter_args(%[[V1:.+]] = %[[V0]]) + // CHECK: = xegpu.update_nd_offset %[[V1]], [0, %[[C32]]] + // CHECK: xegpu.prefetch_nd %[[V1]] + // Reduction loop + // CHECK: scf.for + // CHECK-SAME: iter_args(%[[V3:.+]] = %[[V2]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: = xegpu.update_nd_offset %[[V3]], [0, %[[C32]]] + // CHECK: xegpu.prefetch_nd %[[V3]] + %3 = xegpu.create_nd_tdesc %arg0[0, %arg3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.load_nd %3 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %5 = xegpu.create_nd_tdesc %arg1[%arg3, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %4, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["xegpu.dpas"]} in %0 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.insert_prefetch %{{.*}} %{{.*}} + %2, %3 = transform.xegpu.insert_prefetch %1 %0 index = 0 sg_layout = [32, 1] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads +func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } {SCFToGPU_visited} + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py new file mode 100644 index 0000000000000..673348fc45cb3 --- /dev/null +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -0,0 +1,99 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import xegpu +from mlir.dialects.transform import structured + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def setOperandLayout(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.SetOperandLayoutOp( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOperandLayout + # CHECK: transform.xegpu.set_operand_layout % + # CHECK: index = 0 + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def insertPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + for_op = structured.MatchOp.match_op_names(sequence.bodyTarget, ["scf.for"]) + dpas_op = structured.MatchOp.match_op_names(for_op, ["xegpu.dpas"]) + xegpu.InsertPrefetchOp( + dpas_op, + for_op, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch + # CHECK: %[[FOR_OP:.*]] = transform.structured.match + # CHECK: %[[DPAS_OP:.*]] = transform.structured.match + # CHECK: transform.xegpu.insert_prefetch %[[DPAS_OP]] %[[FOR_OP]] + # CHECK: index = 0 + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + + +@run +def hoistDescOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + xegpu.HoistDescOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: hoistDescOp + # CHECK: transform.xegpu.hoist_desc_ops + + +@run +def setGPULaunchThreadsOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("gpu.lauch"), + ) + with InsertionPoint(sequence.body): + xegpu.SetGPULaunchThreadsOp( + sequence.bodyTarget, + threads=[8, 4, 1] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setGPULaunchThreadsOp + # CHECK: transform.xegpu.set_gpu_launch_threads + # CHECK: threads = [8, 4, 1]