@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
223223 doInsertBitcast (STI, MRI, GR, I, PtrReg, 0 , NewPtrType);
224224}
225225
226- static void validateGroupAsyncCopyPtr (const SPIRVSubtarget &STI,
227- MachineRegisterInfo *MRI,
228- SPIRVGlobalRegistry &GR, MachineInstr &I ,
229- unsigned OpIdx) {
226+ static void validatePtrUnwrapStructField (const SPIRVSubtarget &STI,
227+ MachineRegisterInfo *MRI,
228+ SPIRVGlobalRegistry &GR ,
229+ MachineInstr &I, unsigned OpIdx) {
230230 MachineFunction *MF = I.getParent ()->getParent ();
231231 Register OpReg = I.getOperand (OpIdx).getReg ();
232232 Register OpTypeReg = getTypeReg (MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
440440 validateLifetimeStart (STI, MRI, GR, MI);
441441 break ;
442442 case SPIRV::OpGroupAsyncCopy:
443- validateGroupAsyncCopyPtr (STI, MRI, GR, MI, 3 );
444- validateGroupAsyncCopyPtr (STI, MRI, GR, MI, 4 );
443+ validatePtrUnwrapStructField (STI, MRI, GR, MI, 3 );
444+ validatePtrUnwrapStructField (STI, MRI, GR, MI, 4 );
445445 break ;
446446 case SPIRV::OpGroupWaitEvents:
447447 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
467467 if (Type->getParent () == Curr && !Curr->pred_empty ())
468468 ToMove.insert (const_cast <MachineInstr *>(Type));
469469 } break ;
470+ case SPIRV::OpExtInst: {
471+ // prefetch
472+ if (!MI.getOperand (2 ).isImm () || !MI.getOperand (3 ).isImm () ||
473+ MI.getOperand (2 ).getImm () != SPIRV::InstructionSet::OpenCL_std)
474+ continue ;
475+ switch (MI.getOperand (3 ).getImm ()) {
476+ case SPIRV::OpenCLExtInst::remquo: {
477+ // The last operand must be of a pointer to the return type.
478+ MachineIRBuilder MIB (MI);
479+ SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType (32 , MIB);
480+ SPIRVType *RetType = MRI->getVRegDef (MI.getOperand (1 ).getReg ());
481+ assert (RetType && " Expected return type" );
482+ validatePtrTypes (
483+ STI, MRI, GR, MI, MI.getNumOperands () - 1 ,
484+ RetType->getOpcode () != SPIRV::OpTypeVector
485+ ? Int32Type
486+ : GR.getOrCreateSPIRVVectorType (
487+ Int32Type, RetType->getOperand (2 ).getImm (), MIB));
488+ } break ;
489+ case SPIRV::OpenCLExtInst::fract:
490+ case SPIRV::OpenCLExtInst::frexp:
491+ case SPIRV::OpenCLExtInst::lgamma_r:
492+ case SPIRV::OpenCLExtInst::modf:
493+ case SPIRV::OpenCLExtInst::sincos:
494+ // The last operand must be of a pointer to the base type represented
495+ // by the previous operand.
496+ assert (MI.getOperand (MI.getNumOperands () - 2 ).isReg () &&
497+ " Expected v-reg" );
498+ validatePtrTypes (
499+ STI, MRI, GR, MI, MI.getNumOperands () - 1 ,
500+ GR.getSPIRVTypeForVReg (
501+ MI.getOperand (MI.getNumOperands () - 2 ).getReg ()));
502+ break ;
503+ case SPIRV::OpenCLExtInst::prefetch:
504+ // Expected `ptr` type is a pointer to float, integer or vector, but
505+ // the pontee value can be wrapped into a struct.
506+ assert (MI.getOperand (MI.getNumOperands () - 2 ).isReg () &&
507+ " Expected v-reg" );
508+ validatePtrUnwrapStructField (STI, MRI, GR, MI,
509+ MI.getNumOperands () - 2 );
510+ break ;
511+ }
512+ } break ;
470513 }
471514 }
472515 for (MachineInstr *MI : ToMove) {
0 commit comments