@@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328328 } else {
329329 // Cast 5. The two layouts are equivalent. We should probably remove
330330 // these in RemoveLayoutConversion.
331- rewriter.replaceOp (op, adaptor.getSrc ());
331+ auto dstCvt = requiresI32Conversion (dstTy);
332+ auto srcCvt = requiresI32Conversion (srcTy);
333+ if (dstCvt || srcCvt) {
334+ auto inVals = unpackLLElements (op.getLoc (), adaptor.getSrc (), rewriter);
335+ inVals = unpackI32s (inVals, srcTy, rewriter, op.getLoc (),
336+ getTypeConverter ());
337+ inVals =
338+ packI32s (inVals, dstTy, rewriter, op.getLoc (), getTypeConverter ());
339+ auto res = packLLElements (op.getLoc (), getTypeConverter (), inVals,
340+ rewriter, op.getType ());
341+ rewriter.replaceOp (op, res);
342+ } else {
343+ rewriter.replaceOp (op, adaptor.getSrc ());
344+ }
332345 return success ();
333346 }
334347 }
@@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
342355 StringAttr kRegister = str_attr (" register" );
343356 assert (!cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
344357
358+ auto srcTy = op.getSrc ().getType ();
359+ auto dstTy = op.getType ();
345360 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
361+ inVals = unpackI32s (inVals, srcTy, rewriter, loc, getTypeConverter ());
346362 SmallVector<Value> outVals (numRegs);
347- for (int i = 0 ; i < outVals. size () ; i++) {
363+ for (int i = 0 ; i < numRegs ; i++) {
348364 // Remove free masks from the register index
349365 // For example, if idx = 0b00111, and masks = 0b00100, then we get
350366 // 0b00011. It means that register 7 (0b111) has the same value as
@@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
355371 : idx;
356372 outVals[i] = inVals[srcIdx];
357373 }
374+ outVals = packI32s (outVals, dstTy, rewriter, loc, getTypeConverter ());
358375 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
359376 op.getType ());
360377 rewriter.replaceOp (op, result);
@@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
386403 if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387404 if (auto nvidiaMma =
388405 dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent ())) {
389- if (product (getCTAsPerCGA (nvidiaMma)) > 1 ) {
390- return false ;
391- }
392406 if (useLegacyMMAConversion) {
393407 return false ;
394408 }
@@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
398412 dotOperand.getKWidth () * dstTy.getElementTypeBitWidth () > 64 ;
399413 return largeKWidth && nvidiaMma.isAmpere ();
400414 }
415+ return false ;
401416 }
402417 if (isa<BlockedEncodingAttr>(layout)) {
403418 return true ;
@@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
439454 inVals[it.index ()] = ptrtoint (llvmElemTy, it.value ());
440455 }
441456 }
457+ inVals = unpackI32s (inVals, srcTy, rewriter, loc, getTypeConverter ());
442458
443459 // Pretty sure this is the identity function ATM
444460 // It'd be better to simply call `quotient({kBlock})` and
@@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
458474 }
459475 }
460476
461- // FIXME [Dot LL]
462- // We know it's just for largeKWidth case in Ampere
463- // In this case, we need to pack the outputs into i32
464- if (isa<DotOperandEncodingAttr>(dstTy.getEncoding ())) {
465- auto concat = [&](Value a, Value b) {
466- return or_ (zext (i32_ty, bitcast (a, i16_ty)),
467- shl (zext (i32_ty, bitcast (b, i16_ty)), i32_val (16 )));
468- };
469-
470- SmallVector<Value> outVals32 (outVals.size () / 2 );
471- for (int i = 0 ; i < outVals32.size (); ++i) {
472- outVals32[i] = concat (outVals[2 * i], outVals[2 * i + 1 ]);
473- }
474- outVals = outVals32;
475- }
476-
477+ outVals = packI32s (outVals, dstTy, rewriter, loc, getTypeConverter ());
477478 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
478479 op.getType ());
479480 rewriter.replaceOp (op, result);
0 commit comments