@@ -419,6 +419,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
419419 }
420420};
421421
422+ // TODO: AMDGPU backend already have all this bitpacking logic, we should move
423+ // it to some common place.
424+ // / Vmcnt, Expcnt and Lgkmcnt are decoded as follows:
425+ // / Vmcnt = Waitcnt[3:0] (pre-gfx9)
426+ // / Vmcnt = Waitcnt[15:14,3:0] (gfx9,10)
427+ // / Vmcnt = Waitcnt[15:10] (gfx11)
428+ // / Expcnt = Waitcnt[6:4] (pre-gfx11)
429+ // / Expcnt = Waitcnt[2:0] (gfx11)
430+ // / Lgkmcnt = Waitcnt[11:8] (pre-gfx10)
431+ // / Lgkmcnt = Waitcnt[13:8] (gfx10)
432+ // / Lgkmcnt = Waitcnt[9:4] (gfx11)
433+ static FailureOr<unsigned > encodeWaitcnt (Chipset chipset, unsigned vmcnt,
434+ unsigned expcnt, unsigned lgkmcnt) {
435+ if (chipset.majorVersion < 9 ) {
436+ vmcnt = std::min (15u , vmcnt);
437+ expcnt = std::min (7u , expcnt);
438+ lgkmcnt = std::min (15u , lgkmcnt);
439+ return vmcnt | (expcnt << 4 ) | (lgkmcnt << 8 );
440+ }
441+ if (chipset.majorVersion == 9 ) {
442+ vmcnt = std::min (63u , vmcnt);
443+ expcnt = std::min (7u , expcnt);
444+ lgkmcnt = std::min (15u , lgkmcnt);
445+ unsigned lowBits = vmcnt & 0xF ;
446+ unsigned highBits = (vmcnt >> 4 ) << 14 ;
447+ unsigned otherCnts = (expcnt << 4 ) | (lgkmcnt << 8 );
448+ return lowBits | highBits | otherCnts;
449+ }
450+ if (chipset.majorVersion == 10 ) {
451+ vmcnt = std::min (63u , vmcnt);
452+ expcnt = std::min (7u , expcnt);
453+ lgkmcnt = std::min (63u , lgkmcnt);
454+ unsigned lowBits = vmcnt & 0xF ;
455+ unsigned highBits = (vmcnt >> 4 ) << 14 ;
456+ unsigned otherCnts = (expcnt << 4 ) | (lgkmcnt << 8 );
457+ return lowBits | highBits | otherCnts;
458+ }
459+ if (chipset.majorVersion == 11 ) {
460+ vmcnt = std::min (63u , vmcnt);
461+ expcnt = std::min (7u , expcnt);
462+ lgkmcnt = std::min (63u , lgkmcnt);
463+ return (vmcnt << 10 ) | expcnt | (lgkmcnt << 4 );
464+ }
465+ return failure ();
466+ }
467+
468+ struct MemoryCounterWaitOpLowering
469+ : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> {
470+ MemoryCounterWaitOpLowering (const LLVMTypeConverter &converter,
471+ Chipset chipset)
472+ : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
473+ chipset (chipset) {}
474+
475+ Chipset chipset;
476+
477+ LogicalResult
478+ matchAndRewrite (MemoryCounterWaitOp op, OpAdaptor adaptor,
479+ ConversionPatternRewriter &rewriter) const override {
480+ if (chipset.majorVersion >= 12 ) {
481+ Location loc = op.getLoc ();
482+ if (std::optional<int > ds = adaptor.getDs ())
483+ rewriter.create <ROCDL::WaitDscntOp>(loc, *ds);
484+
485+ if (std::optional<int > load = adaptor.getLoad ())
486+ rewriter.create <ROCDL::WaitLoadcntOp>(loc, *load);
487+
488+ if (std::optional<int > store = adaptor.getStore ())
489+ rewriter.create <ROCDL::WaitStorecntOp>(loc, *store);
490+
491+ if (std::optional<int > exp = adaptor.getExp ())
492+ rewriter.create <ROCDL::WaitExpcntOp>(loc, *exp);
493+
494+ rewriter.eraseOp (op);
495+ return success ();
496+ }
497+
498+ auto getVal = [](Attribute attr) -> unsigned {
499+ if (attr)
500+ return cast<IntegerAttr>(attr).getInt ();
501+
502+ // This value will be clamped to the maximum value for the chipset.
503+ return 1024 ;
504+ };
505+ unsigned ds = getVal (adaptor.getDsAttr ());
506+ unsigned exp = getVal (adaptor.getExpAttr ());
507+
508+ unsigned vmcnt = 1024 ;
509+ Attribute load = adaptor.getLoadAttr ();
510+ Attribute store = adaptor.getStoreAttr ();
511+ if (load && store) {
512+ vmcnt = getVal (load) + getVal (store);
513+ } else if (load) {
514+ vmcnt = getVal (load);
515+ } else if (store) {
516+ vmcnt = getVal (store);
517+ }
518+
519+ FailureOr<unsigned > waitcnt = encodeWaitcnt (chipset, vmcnt, exp, ds);
520+ if (failed (waitcnt))
521+ return op.emitOpError (" unsupported chipset" );
522+
523+ rewriter.replaceOpWithNewOp <ROCDL::SWaitcntOp>(op, *waitcnt);
524+ return success ();
525+ }
526+ };
527+
422528struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern <LDSBarrierOp> {
423529 LDSBarrierOpLowering (const LLVMTypeConverter &converter, Chipset chipset)
424530 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
@@ -1825,9 +1931,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
18251931 ROCDL::RawPtrBufferAtomicUminOp>,
18261932 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
18271933 ROCDL::RawPtrBufferAtomicCmpSwap>,
1828- AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering ,
1829- MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering ,
1830- ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1934+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering ,
1935+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering ,
1936+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
18311937 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
18321938 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
18331939 TransposeLoadOpLowering>(converter, chipset);
0 commit comments