Skip to content

Commit 21718f0

Browse files
author
Yaxun Liu
committed
cast to global addr
1 parent 0b06286 commit 21718f0

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

llvm/lib/Target/AMDGPU/AMDGPUSplitKernelArguments.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,24 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
9393
SmallVector<Argument *, 8> StructArgs;
9494
SmallVector<Type *, 8> NewArgTypes;
9595

96+
// Helper function to convert address space if needed
97+
auto convertAddressSpace = [](Type *Ty) -> Type * {
98+
if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
99+
if (PtrTy->getAddressSpace() == AMDGPUAS::FLAT_ADDRESS) {
100+
// Create a new pointer type in the global address space
101+
return PointerType::get(PtrTy->getContext(), AMDGPUAS::GLOBAL_ADDRESS);
102+
}
103+
}
104+
return Ty;
105+
};
106+
96107
// Collect struct arguments and new argument types
97108
unsigned OriginalArgIndex = 0;
98109
unsigned NewArgIndex = 0;
99110
for (Argument &Arg : F.args()) {
100111
LLVM_DEBUG(dbgs() << "Processing argument: " << Arg << "\n");
101112
if (Arg.use_empty()) {
102-
NewArgTypes.push_back(Arg.getType());
113+
NewArgTypes.push_back(convertAddressSpace(Arg.getType()));
103114
NewArgMappings.push_back(
104115
std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
105116
++NewArgIndex;
@@ -181,7 +192,8 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
181192
ArgToLoadsMap[&Arg] = Loads;
182193
ArgToGEPsMap[&Arg] = GEPs;
183194
for (LoadInst *LI : Loads) {
184-
NewArgTypes.push_back(LI->getType());
195+
Type *NewType = convertAddressSpace(LI->getType());
196+
NewArgTypes.push_back(NewType);
185197

186198
// Compute offset
187199
uint64_t Offset = 0;
@@ -198,7 +210,7 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
198210
++NewArgIndex;
199211
}
200212
} else {
201-
NewArgTypes.push_back(Arg.getType());
213+
NewArgTypes.push_back(convertAddressSpace(Arg.getType()));
202214
// Include mapping if indices have changed
203215
if (NewArgIndex != OriginalArgIndex)
204216
NewArgMappings.push_back(
@@ -281,14 +293,33 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
281293
if (ArgToLoadsMap.count(&Arg)) {
282294
for (LoadInst *LI : ArgToLoadsMap[&Arg]) {
283295
NewArgIt->setName(LI->getName());
284-
VMap[LI] = &*NewArgIt++;
296+
Value *NewArg = &*NewArgIt++;
297+
// Only insert cast if we're dealing with pointers
298+
if (isa<PointerType>(NewArg->getType()) &&
299+
isa<PointerType>(LI->getType())) {
300+
IRBuilder<> Builder(LI);
301+
Value *CastedArg = Builder.CreatePointerBitCastOrAddrSpaceCast(
302+
NewArg, LI->getType());
303+
VMap[LI] = CastedArg;
304+
} else {
305+
VMap[LI] = NewArg;
306+
}
285307
}
286-
// After replacing loads, replace uses of Arg with Undef
287308
UndefValue *UndefArg = UndefValue::get(Arg.getType());
288309
Arg.replaceAllUsesWith(UndefArg);
289310
} else {
290311
NewArgIt->setName(Arg.getName());
291-
Arg.replaceAllUsesWith(&*NewArgIt);
312+
Value *NewArg = &*NewArgIt;
313+
// Only insert cast if we're dealing with pointers
314+
if (isa<PointerType>(NewArg->getType()) &&
315+
isa<PointerType>(Arg.getType())) {
316+
IRBuilder<> Builder(&*NewF->begin()->begin());
317+
Value *CastedArg =
318+
Builder.CreatePointerBitCastOrAddrSpaceCast(NewArg, Arg.getType());
319+
Arg.replaceAllUsesWith(CastedArg);
320+
} else {
321+
Arg.replaceAllUsesWith(NewArg);
322+
}
292323
++NewArgIt;
293324
}
294325
}

0 commit comments

Comments
 (0)