Skip to content

Commit 0dfc748

Browse files
author
git apple-llvm automerger
committed
Merge commit '0dae924c1f66' from llvm.org/main into next
2 parents b706a01 + 0dae924 commit 0dfc748

File tree

7 files changed

+208
-72
lines changed

7 files changed

+208
-72
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4414,10 +4414,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
44144414
return std::nullopt;
44154415
}
44164416

4417+
// Helper function to extract string value from bind name variant
4418+
static std::optional<llvm::StringRef> getBindNameStringValue(
4419+
const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4420+
&bindNameValue) {
4421+
if (!bindNameValue.has_value())
4422+
return std::nullopt;
4423+
4424+
return std::visit(
4425+
[](const auto &attr) -> std::optional<llvm::StringRef> {
4426+
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4427+
mlir::StringAttr>) {
4428+
return attr.getValue();
4429+
} else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
4430+
mlir::SymbolRefAttr>) {
4431+
return attr.getLeafReference();
4432+
} else {
4433+
return std::nullopt;
4434+
}
4435+
},
4436+
bindNameValue.value());
4437+
}
4438+
44174439
static bool compareDeviceTypeInfo(
44184440
mlir::acc::RoutineOp op,
4419-
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
4420-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
4441+
llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
4442+
llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
4443+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
4444+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
44214445
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
44224446
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
44234447
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4427,9 +4451,13 @@ static bool compareDeviceTypeInfo(
44274451
for (uint32_t dtypeInt = 0;
44284452
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
44294453
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
4430-
if (op.getBindNameValue(dtype) !=
4431-
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4432-
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
4454+
auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
4455+
if (bindNameValue !=
4456+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4457+
bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
4458+
bindNameValue !=
4459+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
4460+
bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
44334461
return false;
44344462
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
44354463
return false;
@@ -4476,8 +4504,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
44764504
void createOpenACCRoutineConstruct(
44774505
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
44784506
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
4479-
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
4480-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4507+
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
4508+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
4509+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4510+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
44814511
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
44824512
llvm::SmallVector<mlir::Attribute> &gangDimValues,
44834513
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4490,7 +4520,8 @@ void createOpenACCRoutineConstruct(
44904520
0) {
44914521
// If the routine is already specified with the same clauses, just skip
44924522
// the operation creation.
4493-
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
4523+
if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
4524+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
44944525
gangDeviceTypes, gangDimValues,
44954526
gangDimDeviceTypes, seqDeviceTypes,
44964527
workerDeviceTypes, vectorDeviceTypes) &&
@@ -4507,8 +4538,10 @@ void createOpenACCRoutineConstruct(
45074538
modBuilder.create<mlir::acc::RoutineOp>(
45084539
loc, routineOpStr,
45094540
mlir::SymbolRefAttr::get(builder.getContext(), funcName),
4510-
getArrayAttrOrNull(builder, bindNames),
4511-
getArrayAttrOrNull(builder, bindNameDeviceTypes),
4541+
getArrayAttrOrNull(builder, bindIdNames),
4542+
getArrayAttrOrNull(builder, bindStrNames),
4543+
getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
4544+
getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
45124545
getArrayAttrOrNull(builder, workerDeviceTypes),
45134546
getArrayAttrOrNull(builder, vectorDeviceTypes),
45144547
getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
@@ -4525,8 +4558,10 @@ static void interpretRoutineDeviceInfo(
45254558
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
45264559
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
45274560
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
4528-
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
4529-
llvm::SmallVector<mlir::Attribute> &bindNames,
4561+
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
4562+
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
4563+
llvm::SmallVector<mlir::Attribute> &bindIdNames,
4564+
llvm::SmallVector<mlir::Attribute> &bindStrNames,
45304565
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
45314566
llvm::SmallVector<mlir::Attribute> &gangDimValues,
45324567
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4559,16 +4594,18 @@ static void interpretRoutineDeviceInfo(
45594594
if (dinfo.bindNameOpt().has_value()) {
45604595
const auto &bindName = dinfo.bindNameOpt().value();
45614596
mlir::Attribute bindNameAttr;
4562-
if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
4597+
if (const auto &bindSym{
4598+
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4599+
bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
4600+
bindIdNames.push_back(bindNameAttr);
4601+
bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
4602+
} else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
45634603
bindNameAttr = builder.getStringAttr(*bindStr);
4564-
} else if (const auto &bindSym{
4565-
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
4566-
bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
4604+
bindStrNames.push_back(bindNameAttr);
4605+
bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
45674606
} else {
45684607
llvm_unreachable("Unsupported bind name type");
45694608
}
4570-
bindNames.push_back(bindNameAttr);
4571-
bindNameDeviceTypes.push_back(getDeviceTypeAttr());
45724609
}
45734610
}
45744611

@@ -4584,8 +4621,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45844621
bool hasNohost{false};
45854622

45864623
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
4587-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4588-
gangDimDeviceTypes, gangDimValues;
4624+
workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4625+
bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
4626+
gangDimValues;
45894627

45904628
for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
45914629
// Device Independent Attributes
@@ -4594,24 +4632,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
45944632
}
45954633
// Note: Device Independent Attributes are set to the
45964634
// none device type in `info`.
4597-
interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
4598-
vectorDeviceTypes, workerDeviceTypes,
4599-
bindNameDeviceTypes, bindNames, gangDeviceTypes,
4600-
gangDimValues, gangDimDeviceTypes);
4635+
interpretRoutineDeviceInfo(
4636+
converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
4637+
bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
4638+
bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
46014639

46024640
// Device Dependent Attributes
46034641
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
46044642
info.deviceTypeInfos()) {
4605-
interpretRoutineDeviceInfo(
4606-
converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
4607-
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
4608-
gangDimValues, gangDimDeviceTypes);
4643+
interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
4644+
vectorDeviceTypes, workerDeviceTypes,
4645+
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4646+
bindIdNames, bindStrNames, gangDeviceTypes,
4647+
gangDimValues, gangDimDeviceTypes);
46094648
}
46104649
}
46114650
createOpenACCRoutineConstruct(
4612-
converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
4613-
bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
4614-
seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
4651+
converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
4652+
bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
4653+
gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
4654+
workerDeviceTypes, vectorDeviceTypes);
46154655
}
46164656

46174657
static void

flang/test/Lower/OpenACC/acc-routine.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
44

5-
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
6-
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
5+
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
6+
! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
7+
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
78
! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
89
! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
910
! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
1011
! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
11-
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
12+
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
1213
! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
1314
! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
1415
! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost

flang/test/Lower/OpenACC/acc-routine03.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ subroutine sub2(a)
3030
end subroutine
3131

3232
! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
33-
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
33+
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
3434
! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
3535
! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Interfaces/ControlFlowInterfaces.h"
3030
#include "mlir/Interfaces/LoopLikeInterface.h"
3131
#include "mlir/Interfaces/SideEffectInterfaces.h"
32+
#include <variant>
3233

3334
#define GET_TYPEDEF_CLASSES
3435
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
27722772
}];
27732773

27742774
let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
2775-
OptionalAttr<StrArrayAttr>:$bindName,
2776-
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
2775+
OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
2776+
OptionalAttr<StrArrayAttr>:$bindStrName,
2777+
OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
2778+
OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
27772779
OptionalAttr<DeviceTypeArrayAttr>:$worker,
27782780
OptionalAttr<DeviceTypeArrayAttr>:$vector,
27792781
OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
@@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
28152817
std::optional<int64_t> getGangDimValue();
28162818
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
28172819

2818-
std::optional<llvm::StringRef> getBindNameValue();
2819-
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
2820+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
2821+
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
28202822
}];
28212823

28222824
let assemblyFormat = [{
28232825
$sym_name `func` `(` $func_name `)`
28242826
oilist (
2825-
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
2827+
`bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
28262828
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
28272829
| `worker` custom<DeviceTypeArrayAttr>($worker)
28282830
| `vector` custom<DeviceTypeArrayAttr>($vector)

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/SmallSet.h"
2222
#include "llvm/ADT/TypeSwitch.h"
2323
#include "llvm/Support/LogicalResult.h"
24+
#include <variant>
2425

2526
using namespace mlir;
2627
using namespace acc;
@@ -3461,40 +3462,88 @@ LogicalResult acc::RoutineOp::verify() {
34613462
return success();
34623463
}
34633464

3464-
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3465-
mlir::ArrayAttr &deviceTypes) {
3466-
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3467-
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3465+
static ParseResult parseBindName(OpAsmParser &parser,
3466+
mlir::ArrayAttr &bindIdName,
3467+
mlir::ArrayAttr &bindStrName,
3468+
mlir::ArrayAttr &deviceIdTypes,
3469+
mlir::ArrayAttr &deviceStrTypes) {
3470+
llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
3471+
llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
3472+
llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
3473+
llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
34683474

34693475
if (failed(parser.parseCommaSeparatedList([&]() {
3470-
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3476+
mlir::Attribute newAttr;
3477+
bool isSymbolRefAttr;
3478+
auto parseResult = parser.parseAttribute(newAttr);
3479+
if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3480+
bindIdNameAttrs.push_back(symbolRefAttr);
3481+
isSymbolRefAttr = true;
3482+
} else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3483+
bindStrNameAttrs.push_back(stringAttr);
3484+
isSymbolRefAttr = false;
3485+
}
3486+
if (parseResult)
34713487
return failure();
34723488
if (failed(parser.parseOptionalLSquare())) {
3473-
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3474-
parser.getContext(), mlir::acc::DeviceType::None));
3489+
if (isSymbolRefAttr) {
3490+
deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3491+
parser.getContext(), mlir::acc::DeviceType::None));
3492+
} else {
3493+
deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3494+
parser.getContext(), mlir::acc::DeviceType::None));
3495+
}
34753496
} else {
3476-
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3477-
parser.parseRSquare())
3478-
return failure();
3497+
if (isSymbolRefAttr) {
3498+
if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3499+
parser.parseRSquare())
3500+
return failure();
3501+
} else {
3502+
if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3503+
parser.parseRSquare())
3504+
return failure();
3505+
}
34793506
}
34803507
return success();
34813508
})))
34823509
return failure();
34833510

3484-
bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3485-
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3511+
bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3512+
bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3513+
deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3514+
deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
34863515

34873516
return success();
34883517
}
34893518

34903519
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
3491-
std::optional<mlir::ArrayAttr> bindName,
3492-
std::optional<mlir::ArrayAttr> deviceTypes) {
3493-
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3494-
[&](const auto &pair) {
3495-
p << std::get<0>(pair);
3496-
printSingleDeviceType(p, std::get<1>(pair));
3497-
});
3520+
std::optional<mlir::ArrayAttr> bindIdName,
3521+
std::optional<mlir::ArrayAttr> bindStrName,
3522+
std::optional<mlir::ArrayAttr> deviceIdTypes,
3523+
std::optional<mlir::ArrayAttr> deviceStrTypes) {
3524+
// Create combined vectors for all bind names and device types
3525+
llvm::SmallVector<mlir::Attribute> allBindNames;
3526+
llvm::SmallVector<mlir::Attribute> allDeviceTypes;
3527+
3528+
// Append bindIdName and deviceIdTypes
3529+
if (hasDeviceTypeValues(deviceIdTypes)) {
3530+
allBindNames.append(bindIdName->begin(), bindIdName->end());
3531+
allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3532+
}
3533+
3534+
// Append bindStrName and deviceStrTypes
3535+
if (hasDeviceTypeValues(deviceStrTypes)) {
3536+
allBindNames.append(bindStrName->begin(), bindStrName->end());
3537+
allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3538+
}
3539+
3540+
// Print the combined sequence
3541+
if (!allBindNames.empty())
3542+
llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3543+
[&](const auto &pair) {
3544+
p << std::get<0>(pair);
3545+
printSingleDeviceType(p, std::get<1>(pair));
3546+
});
34983547
}
34993548

35003549
static ParseResult parseRoutineGangClause(OpAsmParser &parser,
@@ -3654,19 +3703,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
36543703
return hasDeviceType(getSeq(), deviceType);
36553704
}
36563705

3657-
std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3706+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3707+
RoutineOp::getBindNameValue() {
36583708
return getBindNameValue(mlir::acc::DeviceType::None);
36593709
}
36603710

3661-
std::optional<llvm::StringRef>
3711+
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
36623712
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3663-
if (!hasDeviceTypeValues(getBindNameDeviceType()))
3713+
if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
3714+
!hasDeviceTypeValues(getBindStrNameDeviceType())) {
36643715
return std::nullopt;
3665-
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3666-
auto attr = (*getBindName())[*pos];
3716+
}
3717+
3718+
if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
3719+
auto attr = (*getBindIdName())[*pos];
3720+
auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3721+
assert(symbolRefAttr && "expected SymbolRef");
3722+
return symbolRefAttr;
3723+
}
3724+
3725+
if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
3726+
auto attr = (*getBindStrName())[*pos];
36673727
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3668-
return stringAttr.getValue();
3728+
assert(stringAttr && "expected String");
3729+
return stringAttr;
36693730
}
3731+
36703732
return std::nullopt;
36713733
}
36723734

0 commit comments

Comments
 (0)