@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464464
465465template <typename UniformOp, typename NonUniformOp>
466466static Value createGroupReduceOpImpl (OpBuilder &builder, Location loc,
467- Value arg, bool isGroup, bool isUniform) {
467+ Value arg, bool isGroup, bool isUniform,
468+ std::optional<uint32_t > clusterSize) {
468469 Type type = arg.getType ();
469470 auto scope = mlir::spirv::ScopeAttr::get (builder.getContext (),
470471 isGroup ? spirv::Scope::Workgroup
471472 : spirv::Scope::Subgroup);
472- auto groupOp = spirv::GroupOperationAttr::get (builder.getContext (),
473- spirv::GroupOperation::Reduce);
473+ auto groupOp = spirv::GroupOperationAttr::get (
474+ builder.getContext (), clusterSize.has_value ()
475+ ? spirv::GroupOperation::ClusteredReduce
476+ : spirv::GroupOperation::Reduce);
474477 if (isUniform) {
475478 return builder.create <UniformOp>(loc, type, scope, groupOp, arg)
476479 .getResult ();
477480 }
478- return builder.create <NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
481+
482+ Value clusterSizeValue;
483+ if (clusterSize.has_value ())
484+ clusterSizeValue = builder.create <spirv::ConstantOp>(
485+ loc, builder.getI32Type (),
486+ builder.getIntegerAttr (builder.getI32Type (), *clusterSize));
487+
488+ return builder
489+ .create <NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
479490 .getResult ();
480491}
481492
482- static std::optional<Value> createGroupReduceOp (OpBuilder &builder,
483- Location loc, Value arg,
484- gpu::AllReduceOperation opType,
485- bool isGroup, bool isUniform ) {
493+ static std::optional<Value>
494+ createGroupReduceOp (OpBuilder &builder, Location loc, Value arg,
495+ gpu::AllReduceOperation opType, bool isGroup ,
496+ bool isUniform, std::optional< uint32_t > clusterSize ) {
486497 enum class ElemType { Float, Boolean, Integer };
487- using FuncT = Value (*)(OpBuilder &, Location, Value, bool , bool );
498+ using FuncT = Value (*)(OpBuilder &, Location, Value, bool , bool ,
499+ std::optional<uint32_t >);
488500 struct OpHandler {
489501 gpu::AllReduceOperation kind;
490502 ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
548560
549561 for (const OpHandler &handler : handlers)
550562 if (handler.kind == opType && elementType == handler.elemType )
551- return handler.func (builder, loc, arg, isGroup, isUniform);
563+ return handler.func (builder, loc, arg, isGroup, isUniform, clusterSize );
552564
553565 return std::nullopt ;
554566}
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
571583
572584 auto result =
573585 createGroupReduceOp (rewriter, op.getLoc (), adaptor.getValue (), *opType,
574- /* isGroup*/ true , op.getUniform ());
586+ /* isGroup*/ true , op.getUniform (), std:: nullopt );
575587 if (!result)
576588 return failure ();
577589
@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
589601 LogicalResult
590602 matchAndRewrite (gpu::SubgroupReduceOp op, OpAdaptor adaptor,
591603 ConversionPatternRewriter &rewriter) const override {
592- if (op.getClusterSize ())
604+ if (op.getClusterStride () > 1 ) {
593605 return rewriter.notifyMatchFailure (
594- op, " lowering for clustered reduce not implemented" );
606+ op, " lowering for cluster stride > 1 is not implemented" );
607+ }
595608
596609 if (!isa<spirv::ScalarType>(adaptor.getValue ().getType ()))
597610 return rewriter.notifyMatchFailure (op, " reduction type is not a scalar" );
598611
599- auto result = createGroupReduceOp (rewriter, op. getLoc (), adaptor. getValue (),
600- adaptor.getOp (),
601- /* isGroup=*/ false , adaptor.getUniform ());
612+ auto result = createGroupReduceOp (
613+ rewriter, op. getLoc (), adaptor. getValue (), adaptor.getOp (),
614+ /* isGroup=*/ false , adaptor.getUniform (), op. getClusterSize ());
602615 if (!result)
603616 return failure ();
604617
0 commit comments