@@ -325,14 +325,15 @@ xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
325325xegpu::CreateNdDescOp setDescLayout (transform::TransformRewriter &rewriter,
326326 xegpu::CreateNdDescOp descOp,
327327 xegpu::LayoutAttr layout) {
328- auto ctx = rewriter.getContext ();
329328 auto oldTensorDesc = descOp.getResult ();
330329 auto descShapedType = cast<ShapedType>(oldTensorDesc.getType ());
331- // This discards any block_tdesc_attr attributes.
332- auto descType = xegpu::TensorDescType::get (ctx, descShapedType.getShape (),
333- descShapedType.getElementType (),
334- /* encoding=*/ nullptr ,
335- /* layout=*/ layout);
330+ // TODO inherit desc attributes from old op (if any)
331+ auto descType = xegpu::TensorDescType::get (
332+ descShapedType.getShape (), descShapedType.getElementType (),
333+ /* array_length=*/ 1 ,
334+ /* boundary_check=*/ true ,
335+ /* memory_space=*/ xegpu::MemorySpace::Global,
336+ /* layout=*/ layout);
336337
337338 rewriter.setInsertionPointAfter (descOp);
338339 auto newDescOp = rewriter.replaceOpWithNewOp <xegpu::CreateNdDescOp>(
@@ -589,7 +590,7 @@ transform::GetDescOp::applyToOne(transform::TransformRewriter &rewriter,
589590 return diag;
590591 }
591592
592- int64_t operandIndex = getOperandIndex () ? getOperandIndex (). value () : 0 ;
593+ int64_t operandIndex = getOperandIndex ();
593594 if (operandIndex >= targetOp.getNumOperands ()) {
594595 return emitSilenceableFailure (getLoc ())
595596 << " operandIndex exceeds the number of op operands." ;
@@ -618,7 +619,7 @@ DiagnosedSilenceableFailure transform::SetResultLayoutOp::applyToOne(
618619 transform::ApplyToEachResultList &results,
619620 transform::TransformState &state) {
620621
621- int64_t resultIndex = getResultIndex () ? getResultIndex (). value () : 0 ;
622+ int64_t resultIndex = getResultIndex ();
622623 if (resultIndex >= target->getNumResults ()) {
623624 return emitSilenceableFailure (getLoc ())
624625 << " resultIndex exceeds the number of op results." ;
@@ -642,30 +643,21 @@ DiagnosedSilenceableFailure transform::SetResultLayoutOp::applyToOne(
642643 << " Expected inst_data to be a 2D vector" ;
643644 }
644645
645- // For now only desc op or dpas op are supported.
646+ // For now only create_nd_desc op is supported.
646647 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
647- auto dpasOp = dyn_cast<xegpu::DpasOp>(target);
648- if (!descOp && !dpasOp) {
648+ if (!descOp) {
649649 auto diag = emitSilenceableFailure (getLoc ())
650- << " Expected a xegpu.create_nd_desc or xegpu.dpas op, but got: " << target->getName ();
650+ << " Expected a xegpu.create_nd_desc op, but got: "
651+ << target->getName ();
651652 diag.attachNote (target->getLoc ()) << " target op" ;
652653 return diag;
653654 }
654655
656+ // Set layout attr in desc op's return type. Replaces old desc op.
655657 auto layoutAttr =
656658 createLayoutAttr (rewriter.getContext (), sgLayout, sgData, instData);
657- if (descOp) {
658- // Replace desc op with a new op that has the layout attr in return type.
659- auto newdescOp = setDescLayout (rewriter, descOp, layoutAttr);
660- results.push_back (newdescOp.getOperation ());
661- }
662- if (dpasOp) {
663- // Set layout attribute for the dpas op result.
664- // NOTE this actually does not create a new handle ...
665- // NOTE should not invalidate the handle ... should be a separate op?
666- xegpu::setLayoutAttr (dpasOp.getOperation ()->getResults ()[0 ], layoutAttr);
667- results.push_back (dpasOp.getOperation ());
668- }
659+ auto newdescOp = setDescLayout (rewriter, descOp, layoutAttr);
660+ results.push_back (newdescOp.getOperation ());
669661 return DiagnosedSilenceableFailure::success ();
670662}
671663
@@ -676,7 +668,76 @@ void transform::SetResultLayoutOp::getEffects(
676668 modifiesPayload (effects);
677669}
678670
679- DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne (
671+ DiagnosedSilenceableFailure transform::SetOpLayoutAttrOp::applyToOne (
672+ transform::TransformRewriter &rewriter, Operation *target,
673+ transform::ApplyToEachResultList &results,
674+ transform::TransformState &state) {
675+
676+ bool resultTarget = getResult ();
677+ bool operandTarget = getOperand ();
678+
679+ if (resultTarget && operandTarget) {
680+ return emitSilenceableFailure (getLoc ())
681+ << " `result` and `operand` cannot be both set." ;
682+ }
683+ if (!resultTarget && !operandTarget) {
684+ return emitSilenceableFailure (getLoc ())
685+ << " Either `result` or `operand` must be set." ;
686+ }
687+
688+ int64_t index = getIndex ();
689+ if (resultTarget && index >= target->getNumResults ()) {
690+ return emitSilenceableFailure (getLoc ())
691+ << " index exceeds the number of op results." ;
692+ }
693+ if (!resultTarget && index >= target->getNumOperands ()) {
694+ return emitSilenceableFailure (getLoc ())
695+ << " index exceeds the number of op operands." ;
696+ }
697+
698+ auto sgLayout = getSgLayout ();
699+ if (sgLayout.size () != 2 ) {
700+ return emitSilenceableFailure (getLoc ())
701+ << " Expected sg_layout to be a 2D vector" ;
702+ }
703+
704+ auto sgData = getSgData ();
705+ if (sgData.size () != 2 ) {
706+ return emitSilenceableFailure (getLoc ())
707+ << " Expected sg_data to be a 2D vector" ;
708+ }
709+
710+ auto instData = getInstData ();
711+ if (instData.size () != 2 ) {
712+ return emitSilenceableFailure (getLoc ())
713+ << " Expected inst_data to be a 2D vector" ;
714+ }
715+
716+ // For now only dpas op is supported.
717+ if (!isa<xegpu::DpasOp>(target)) {
718+ auto diag = emitSilenceableFailure (getLoc ())
719+ << " Expected a xegpu.dpas op, but got: " << target->getName ();
720+ diag.attachNote (target->getLoc ()) << " target op" ;
721+ return diag;
722+ }
723+ auto layoutAttr =
724+ createLayoutAttr (rewriter.getContext (), sgLayout, sgData, instData);
725+ // Set layout attribute for the op result or operand
726+ if (resultTarget) {
727+ xegpu::setLayoutAttr (target->getResult (index), layoutAttr);
728+ } else {
729+ xegpu::setLayoutAttr (target->getOpOperand (index), layoutAttr);
730+ }
731+ return DiagnosedSilenceableFailure::success ();
732+ }
733+
734+ void transform::SetOpLayoutAttrOp::getEffects (
735+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
736+ onlyReadsHandle (getTargetMutable (), effects);
737+ modifiesPayload (effects);
738+ }
739+
740+ DiagnosedSilenceableFailure transform::ConvertOperandLayoutOp::applyToOne (
680741 transform::TransformRewriter &rewriter, Operation *target,
681742 transform::ApplyToEachResultList &results,
682743 transform::TransformState &state) {
@@ -714,26 +775,48 @@ DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne(
714775 << " Expected inst_data to be a 2D vector" ;
715776 }
716777
717- // Replace descriptor op using layout attribute .
718- Value opVec = targetOp. getOperation () ->getOperand (operandIndex);
778+ // Find desc op.
779+ Value opVec = target ->getOperand (operandIndex);
719780 auto maybeDescOp = findDescriptorOp (opVec, targetOp.getOperation ());
720781 if (!maybeDescOp) {
721782 return emitSilenceableFailure (getLoc ()) << " Could not find descriptor op." ;
722783 }
723784 auto descOp = *maybeDescOp;
724- // Set layout attribute.
785+ // Get load op.
786+ auto maybeLoadOp = getUserOfType<xegpu::LoadNdOp>(descOp.getResult ());
787+ if (!maybeLoadOp) {
788+ return emitSilenceableFailure (getLoc ())
789+ << " Expected a xegpu.load_nd op as a user of the descriptor op." ;
790+ }
791+ auto loadOp = *maybeLoadOp;
792+ // Get load op operand value layout
793+ auto producerLayoutAttr = xegpu::getLayoutAttr (loadOp.getOperand (0 ));
794+ if (!producerLayoutAttr) {
795+ return emitSilenceableFailure (getLoc ())
796+ << " Operand producer op does not have a layout attr." ;
797+ }
798+
799+ // New layout attr
725800 auto layoutAttr =
726801 createLayoutAttr (rewriter.getContext (), sgLayout, sgData, instData);
727- descOp = setDescLayout (rewriter, descOp, layoutAttr);
728- if (operandIndex == 2 ) {
729- // C operand: set layout attribute for the dpas op result.
730- xegpu::setLayoutAttr (targetOp.getOperation ()->getResults ()[0 ], layoutAttr);
802+
803+ if (producerLayoutAttr != layoutAttr) {
804+ rewriter.setInsertionPointAfter (loadOp.getOperation ());
805+ auto source = loadOp.getResult ();
806+ auto convLayoutOp = rewriter.create <xegpu::ConvertLayoutOp>(
807+ loadOp.getLoc (), source.getType (), source, producerLayoutAttr,
808+ layoutAttr);
809+ // Replace load op result with the converted layout.
810+ rewriter.replaceUsesWithIf (
811+ source, convLayoutOp.getResult (), [&](OpOperand &use) {
812+ return use.getOwner () != convLayoutOp.getOperation ();
813+ });
731814 }
732815
733816 return DiagnosedSilenceableFailure::success ();
734817}
735818
736- void transform::SetOperandLayoutOp ::getEffects (
819+ void transform::ConvertOperandLayoutOp ::getEffects (
737820 ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
738821 onlyReadsHandle (getTargetMutable (), effects);
739822 modifiesPayload (effects);
0 commit comments