@@ -93,13 +93,24 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
9393 SmallVector<Argument *, 8 > StructArgs;
9494 SmallVector<Type *, 8 > NewArgTypes;
9595
96+ // Helper function to convert address space if needed
97+ auto convertAddressSpace = [](Type *Ty) -> Type * {
98+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
99+ if (PtrTy->getAddressSpace () == AMDGPUAS::FLAT_ADDRESS) {
100+ // Create a new pointer type in the global address space
101+ return PointerType::get (PtrTy->getContext (), AMDGPUAS::GLOBAL_ADDRESS);
102+ }
103+ }
104+ return Ty;
105+ };
106+
96107 // Collect struct arguments and new argument types
97108 unsigned OriginalArgIndex = 0 ;
98109 unsigned NewArgIndex = 0 ;
99110 for (Argument &Arg : F.args ()) {
100111 LLVM_DEBUG (dbgs () << " Processing argument: " << Arg << " \n " );
101112 if (Arg.use_empty ()) {
102- NewArgTypes.push_back (Arg.getType ());
113+ NewArgTypes.push_back (convertAddressSpace ( Arg.getType () ));
103114 NewArgMappings.push_back (
104115 std::make_tuple (NewArgIndex, OriginalArgIndex, 0 ));
105116 ++NewArgIndex;
@@ -181,7 +192,8 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
181192 ArgToLoadsMap[&Arg] = Loads;
182193 ArgToGEPsMap[&Arg] = GEPs;
183194 for (LoadInst *LI : Loads) {
184- NewArgTypes.push_back (LI->getType ());
195+ Type *NewType = convertAddressSpace (LI->getType ());
196+ NewArgTypes.push_back (NewType);
185197
186198 // Compute offset
187199 uint64_t Offset = 0 ;
@@ -198,7 +210,7 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
198210 ++NewArgIndex;
199211 }
200212 } else {
201- NewArgTypes.push_back (Arg.getType ());
213+ NewArgTypes.push_back (convertAddressSpace ( Arg.getType () ));
202214 // Include mapping if indices have changed
203215 if (NewArgIndex != OriginalArgIndex)
204216 NewArgMappings.push_back (
@@ -281,14 +293,33 @@ bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
281293 if (ArgToLoadsMap.count (&Arg)) {
282294 for (LoadInst *LI : ArgToLoadsMap[&Arg]) {
283295 NewArgIt->setName (LI->getName ());
284- VMap[LI] = &*NewArgIt++;
296+ Value *NewArg = &*NewArgIt++;
297+ // Only insert cast if we're dealing with pointers
298+ if (isa<PointerType>(NewArg->getType ()) &&
299+ isa<PointerType>(LI->getType ())) {
300+ IRBuilder<> Builder (LI);
301+ Value *CastedArg = Builder.CreatePointerBitCastOrAddrSpaceCast (
302+ NewArg, LI->getType ());
303+ VMap[LI] = CastedArg;
304+ } else {
305+ VMap[LI] = NewArg;
306+ }
285307 }
286- // After replacing loads, replace uses of Arg with Undef
287308 UndefValue *UndefArg = UndefValue::get (Arg.getType ());
288309 Arg.replaceAllUsesWith (UndefArg);
289310 } else {
290311 NewArgIt->setName (Arg.getName ());
291- Arg.replaceAllUsesWith (&*NewArgIt);
312+ Value *NewArg = &*NewArgIt;
313+ // Only insert cast if we're dealing with pointers
314+ if (isa<PointerType>(NewArg->getType ()) &&
315+ isa<PointerType>(Arg.getType ())) {
316+ IRBuilder<> Builder (&*NewF->begin ()->begin ());
317+ Value *CastedArg =
318+ Builder.CreatePointerBitCastOrAddrSpaceCast (NewArg, Arg.getType ());
319+ Arg.replaceAllUsesWith (CastedArg);
320+ } else {
321+ Arg.replaceAllUsesWith (NewArg);
322+ }
292323 ++NewArgIt;
293324 }
294325 }
0 commit comments