@@ -766,10 +766,6 @@ async function generateRenderResponse(
766766 // will happen in the subsequent revalidation request
767767 let statusCode = 200 ;
768768 let url = new URL ( request . url ) ;
769- // TODO: Can this be done with a pathname extension instead of a header?
770- // If not, make sure we strip this at the SSR server and it can only be set
771- // by us to avoid cache-poisoning attempts
772-
773769 let isSubmission = isMutationMethod ( request . method ) ;
774770 let routeIdsToLoad =
775771 ! isSubmission && url . searchParams . has ( "_routes" )
@@ -804,55 +800,61 @@ async function generateRenderResponse(
804800 // POST `request` to `query` and process our action there.
805801 let formState : unknown ;
806802 let skipRevalidation = false ;
803+ let potentialCSRFAttackError : unknown | undefined ;
807804 if ( request . method === "POST" ) {
808- throwIfPotentialCSRFAttack ( request . headers , allowedActionOrigins ) ;
809-
810- ctx . runningAction = true ;
811- let result = await processServerAction (
812- request ,
813- basename ,
814- decodeReply ,
815- loadServerAction ,
816- decodeAction ,
817- decodeFormState ,
818- onError ,
819- temporaryReferences ,
820- ) ;
821- ctx . runningAction = false ;
822-
823- if ( isResponse ( result ) ) {
824- return generateRedirectResponse (
825- result ,
826- actionResult ,
827- basename ,
828- isDataRequest ,
829- generateResponse ,
830- temporaryReferences ,
831- ( ctx . redirect as unknown as Response ) ?. headers ,
832- ) ;
833- }
805+ try {
806+ throwIfPotentialCSRFAttack ( request . headers , allowedActionOrigins ) ;
834807
835- skipRevalidation = result ?. skipRevalidation ?? false ;
836- actionResult = result ?. actionResult ;
837- formState = result ?. formState ;
838- request = result ?. revalidationRequest ?? request ;
839-
840- if ( ctx . redirect ) {
841- return generateRedirectResponse (
842- ctx . redirect ,
843- actionResult ,
808+ ctx . runningAction = true ;
809+ let result = await processServerAction (
810+ request ,
844811 basename ,
845- isDataRequest ,
846- generateResponse ,
812+ decodeReply ,
813+ loadServerAction ,
814+ decodeAction ,
815+ decodeFormState ,
816+ onError ,
847817 temporaryReferences ,
848- undefined ,
849- ) ;
818+ ) . finally ( ( ) => {
819+ ctx . runningAction = false ;
820+ } ) ;
821+
822+ if ( isResponse ( result ) ) {
823+ return generateRedirectResponse (
824+ result ,
825+ actionResult ,
826+ basename ,
827+ isDataRequest ,
828+ generateResponse ,
829+ temporaryReferences ,
830+ ( ctx . redirect as unknown as Response ) ?. headers ,
831+ ) ;
832+ }
833+
834+ skipRevalidation = result ?. skipRevalidation ?? false ;
835+ actionResult = result ?. actionResult ;
836+ formState = result ?. formState ;
837+ request = result ?. revalidationRequest ?? request ;
838+
839+ if ( ctx . redirect ) {
840+ return generateRedirectResponse (
841+ ctx . redirect ,
842+ actionResult ,
843+ basename ,
844+ isDataRequest ,
845+ generateResponse ,
846+ temporaryReferences ,
847+ undefined ,
848+ ) ;
849+ }
850+ } catch ( error ) {
851+ potentialCSRFAttackError = error ;
850852 }
851853 }
852854
853855 let staticContext = await query (
854856 request ,
855- skipRevalidation
857+ skipRevalidation || ! ! potentialCSRFAttackError
856858 ? {
857859 filterMatchesToLoad : ( ) => false ,
858860 }
@@ -871,6 +873,13 @@ async function generateRenderResponse(
871873 ) ;
872874 }
873875
876+ if ( potentialCSRFAttackError ) {
877+ staticContext . errors ??= { } ;
878+ staticContext . errors [ staticContext . matches [ 0 ] . route . id ] =
879+ potentialCSRFAttackError ;
880+ staticContext . statusCode = 400 ;
881+ }
882+
874883 return generateStaticContextResponse (
875884 routes ,
876885 basename ,
0 commit comments