@@ -1157,23 +1157,28 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols(
1157
1157
bool HasCUDAImplicitTransfer (const Expr<SomeType> &expr) {
1158
1158
semantics::UnorderedSymbolSet hostSymbols;
1159
1159
semantics::UnorderedSymbolSet deviceSymbols;
1160
+ semantics::UnorderedSymbolSet cudaSymbols{CollectCudaSymbols (expr)};
1160
1161
1161
1162
SymbolVector symbols{GetSymbolVector (expr)};
1162
1163
std::reverse (symbols.begin (), symbols.end ());
1163
1164
bool skipNext{false };
1164
1165
for (const Symbol &sym : symbols) {
1165
- bool isComponent{sym.owner ().IsDerivedType ()};
1166
- bool skipComponent{false };
1167
- if (!skipNext) {
1168
- if (IsCUDADeviceSymbol (sym)) {
1169
- deviceSymbols.insert (sym);
1170
- } else if (isComponent) {
1171
- skipComponent = true ; // Component is not device. Look on the base.
1172
- } else {
1173
- hostSymbols.insert (sym);
1166
+ if (cudaSymbols.find (sym) != cudaSymbols.end ()) {
1167
+ bool isComponent{sym.owner ().IsDerivedType ()};
1168
+ bool skipComponent{false };
1169
+ if (!skipNext) {
1170
+ if (IsCUDADeviceSymbol (sym)) {
1171
+ deviceSymbols.insert (sym);
1172
+ } else if (isComponent) {
1173
+ skipComponent = true ; // Component is not device. Look on the base.
1174
+ } else {
1175
+ hostSymbols.insert (sym);
1176
+ }
1174
1177
}
1178
+ skipNext = isComponent && !skipComponent;
1179
+ } else {
1180
+ skipNext = false ;
1175
1181
}
1176
- skipNext = isComponent && !skipComponent;
1177
1182
}
1178
1183
bool hasConstant{HasConstant (expr)};
1179
1184
return (hasConstant || (hostSymbols.size () > 0 )) && deviceSymbols.size () > 0 ;
0 commit comments