Skip to content

Commit 3712085

Browse files
authored
add onSessionRefreshSuccess and onSessionRefreshError (#23)
1 parent 4a16303 commit 3712085

File tree

3 files changed

+155
-55
lines changed

3 files changed

+155
-55
lines changed

src/interfaces.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { SessionStorage, SessionIdStorageStrategy, data } from 'react-router';
1+
import type { SessionStorage, SessionIdStorageStrategy, data, SessionData } from 'react-router';
22
import type { OauthTokens, User } from '@workos-inc/node';
33

44
export type DataWithResponseInit<T> = ReturnType<typeof data<T>>;
@@ -16,6 +16,19 @@ export interface AuthLoaderSuccessData {
1616
user: User;
1717
}
1818

19+
export interface RefreshErrorOptions {
20+
error: unknown;
21+
request: Request;
22+
sessionData: SessionData;
23+
}
24+
25+
export interface RefreshSuccessOptions {
26+
accessToken: string;
27+
user: User;
28+
impersonator: Impersonator | null;
29+
organizationId: string | null;
30+
}
31+
1932
export interface Impersonator {
2033
email: string;
2134
reason: string | null;
@@ -67,6 +80,8 @@ export interface GetAuthURLOptions {
6780
export type AuthKitLoaderOptions = {
6881
ensureSignedIn?: boolean;
6982
debug?: boolean;
83+
onSessionRefreshError?: (options: RefreshErrorOptions) => void | Response | Promise<void | Response>;
84+
onSessionRefreshSuccess?: (options: RefreshSuccessOptions) => void | Promise<void>;
7085
} & (
7186
| {
7287
storage?: never;

src/session.spec.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,43 @@ describe('session', () => {
559559
expect(response.headers.get('Set-Cookie')).toBe('destroyed-session-cookie');
560560
}
561561
});
562+
563+
it('calls onSessionRefreshSuccess when provided', async () => {
564+
const onSessionRefreshSuccess = jest.fn();
565+
await authkitLoader(createLoaderArgs(createMockRequest()), {
566+
onSessionRefreshSuccess,
567+
});
568+
569+
expect(onSessionRefreshSuccess).toHaveBeenCalled();
570+
});
571+
572+
it('calls onSessionRefreshError when provided and refresh fails', async () => {
573+
authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid'));
574+
const onSessionRefreshError = jest.fn().mockReturnValue(redirect('/error'));
575+
576+
await authkitLoader(createLoaderArgs(createMockRequest()), {
577+
onSessionRefreshError,
578+
});
579+
580+
expect(onSessionRefreshError).toHaveBeenCalled();
581+
});
582+
583+
it('allows redirect from onSessionRefreshError callback', async () => {
584+
authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid'));
585+
586+
try {
587+
await authkitLoader(createLoaderArgs(createMockRequest()), {
588+
onSessionRefreshError: () => {
589+
throw redirect('/');
590+
},
591+
});
592+
fail('Expected redirect response to be thrown');
593+
} catch (response: unknown) {
594+
assertIsResponse(response);
595+
expect(response.status).toBe(302);
596+
expect(response.headers.get('Location')).toBe('/');
597+
}
598+
});
562599
});
563600
});
564601

src/session.ts

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ export type TypedResponse<T> = Response & {
2222
json(): Promise<T>;
2323
};
2424

25+
export class SessionRefreshError extends Error {
26+
constructor(cause: unknown) {
27+
super('Session refresh error', { cause });
28+
this.name = 'SessionRefreshError';
29+
}
30+
}
31+
2532
/**
2633
* This function is used to refresh the session by using the refresh token.
2734
* It will authenticate the user with the refresh token and return a new session object.
@@ -89,7 +96,7 @@ export async function refreshSession(request: Request, { organizationId }: { org
8996

9097
async function updateSession(request: Request, debug: boolean) {
9198
const session = await getSessionFromCookie(request.headers.get('Cookie') as string);
92-
const { commitSession, getSession, destroySession } = await getSessionStorage();
99+
const { commitSession, getSession } = await getSessionStorage();
93100

94101
// If no session, just continue
95102
if (!session) {
@@ -138,13 +145,7 @@ async function updateSession(request: Request, debug: boolean) {
138145
// istanbul ignore next
139146
if (debug) console.log('Failed to refresh. Deleting cookie and redirecting.', e);
140147

141-
const cookieSession = await getSession(request.headers.get('Cookie'));
142-
143-
throw redirect('/', {
144-
headers: {
145-
'Set-Cookie': await destroySession(cookieSession),
146-
},
147-
});
148+
throw new SessionRefreshError(e);
148149
}
149150
}
150151

@@ -287,6 +288,8 @@ export async function authkitLoader<Data = unknown>(
287288
const {
288289
ensureSignedIn = false,
289290
debug = false,
291+
onSessionRefreshSuccess,
292+
onSessionRefreshError,
290293
storage,
291294
cookie,
292295
} = typeof loaderOrOptions === 'object' ? loaderOrOptions : options;
@@ -295,62 +298,107 @@ export async function authkitLoader<Data = unknown>(
295298
const { getSession, destroySession } = await configureSessionStorage({ storage, cookieName });
296299

297300
const { request } = loaderArgs;
298-
const session = await updateSession(request, debug);
299301

300-
if (!session) {
301-
if (ensureSignedIn) {
302-
const returnPathname = getReturnPathname(request.url);
302+
try {
303+
// Try to get session, this might throw SessionRefreshError
304+
const session = await updateSession(request, debug);
305+
306+
if (!session) {
307+
// No session found case (not authenticated)
308+
if (ensureSignedIn) {
309+
const returnPathname = getReturnPathname(request.url);
310+
const cookieSession = await getSession(request.headers.get('Cookie'));
311+
312+
throw redirect(await getAuthorizationUrl({ returnPathname }), {
313+
headers: {
314+
'Set-Cookie': await destroySession(cookieSession),
315+
},
316+
});
317+
}
318+
319+
const auth: UnauthorizedData = {
320+
user: null,
321+
accessToken: null,
322+
impersonator: null,
323+
organizationId: null,
324+
permissions: null,
325+
entitlements: null,
326+
role: null,
327+
sessionId: null,
328+
sealedSession: null,
329+
};
330+
331+
return await handleAuthLoader(loader, loaderArgs, auth);
332+
}
333+
334+
// Session found and valid (or refreshed successfully)
335+
const {
336+
sessionId,
337+
organizationId = null,
338+
role = null,
339+
permissions = [],
340+
entitlements = [],
341+
} = getClaimsFromAccessToken(session.accessToken);
342+
343+
const cookieSession = await getSession(request.headers.get('Cookie'));
344+
const { impersonator = null } = session;
345+
346+
// checking for 'headers' in session determines if the session was refreshed or not
347+
if (onSessionRefreshSuccess && 'headers' in session) {
348+
await onSessionRefreshSuccess({
349+
accessToken: session.accessToken,
350+
user: session.user,
351+
impersonator,
352+
organizationId,
353+
});
354+
}
355+
356+
const auth: AuthorizedData = {
357+
user: session.user,
358+
sessionId,
359+
accessToken: session.accessToken,
360+
organizationId,
361+
role,
362+
permissions,
363+
entitlements,
364+
impersonator,
365+
sealedSession: cookieSession.get('jwt'),
366+
};
367+
368+
return await handleAuthLoader(loader, loaderArgs, auth, session);
369+
} catch (error) {
370+
if (error instanceof SessionRefreshError) {
303371
const cookieSession = await getSession(request.headers.get('Cookie'));
304372

305-
throw redirect(await getAuthorizationUrl({ returnPathname }), {
373+
if (onSessionRefreshError) {
374+
try {
375+
const result = await onSessionRefreshError({
376+
error: error.cause,
377+
request,
378+
sessionData: cookieSession,
379+
});
380+
381+
if (result instanceof Response) {
382+
return result;
383+
}
384+
} catch (callbackError) {
385+
// If callback throws a Response (like redirect), propagate it
386+
if (callbackError instanceof Response) {
387+
throw callbackError;
388+
}
389+
}
390+
}
391+
392+
throw redirect('/', {
306393
headers: {
307394
'Set-Cookie': await destroySession(cookieSession),
308395
},
309396
});
310397
}
311398

312-
const auth: UnauthorizedData = {
313-
user: null,
314-
accessToken: null,
315-
impersonator: null,
316-
organizationId: null,
317-
permissions: null,
318-
entitlements: null,
319-
role: null,
320-
sessionId: null,
321-
sealedSession: null,
322-
};
323-
324-
return await handleAuthLoader(loader, loaderArgs, auth);
399+
// Propagate other errors
400+
throw error;
325401
}
326-
327-
// istanbul ignore next
328-
const {
329-
sessionId,
330-
organizationId = null,
331-
role = null,
332-
permissions = [],
333-
entitlements = [],
334-
} = getClaimsFromAccessToken(session.accessToken);
335-
336-
const cookieSession = await getSession(request.headers.get('Cookie'));
337-
338-
// istanbul ignore next
339-
const { impersonator = null } = session;
340-
341-
const auth: AuthorizedData = {
342-
user: session.user,
343-
sessionId,
344-
accessToken: session.accessToken,
345-
organizationId,
346-
role,
347-
permissions,
348-
entitlements,
349-
impersonator,
350-
sealedSession: cookieSession.get('jwt'),
351-
};
352-
353-
return await handleAuthLoader(loader, loaderArgs, auth, session);
354402
}
355403

356404
async function handleAuthLoader(

0 commit comments

Comments
 (0)