@@ -216,7 +216,7 @@ namespace {
216216static LogicalResult fillShardingOption (Operation *op,
217217 ShardingOption &shardingOption,
218218 SymbolRefAttr cluster,
219- ArrayRef<int32_t > meshAxes,
219+ ArrayRef<MeshAxis > meshAxes,
220220 unsigned loopIdx) {
221221 if ((shardingOption.cluster && cluster &&
222222 shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
230230 if (i == loopIdx)
231231 continue ;
232232
233- for (int32_t axis : meshAxes) {
233+ for (MeshAxis axis : meshAxes) {
234234 if (llvm::is_contained (shardingOption.shardingArray [i], axis)) {
235235 LLVM_DEBUG (DBGS () << " sharding option conflicts because mesh axes "
236236 << axis << " duplicate" );
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
260260 SmallVector<AffineMap> maps = shardingOp.getIndexingMaps ();
261261 unsigned numOperands = op->getNumOperands ();
262262 shardingOption.shardingArray .resize (loopTypes.size ());
263- llvm::SmallVector<int32_t > partialMeshAxes;
263+ llvm::SmallVector<MeshAxis > partialMeshAxes;
264264 Partial partialType;
265265 llvm::SmallSet<unsigned , 4 > visitedLoopIndices;
266266 bool anyShardingInResultsOrOperands = false ;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
277277 // shardingOption[index]
278278 for (auto it : llvm::zip (map.getResults (), shardAttr.getSplitAxes ())) {
279279 AffineExpr expr = std::get<0 >(it);
280- ArrayRef<int32_t > axes = std::get<1 >(it).asArrayRef ();
280+ ArrayRef<MeshAxis > axes = std::get<1 >(it).asArrayRef ();
281281 auto dim = cast<AffineDimExpr>(expr);
282282 unsigned index = dim.getPosition ();
283283 visitedLoopIndices.insert (index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
288288
289289 // Handle the partial axes: at this stage, the exact loop index/indices
290290 // cannot be decided because there could be multiple reduction loops.
291- ArrayRef<int32_t > partialAxes = shardAttr.getPartialAxes ();
291+ ArrayRef<MeshAxis > partialAxes = shardAttr.getPartialAxes ();
292292 if (!partialAxes.empty ()) {
293293 if (!partialMeshAxes.empty ())
294294 return op->emitOpError () << " at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
321321 // then the operands with multiple loop indices.
322322 for (auto it : llvm::zip (map.getResults (), shardAttr.getSplitAxes ())) {
323323 AffineExpr expr = std::get<0 >(it);
324- ArrayRef<int32_t > axes = std::get<1 >(it).asArrayRef ();
324+ ArrayRef<MeshAxis > axes = std::get<1 >(it).asArrayRef ();
325325 FailureOr<llvm::SmallSet<unsigned , 2 >> loopIndices =
326326 checkOperandAffineExpr (expr, numDims);
327327 if (failed (loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
362362 if (!partialMeshAxes.empty ()) {
363363 bool anyNonEmptyReductionLoop = llvm::any_of (
364364 llvm::enumerate (shardingOption.shardingArray ), [&](auto it) {
365- SmallVector<int32_t > &subArray = it.value ();
365+ SmallVector<MeshAxis > &subArray = it.value ();
366366 int64_t idx = it.index ();
367367 return isReductionLoop (loopTypes[idx]) && !subArray.empty ();
368368 });
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
406406 return success ();
407407
408408 auto resultType = result.getType ().cast <RankedTensorType>();
409- SmallVector<SmallVector<int32_t >> splitAxes (resultType.getRank ());
410- SmallVector<int32_t > partialAxes;
409+ SmallVector<SmallVector<MeshAxis >> splitAxes (resultType.getRank ());
410+ SmallVector<MeshAxis > partialAxes;
411411
412412 // process the split axes
413413 for (auto it : llvm::enumerate (map.getResults ())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
431431 assert (partialType == curPartialType &&
432432 " Only one reduction type is supported" );
433433 partialType = curPartialType;
434- const SmallVector<int32_t > &axis = std::get<1 >(it);
434+ const SmallVector<MeshAxis > &axis = std::get<1 >(it);
435435 partialAxes.append (axis);
436436 }
437437 }
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
459459 return success ();
460460 Value operand = opOperand.get ();
461461 auto operandType = operand.getType ().cast <RankedTensorType>();
462- SmallVector<SmallVector<int32_t >> splitAxes (operandType.getRank ());
462+ SmallVector<SmallVector<MeshAxis >> splitAxes (operandType.getRank ());
463463 unsigned numDims = map.getNumDims ();
464464 for (auto it : llvm::enumerate (map.getResults ())) {
465465 int64_t idx = it.index ();
0 commit comments