@@ -139,7 +139,7 @@ struct LocalLoadOpConversion
139139
140140LogicalResult lowerDistributedToSharedStmatrix (
141141 Location loc, RankedTensorType tensorTy, MemDescType memDescType,
142- Value adaptorSrc, Value smemBase, Type llvmElemTy,
142+ bool transpose, Value adaptorSrc, Value smemBase, Type llvmElemTy,
143143 ConversionPatternRewriter &rewriter, const TargetInfo &targetInfo,
144144 std::pair<size_t , Type> *const llvmOpCount = nullptr ) {
145145 if (!targetInfo.supportLdStMatrix ())
@@ -160,7 +160,11 @@ LogicalResult lowerDistributedToSharedStmatrix(
160160 auto kOffset = S (" offset" );
161161 auto smemPtrTy = ptr_ty (ctx, 3 );
162162 auto bitwidth = tensorTy.getElementTypeBitWidth ();
163- if (bitwidth > 32 )
163+ // In the transpose case, consecutive elements are not stored contiguously
164+ // so we cannot split an fp32
165+ // We could support bitwidth == 8, but it'd be a rather weird layout
166+ // so we don't do that for now
167+ if ((!transpose && bitwidth > 32 ) || (transpose && bitwidth != 16 ))
164168 return failure ();
165169 // Inter block stmatrix is not supported
166170 if (cvt.hasInDim (kBlock ))
@@ -173,31 +177,75 @@ LogicalResult lowerDistributedToSharedStmatrix(
173177 cvt = removeBroadcast.apply (cvt);
174178 srcVals = removeBroadcast.apply (srcVals);
175179
176- auto tile = LinearLayout::identity1D (32 / bitwidth, kReg , kOffset ) *
177- LinearLayout::identity1D (4 , kLane , kOffset );
178- // Find if there is a register permutation that allows us to divideLeft
179- auto maybeAction = regPermForDivideLeft (cvt, tile);
180- if (!maybeAction.has_value ()) {
181- return failure ();
180+ LinearLayout reps;
181+ if (!transpose) {
182+ auto tile = LinearLayout::identity1D (32 / bitwidth, kReg , kOffset ) *
183+ LinearLayout::identity1D (4 , kLane , kOffset );
184+
185+ // Find if there is a register permutation that allows us to divideLeft
186+ // We need to pass the map from regs to offsets, as is cvt
187+ auto maybeAction = regPermForDivideLeft (cvt, tile);
188+ if (!maybeAction.has_value ()) {
189+ return failure ();
190+ }
191+ auto action = maybeAction.value ();
192+ // Check if the action indeed allows us to divideLeft
193+ cvt = action.apply (cvt);
194+ srcVals = action.apply (srcVals);
195+
196+ auto maybeQuot = divideLeft (cvt, tile);
197+ if (!maybeQuot.has_value ()) {
198+ return failure ();
199+ }
200+ reps = zerosLike (tile) * maybeQuot.value ();
201+ } else {
202+ // Division does not quite work here. To define this properly, we would need
203+ // to define a different multiplication that does:
204+ // A *' B = [[0, A], [B, 0]] and define leftDivision for it
205+ // We do it ad-hoc for now, as I beleive there's not much demand for this op
206+ // outside of this lowering
207+
208+ // Divisibility in the sense above is the same as regular divisibility
209+ // You need to see that the tile A is a sublayout of the matrix, and that
210+ // it has zeros above it and to its right.
211+
212+ // In particular, offsets lanes 4, 8, 16 map to offsets 1, 2, 4...
213+ const auto &laneBases = cvt.getBases ().find (kLane )->second ;
214+ for (int i = 0 ; i < 3 ; ++i) {
215+ if (laneBases[i + 2 ][0 ] != (1 << i))
216+ return failure ();
217+ }
218+ // ... and no other basis should depend on 1, 2, 4
219+ // Note that this gives us the usual alignment condition, but we have
220+ // translated it to checking that the matrix to the left of A is all zeros
221+ for (auto dim : cvt.getInDimNames ()) {
222+ const auto &bases = cvt.getBases ().find (dim)->second ;
223+ for (auto [i, basis] : llvm::enumerate (bases)) {
224+ if (dim == kLane && i >= 2 )
225+ continue ;
226+ if (basis[0 ] & 0b111 )
227+ return failure ();
228+ }
229+ }
230+
231+ // Hack: We are not going to use in the rest of the function reps[kLane][2:]
232+ // so we don't need to zero them out
233+ reps = cvt;
182234 }
183- auto action = maybeAction.value ();
184- // Check if the action indeed allows us to divideLeft
185- cvt = action.apply (cvt);
186- auto maybeQuot = divideLeft (cvt, tile);
187- if (!maybeQuot.has_value ()) {
235+
236+ // We must have at least 2 register elements to use stmatrix.trans
237+ if (transpose && reps.getInDimSizeLog2 (kReg ) < llvm::Log2_32 (32 / bitwidth)) {
188238 return failure ();
189239 }
190- auto quot = maybeQuot.value ();
191- srcVals = action.apply (srcVals);
192- // Map from kReg, kLane, kWarp to beginning of each tile
193- auto reps = zerosLike (tile) * quot;
194- assert (reps.getOutDimSize (kOffset ) == cvt.getOutDimSize (kOffset ));
195240
196- // Choose up to 4 packs of 32-bit elements indexed by the next to bases
197- // as the vectorisation factor
198- auto vec = std::min (2 , quot.getInDimSizeLog2 (kReg ));
241+ // Choose up to 4 packs of 32-bit elements indexed by the next (at most) two
242+ // bases as the vectorisation factor. We don't consider the basis of the tile
243+ // for vectorisation so we substract them
244+ auto vec = std::min<int32_t >(2 , reps.getInDimSizeLog2 (kReg ) -
245+ llvm::Log2_32 (32 / bitwidth));
199246
200- // FIXME(Lezcano): Should we bail if any of the other 3 lane bases is zero?
247+ // Map from kReg, kLane, kWarp to beginning of each tile
248+ assert (reps.getOutDimSize (kOffset ) == cvt.getOutDimSize (kOffset ));
201249
202250 auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
203251 // Compute the addresses for the 0th tile
@@ -212,12 +260,24 @@ LogicalResult lowerDistributedToSharedStmatrix(
212260 // given
213261 // by the first `vec` reg bases that are not part of the tile
214262 std::vector<std::vector<int32_t >> laneBases;
215- assert (tile.getInDimSizeLog2 (kLane ) == 2 );
216- for (int i = 0 ; i < 3 ; ++i) {
217- laneBases.push_back (reps.getBasis (kLane , tile.getInDimSizeLog2 (kLane ) + i));
218- }
219- for (int i = 0 ; i < vec; ++i) {
220- laneBases.push_back (reps.getBasis (kReg , tile.getInDimSizeLog2 (kReg ) + i));
263+ if (!transpose) {
264+ auto tileDimSizeReg = llvm::Log2_32 (32 / bitwidth);
265+ auto tileDimSizeLane = 2 ;
266+ for (int i = 0 ; i < 3 ; ++i) {
267+ laneBases.push_back (reps.getBasis (kLane , tileDimSizeLane + i));
268+ }
269+ for (int i = 0 ; i < vec; ++i) {
270+ laneBases.push_back (reps.getBasis (kReg , tileDimSizeReg + i));
271+ }
272+ } else {
273+ // We choose the first basis of the register. In the future we could choose
274+ // a basis that minimises the bank conflicts
275+ laneBases.push_back (reps.getBasis (kReg , 0 ));
276+ laneBases.push_back (reps.getBasis (kLane , 0 ));
277+ laneBases.push_back (reps.getBasis (kLane , 1 ));
278+ for (int i = 0 ; i < vec; ++i) {
279+ laneBases.push_back (reps.getBasis (kReg , i + 1 ));
280+ }
221281 }
222282
223283 LinearLayout addrLayout =
@@ -247,7 +307,8 @@ LogicalResult lowerDistributedToSharedStmatrix(
247307 }
248308 inputs.push_back (b.bitcast (input, i32_ty));
249309 }
250- rewriter.create <triton::nvgpu::StoreMatrixOp>(loc, vecAddr, inputs);
310+ rewriter.create <triton::nvgpu::StoreMatrixOp>(loc, vecAddr, inputs,
311+ /* needTrans=*/ transpose);
251312 }
252313 return success ();
253314}
@@ -271,10 +332,19 @@ struct LocalAllocOpConversion
271332 Value smemBase =
272333 LLVM::getSharedMemoryBase (op.getLoc (), rewriter, targetInfo, op);
273334
274- if (lowerDistributedToSharedStmatrix (op.getLoc (), srcTy, memDescType,
275- adaptor.getSrc (), smemBase, llvmElemTy,
276- rewriter, targetInfo)
277- .failed ()) {
335+ // Try to lower transposed or not
336+ bool lowered = false ;
337+ for (bool transpose : {false , true }) {
338+ lowered =
339+ lowerDistributedToSharedStmatrix (
340+ op.getLoc (), srcTy, memDescType, transpose, adaptor.getSrc (),
341+ smemBase, llvmElemTy, rewriter, targetInfo)
342+ .succeeded ();
343+ if (lowered) {
344+ break ;
345+ }
346+ }
347+ if (!lowered) {
278348 return failure ();
279349 }
280350
@@ -306,11 +376,20 @@ struct LocalStoreOpConversion
306376 getTypeConverter ()->convertType (op.getDst ().getType ().getElementType ());
307377 SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct (
308378 op.getLoc (), adaptor.getDst (), llvmElemTy, rewriter);
309- if (lowerDistributedToSharedStmatrix (op.getLoc (), op.getSrc ().getType (),
310- op.getDst ().getType (),
311- adaptor.getSrc (), smemObj.getBase (),
312- llvmElemTy, rewriter, targetInfo)
313- .failed ()) {
379+
380+ // Try to lower transposed or not
381+ bool lowered = false ;
382+ for (bool transpose : {false , true }) {
383+ lowered = lowerDistributedToSharedStmatrix (
384+ op.getLoc (), op.getSrc ().getType (), op.getDst ().getType (),
385+ transpose, adaptor.getSrc (), smemObj.getBase (), llvmElemTy,
386+ rewriter, targetInfo)
387+ .succeeded ();
388+ if (lowered) {
389+ break ;
390+ }
391+ }
392+ if (!lowered) {
314393 return failure ();
315394 }
316395 rewriter.eraseOp (op);
0 commit comments