@@ -8086,13 +8086,76 @@ static SDValue getZT0FrameIndex(MachineFrameInfo &MFI,
80868086 DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
80878087}
80888088
8089+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8090+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8091+ SelectionDAG &DAG,
8092+ AArch64FunctionInfo *Info, SDLoc DL,
8093+ SDValue Chain, bool IsSave) {
8094+ MachineFunction &MF = DAG.getMachineFunction();
8095+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8096+ FuncInfo->setSMESaveBufferUsed();
8097+ TargetLowering::ArgListTy Args;
8098+ Args.emplace_back(
8099+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
8100+ PointerType::getUnqual(*DAG.getContext()));
8101+
8102+ RTLIB::Libcall LC =
8103+ IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
8104+ SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
8105+ TLI.getPointerTy(DAG.getDataLayout()));
8106+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8107+ TargetLowering::CallLoweringInfo CLI(DAG);
8108+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8109+ TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
8110+ return TLI.LowerCallTo(CLI).second;
8111+ }
8112+
8113+ static SDValue emitRestoreZALazySave(SDValue Chain, SDLoc DL,
8114+ const AArch64TargetLowering &TLI,
8115+ const AArch64RegisterInfo &TRI,
8116+ AArch64FunctionInfo &FuncInfo,
8117+ SelectionDAG &DAG) {
8118+ // Conditionally restore the lazy save using a pseudo node.
8119+ RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
8120+ TPIDR2Object &TPIDR2 = FuncInfo.getTPIDR2Obj();
8121+ SDValue RegMask = DAG.getRegisterMask(TRI.getCallPreservedMask(
8122+ DAG.getMachineFunction(), TLI.getLibcallCallingConv(LC)));
8123+ SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8124+ TLI.getLibcallName(LC), TLI.getPointerTy(DAG.getDataLayout()));
8125+ SDValue TPIDR2_EL0 = DAG.getNode(
8126+ ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Chain,
8127+ DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8128+ // Copy the address of the TPIDR2 block into X0 before 'calling' the
8129+ // RESTORE_ZA pseudo.
8130+ SDValue Glue;
8131+ SDValue TPIDR2Block = DAG.getFrameIndex(
8132+ TPIDR2.FrameIndex,
8133+ DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8134+ Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, TPIDR2Block, Glue);
8135+ Chain =
8136+ DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8137+ {Chain, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8138+ RestoreRoutine, RegMask, Chain.getValue(1)});
8139+ // Finally reset the TPIDR2_EL0 register to 0.
8140+ Chain = DAG.getNode(
8141+ ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8142+ DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8143+ DAG.getConstant(0, DL, MVT::i64));
8144+ TPIDR2.Uses++;
8145+ return Chain;
8146+ }
8147+
80898148SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
80908149 SelectionDAG &DAG) const {
80918150 assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
80928151 SDValue Glue = Chain.getValue(1);
80938152
80948153 MachineFunction &MF = DAG.getMachineFunction();
8095- SMEAttrs SMEFnAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
8154+ auto &FuncInfo = *MF.getInfo<AArch64FunctionInfo>();
8155+ auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
8156+ const AArch64RegisterInfo &TRI = *Subtarget.getRegisterInfo();
8157+
8158+ SMEAttrs SMEFnAttrs = FuncInfo.getSMEFnAttrs();
80968159
80978160 // The following conditions are true on entry to an exception handler:
80988161 // - PSTATE.SM is 0.
@@ -8107,14 +8170,43 @@ SDValue AArch64TargetLowering::lowerEHPadEntry(SDValue Chain, SDLoc const &DL,
81078170 // These mode changes are usually optimized away in catch blocks as they
81088171 // occur before the __cxa_begin_catch (which is a non-streaming function),
81098172 // but are necessary in some cases (such as for cleanups).
8173+ //
8174+ // Additionally, if the function has ZA or ZT0 state, we must restore it.
81108175
8176+ // [COND_]SMSTART SM
81118177 if (SMEFnAttrs.hasStreamingInterfaceOrBody())
8112- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8113- /*Glue*/ Glue, AArch64SME::Always);
8178+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain,
8179+ /*Glue*/ Glue, AArch64SME::Always);
8180+ else if (SMEFnAttrs.hasStreamingCompatibleInterface())
8181+ Chain = changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8182+ AArch64SME::IfCallerIsStreaming);
8183+
8184+ if (getTM().useNewSMEABILowering())
8185+ return Chain;
81148186
8115- if (SMEFnAttrs.hasStreamingCompatibleInterface())
8116- return changeStreamingMode(DAG, DL, /*Enable=*/true, Chain, Glue,
8117- AArch64SME::IfCallerIsStreaming);
8187+ if (SMEFnAttrs.hasAgnosticZAInterface()) {
8188+ // Restore full ZA
8189+ Chain = emitSMEStateSaveRestore(*this, DAG, &FuncInfo, DL, Chain,
8190+ /*IsSave=*/false);
8191+ } else if (SMEFnAttrs.hasZAState() || SMEFnAttrs.hasZT0State()) {
8192+ // SMSTART ZA
8193+ Chain = DAG.getNode(
8194+ AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
8195+ DAG.getTargetConstant(int32_t(AArch64SVCR::SVCRZA), DL, MVT::i32));
8196+
8197+ // Restore ZT0
8198+ if (SMEFnAttrs.hasZT0State()) {
8199+ SDValue ZT0FrameIndex =
8200+ getZT0FrameIndex(MF.getFrameInfo(), FuncInfo, DAG);
8201+ Chain =
8202+ DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8203+ {Chain, DAG.getConstant(0, DL, MVT::i32), ZT0FrameIndex});
8204+ }
8205+
8206+ // Restore ZA
8207+ if (SMEFnAttrs.hasZAState())
8208+ Chain = emitRestoreZALazySave(Chain, DL, *this, TRI, FuncInfo, DAG);
8209+ }
81188210
81198211 return Chain;
81208212}
@@ -9232,30 +9324,6 @@ SDValue AArch64TargetLowering::changeStreamingMode(
92329324 return GetCheckVL(SMChange.getValue(0), SMChange.getValue(1));
92339325}
92349326
9235- // Emit a call to __arm_sme_save or __arm_sme_restore.
9236- static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
9237- SelectionDAG &DAG,
9238- AArch64FunctionInfo *Info, SDLoc DL,
9239- SDValue Chain, bool IsSave) {
9240- MachineFunction &MF = DAG.getMachineFunction();
9241- AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9242- FuncInfo->setSMESaveBufferUsed();
9243- TargetLowering::ArgListTy Args;
9244- Args.emplace_back(
9245- DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64),
9246- PointerType::getUnqual(*DAG.getContext()));
9247-
9248- RTLIB::Libcall LC =
9249- IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE;
9250- SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
9251- TLI.getPointerTy(DAG.getDataLayout()));
9252- auto *RetTy = Type::getVoidTy(*DAG.getContext());
9253- TargetLowering::CallLoweringInfo CLI(DAG);
9254- CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
9255- TLI.getLibcallCallingConv(LC), RetTy, Callee, std::move(Args));
9256- return TLI.LowerCallTo(CLI).second;
9257- }
9258-
92599327static AArch64SME::ToggleCondition
92609328getSMToggleCondition(const SMECallAttrs &CallAttrs) {
92619329 if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
@@ -10015,33 +10083,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1001510083 {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
1001610084
1001710085 if (RequiresLazySave) {
10018- // Conditionally restore the lazy save using a pseudo node.
10019- RTLIB::Libcall LC = RTLIB::SMEABI_TPIDR2_RESTORE;
10020- TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
10021- SDValue RegMask = DAG.getRegisterMask(
10022- TRI->getCallPreservedMask(MF, getLibcallCallingConv(LC)));
10023- SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
10024- getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
10025- SDValue TPIDR2_EL0 = DAG.getNode(
10026- ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
10027- DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
10028- // Copy the address of the TPIDR2 block into X0 before 'calling' the
10029- // RESTORE_ZA pseudo.
10030- SDValue Glue;
10031- SDValue TPIDR2Block = DAG.getFrameIndex(
10032- TPIDR2.FrameIndex,
10033- DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
10034- Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
10035- Result =
10036- DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
10037- {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
10038- RestoreRoutine, RegMask, Result.getValue(1)});
10039- // Finally reset the TPIDR2_EL0 register to 0.
10040- Result = DAG.getNode(
10041- ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
10042- DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
10043- DAG.getConstant(0, DL, MVT::i64));
10044- TPIDR2.Uses++;
10086+ Result = emitRestoreZALazySave(Result, DL, *this, *TRI, *FuncInfo, DAG);
1004510087 } else if (RequiresSaveAllZA) {
1004610088 Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Result,
1004710089 /*IsSave=*/false);
0 commit comments