diff --git a/src/interfaces.ts b/src/interfaces.ts index 7b42d7c..1ee60d3 100644 --- a/src/interfaces.ts +++ b/src/interfaces.ts @@ -1,4 +1,4 @@ -import type { SessionStorage, SessionIdStorageStrategy, data } from 'react-router'; +import type { SessionStorage, SessionIdStorageStrategy, data, SessionData } from 'react-router'; import type { OauthTokens, User } from '@workos-inc/node'; export type DataWithResponseInit = ReturnType>; @@ -16,6 +16,19 @@ export interface AuthLoaderSuccessData { user: User; } +export interface RefreshErrorOptions { + error: unknown; + request: Request; + sessionData: SessionData; +} + +export interface RefreshSuccessOptions { + accessToken: string; + user: User; + impersonator: Impersonator | null; + organizationId: string | null; +} + export interface Impersonator { email: string; reason: string | null; @@ -67,6 +80,8 @@ export interface GetAuthURLOptions { export type AuthKitLoaderOptions = { ensureSignedIn?: boolean; debug?: boolean; + onSessionRefreshError?: (options: RefreshErrorOptions) => void | Response | Promise; + onSessionRefreshSuccess?: (options: RefreshSuccessOptions) => void | Promise; } & ( | { storage?: never; diff --git a/src/session.spec.ts b/src/session.spec.ts index 3074236..54794d0 100644 --- a/src/session.spec.ts +++ b/src/session.spec.ts @@ -559,6 +559,43 @@ describe('session', () => { expect(response.headers.get('Set-Cookie')).toBe('destroyed-session-cookie'); } }); + + it('calls onSessionRefreshSuccess when provided', async () => { + const onSessionRefreshSuccess = jest.fn(); + await authkitLoader(createLoaderArgs(createMockRequest()), { + onSessionRefreshSuccess, + }); + + expect(onSessionRefreshSuccess).toHaveBeenCalled(); + }); + + it('calls onSessionRefreshError when provided and refresh fails', async () => { + authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid')); + const onSessionRefreshError = jest.fn().mockReturnValue(redirect('/error')); + + await authkitLoader(createLoaderArgs(createMockRequest()), { + onSessionRefreshError, + }); + + expect(onSessionRefreshError).toHaveBeenCalled(); + }); + + it('allows redirect from onSessionRefreshError callback', async () => { + authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid')); + + try { + await authkitLoader(createLoaderArgs(createMockRequest()), { + onSessionRefreshError: () => { + throw redirect('/'); + }, + }); + fail('Expected redirect response to be thrown'); + } catch (response: unknown) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toBe('/'); + } + }); }); }); diff --git a/src/session.ts b/src/session.ts index 87d569d..6e93197 100644 --- a/src/session.ts +++ b/src/session.ts @@ -22,6 +22,13 @@ export type TypedResponse = Response & { json(): Promise; }; +export class SessionRefreshError extends Error { + constructor(cause: unknown) { + super('Session refresh error', { cause }); + this.name = 'SessionRefreshError'; + } +} + /** * This function is used to refresh the session by using the refresh token. * It will authenticate the user with the refresh token and return a new session object. @@ -89,7 +96,7 @@ export async function refreshSession(request: Request, { organizationId }: { org async function updateSession(request: Request, debug: boolean) { const session = await getSessionFromCookie(request.headers.get('Cookie') as string); - const { commitSession, getSession, destroySession } = await getSessionStorage(); + const { commitSession, getSession } = await getSessionStorage(); // If no session, just continue if (!session) { @@ -138,13 +145,7 @@ async function updateSession(request: Request, debug: boolean) { // istanbul ignore next if (debug) console.log('Failed to refresh. Deleting cookie and redirecting.', e); - const cookieSession = await getSession(request.headers.get('Cookie')); - - throw redirect('/', { - headers: { - 'Set-Cookie': await destroySession(cookieSession), - }, - }); + throw new SessionRefreshError(e); } } @@ -287,6 +288,8 @@ export async function authkitLoader( const { ensureSignedIn = false, debug = false, + onSessionRefreshSuccess, + onSessionRefreshError, storage, cookie, } = typeof loaderOrOptions === 'object' ? loaderOrOptions : options; @@ -295,62 +298,107 @@ export async function authkitLoader( const { getSession, destroySession } = await configureSessionStorage({ storage, cookieName }); const { request } = loaderArgs; - const session = await updateSession(request, debug); - if (!session) { - if (ensureSignedIn) { - const returnPathname = getReturnPathname(request.url); + try { + // Try to get session, this might throw SessionRefreshError + const session = await updateSession(request, debug); + + if (!session) { + // No session found case (not authenticated) + if (ensureSignedIn) { + const returnPathname = getReturnPathname(request.url); + const cookieSession = await getSession(request.headers.get('Cookie')); + + throw redirect(await getAuthorizationUrl({ returnPathname }), { + headers: { + 'Set-Cookie': await destroySession(cookieSession), + }, + }); + } + + const auth: UnauthorizedData = { + user: null, + accessToken: null, + impersonator: null, + organizationId: null, + permissions: null, + entitlements: null, + role: null, + sessionId: null, + sealedSession: null, + }; + + return await handleAuthLoader(loader, loaderArgs, auth); + } + + // Session found and valid (or refreshed successfully) + const { + sessionId, + organizationId = null, + role = null, + permissions = [], + entitlements = [], + } = getClaimsFromAccessToken(session.accessToken); + + const cookieSession = await getSession(request.headers.get('Cookie')); + const { impersonator = null } = session; + + // checking for 'headers' in session determines if the session was refreshed or not + if (onSessionRefreshSuccess && 'headers' in session) { + await onSessionRefreshSuccess({ + accessToken: session.accessToken, + user: session.user, + impersonator, + organizationId, + }); + } + + const auth: AuthorizedData = { + user: session.user, + sessionId, + accessToken: session.accessToken, + organizationId, + role, + permissions, + entitlements, + impersonator, + sealedSession: cookieSession.get('jwt'), + }; + + return await handleAuthLoader(loader, loaderArgs, auth, session); + } catch (error) { + if (error instanceof SessionRefreshError) { const cookieSession = await getSession(request.headers.get('Cookie')); - throw redirect(await getAuthorizationUrl({ returnPathname }), { + if (onSessionRefreshError) { + try { + const result = await onSessionRefreshError({ + error: error.cause, + request, + sessionData: cookieSession, + }); + + if (result instanceof Response) { + return result; + } + } catch (callbackError) { + // If callback throws a Response (like redirect), propagate it + if (callbackError instanceof Response) { + throw callbackError; + } + } + } + + throw redirect('/', { headers: { 'Set-Cookie': await destroySession(cookieSession), }, }); } - const auth: UnauthorizedData = { - user: null, - accessToken: null, - impersonator: null, - organizationId: null, - permissions: null, - entitlements: null, - role: null, - sessionId: null, - sealedSession: null, - }; - - return await handleAuthLoader(loader, loaderArgs, auth); + // Propagate other errors + throw error; } - - // istanbul ignore next - const { - sessionId, - organizationId = null, - role = null, - permissions = [], - entitlements = [], - } = getClaimsFromAccessToken(session.accessToken); - - const cookieSession = await getSession(request.headers.get('Cookie')); - - // istanbul ignore next - const { impersonator = null } = session; - - const auth: AuthorizedData = { - user: session.user, - sessionId, - accessToken: session.accessToken, - organizationId, - role, - permissions, - entitlements, - impersonator, - sealedSession: cookieSession.get('jwt'), - }; - - return await handleAuthLoader(loader, loaderArgs, auth, session); } async function handleAuthLoader(