@@ -452,23 +452,24 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
452452 transform::TransformResults &results,
453453 transform::TransformState &state) {
454454
455- auto dpasOps = state.getPayloadOps (getDpasOp ());
455+ auto targetOps = state.getPayloadOps (getTarget ());
456456 auto loopOps = state.getPayloadOps (getLoopOp ());
457457
458- if (!llvm::hasSingleElement (dpasOps )) {
459- return emitDefiniteFailure () << " requires exactly one dpasOp handle (got "
460- << llvm::range_size (dpasOps ) << " )" ;
458+ if (!llvm::hasSingleElement (targetOps )) {
459+ return emitDefiniteFailure () << " requires exactly one targetOp handle (got "
460+ << llvm::range_size (targetOps ) << " )" ;
461461 }
462462 if (!llvm::hasSingleElement (loopOps)) {
463463 return emitDefiniteFailure () << " requires exactly one loopOp handle (got "
464464 << llvm::range_size (loopOps) << " )" ;
465465 }
466466
467- Operation *dpasPtr = *dpasOps.begin ();
468- auto dpasOp = dyn_cast<xegpu::DpasOp>(dpasPtr);
469- if (!dpasOp) {
467+ Operation *targetPtr = *targetOps.begin ();
468+ // For now only DPAS op is supported.
469+ auto targetOp = dyn_cast<xegpu::DpasOp>(targetPtr);
470+ if (!targetOp) {
470471 return emitSilenceableFailure (getLoc ())
471- << " Expected a xegpu.dpas op, but got: " << dpasPtr ->getName ();
472+ << " Expected a xegpu.dpas op, but got: " << targetPtr ->getName ();
472473 }
473474
474475 Operation *loopPtr = *loopOps.begin ();
@@ -478,16 +479,16 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
478479 << " Expected a scf.for op, but got: " << loopPtr->getName ();
479480 }
480481
481- auto parentLoop = dpasOp ->getParentOfType <scf::ForOp>();
482+ auto parentLoop = targetOp ->getParentOfType <scf::ForOp>();
482483 if (!parentLoop || parentLoop != forOp) {
483484 return emitSilenceableFailure (getLoc ())
484- << " dpasOp is not contained in the given scf.for loop." ;
485+ << " target op is not contained in the given scf.for loop." ;
485486 }
486487
487- int64_t tileIndex = getTileIndex ();
488- if (tileIndex >= dpasOp .getNumOperands ()) {
488+ int64_t operandIndex = getOperandIndex ();
489+ if (operandIndex >= targetOp .getNumOperands ()) {
489490 return emitSilenceableFailure (getLoc ())
490- << " tileIndex exceeds the number of op operands." ;
491+ << " operandIndex exceeds the number of op operands." ;
491492 }
492493
493494 auto sgLayout = getSgLayout ();
@@ -503,8 +504,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
503504 }
504505
505506 // Find descriptor op of the operand.
506- Value opVec = dpasOp .getOperation ()->getOperand (tileIndex );
507- auto maybeDescOp = findDescriptorOp (opVec, dpasOp .getOperation ());
507+ Value opVec = targetOp .getOperation ()->getOperand (operandIndex );
508+ auto maybeDescOp = findDescriptorOp (opVec, targetOp .getOperation ());
508509 if (!maybeDescOp) {
509510 return emitSilenceableFailure (getLoc ()) << " Could not find descriptor op." ;
510511 }
@@ -554,42 +555,42 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
554555 newForOp.setLowerBound (forOp.getLowerBound ());
555556
556557 // Fuse with the original loop, keep track of cloned ops.
557- SmallVector<Operation *> sourceOps{dpasOp.getOperation ()}, targetOps;
558- auto fusedLoop =
559- fuseForLoops (newForOp, forOp, rewriter, sourceOps, targetOps);
558+ SmallVector<Operation *> sourceOps{targetOp.getOperation ()}, dstOps;
559+ auto fusedLoop = fuseForLoops (newForOp, forOp, rewriter, sourceOps, dstOps);
560560 assert (fusedLoop && " failed to fuse loops" );
561561
562- // Get the cloned dpas op.
563- auto clonedDpasOp = targetOps [0 ];
564- if (!clonedDpasOp ) {
562+ // Get the cloned target op.
563+ auto clonedTargetOp = dstOps [0 ];
564+ if (!clonedTargetOp ) {
565565 return emitSilenceableFailure (getLoc ())
566- << " Failed to find cloned dpas op in the fused loop." ;
566+ << " Failed to find cloned target op in the fused loop." ;
567567 }
568568
569569 // Map result handles.
570570 results.set (cast<OpResult>(getTransformedLoopOp ()), {fusedLoop});
571- results.set (cast<OpResult>(getTransformedDpasOp ()), {clonedDpasOp });
571+ results.set (cast<OpResult>(getTransformedTargetOp ()), {clonedTargetOp });
572572
573573 return DiagnosedSilenceableFailure::success ();
574574}
575575
576- DiagnosedSilenceableFailure transform::SetDPASLayoutOp ::applyToOne (
576+ DiagnosedSilenceableFailure transform::SetOperandLayoutOp ::applyToOne (
577577 transform::TransformRewriter &rewriter, Operation *target,
578578 transform::ApplyToEachResultList &results,
579579 transform::TransformState &state) {
580580
581- auto dpasOp = dyn_cast<xegpu::DpasOp>(target);
582- if (!dpasOp) {
581+ // For now only DPAS op is supported.
582+ auto targetOp = dyn_cast<xegpu::DpasOp>(target);
583+ if (!targetOp) {
583584 auto diag = emitSilenceableFailure (getLoc ())
584585 << " Expected a xegpu.dpas op, but got: " << target->getName ();
585586 diag.attachNote (target->getLoc ()) << " target op" ;
586587 return diag;
587588 }
588589
589- int64_t tileIndex = getTileIndex ();
590- if (tileIndex >= dpasOp .getNumOperands ()) {
590+ int64_t operandIndex = getOperandIndex ();
591+ if (operandIndex >= targetOp .getNumOperands ()) {
591592 return emitSilenceableFailure (getLoc ())
592- << " tileIndex exceeds the number of op operands." ;
593+ << " operandIndex exceeds the number of op operands." ;
593594 }
594595
595596 auto sgLayout = getSgLayout ();
@@ -611,26 +612,26 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne(
611612 }
612613
613614 // Replace descriptor op using layout attribute.
614- Value opVec = dpasOp .getOperation ()->getOperand (tileIndex );
615- auto maybeDescOp = findDescriptorOp (opVec, dpasOp .getOperation ());
615+ Value opVec = targetOp .getOperation ()->getOperand (operandIndex );
616+ auto maybeDescOp = findDescriptorOp (opVec, targetOp .getOperation ());
616617 if (!maybeDescOp) {
617618 return emitSilenceableFailure (getLoc ()) << " Could not find descriptor op." ;
618619 }
619620 auto descOp = *maybeDescOp;
620621 // Set layout attribute.
621622 auto layoutAttr = createLayoutAttr (rewriter.getContext (), sgLayout, sgData, instData);
622623 descOp = setDescLayout (rewriter, descOp, layoutAttr);
623- if (tileIndex == 2 ) {
624+ if (operandIndex == 2 ) {
624625 // C operand: set layout attribute for the dpas op result.
625- xegpu::setLayoutAttr (dpasOp .getOperation ()->getResults ()[0 ], layoutAttr);
626+ xegpu::setLayoutAttr (targetOp .getOperation ()->getResults ()[0 ], layoutAttr);
626627 }
627628
628629 return DiagnosedSilenceableFailure::success ();
629630}
630631
631- void transform::SetDPASLayoutOp ::getEffects (
632+ void transform::SetOperandLayoutOp ::getEffects (
632633 ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
633- onlyReadsHandle (getDpasOpMutable (), effects);
634+ onlyReadsHandle (getTargetMutable (), effects);
634635 modifiesPayload (effects);
635636}
636637
0 commit comments