Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/interfaces.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { SessionStorage, SessionIdStorageStrategy, data } from 'react-router';
import type { SessionStorage, SessionIdStorageStrategy, data, SessionData } from 'react-router';
import type { OauthTokens, User } from '@workos-inc/node';

export type DataWithResponseInit<T> = ReturnType<typeof data<T>>;
Expand All @@ -16,6 +16,19 @@ export interface AuthLoaderSuccessData {
user: User;
}

export interface RefreshErrorOptions {
error: unknown;
request: Request;
sessionData: SessionData;
}

export interface RefreshSuccessOptions {
accessToken: string;
user: User;
impersonator: Impersonator | null;
organizationId: string | null;
}

export interface Impersonator {
email: string;
reason: string | null;
Expand Down Expand Up @@ -67,6 +80,8 @@ export interface GetAuthURLOptions {
export type AuthKitLoaderOptions = {
ensureSignedIn?: boolean;
debug?: boolean;
onSessionRefreshError?: (options: RefreshErrorOptions) => void | Response | Promise<void | Response>;
onSessionRefreshSuccess?: (options: RefreshSuccessOptions) => void | Promise<void>;
} & (
| {
storage?: never;
Expand Down
37 changes: 37 additions & 0 deletions src/session.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,43 @@ describe('session', () => {
expect(response.headers.get('Set-Cookie')).toBe('destroyed-session-cookie');
}
});

it('calls onSessionRefreshSuccess when provided', async () => {
const onSessionRefreshSuccess = jest.fn();
await authkitLoader(createLoaderArgs(createMockRequest()), {
onSessionRefreshSuccess,
});

expect(onSessionRefreshSuccess).toHaveBeenCalled();
});

it('calls onSessionRefreshError when provided and refresh fails', async () => {
authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid'));
const onSessionRefreshError = jest.fn().mockReturnValue(redirect('/error'));

await authkitLoader(createLoaderArgs(createMockRequest()), {
onSessionRefreshError,
});

expect(onSessionRefreshError).toHaveBeenCalled();
});

it('allows redirect from onSessionRefreshError callback', async () => {
authenticateWithRefreshToken.mockRejectedValue(new Error('Refresh token invalid'));

try {
await authkitLoader(createLoaderArgs(createMockRequest()), {
onSessionRefreshError: () => {
throw redirect('/');
},
});
fail('Expected redirect response to be thrown');
} catch (response: unknown) {
assertIsResponse(response);
expect(response.status).toBe(302);
expect(response.headers.get('Location')).toBe('/');
}
});
});
});

Expand Down
156 changes: 102 additions & 54 deletions src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ export type TypedResponse<T> = Response & {
json(): Promise<T>;
};

export class SessionRefreshError extends Error {
constructor(cause: unknown) {
super('Session refresh error', { cause });
this.name = 'SessionRefreshError';
}
}

/**
* This function is used to refresh the session by using the refresh token.
* It will authenticate the user with the refresh token and return a new session object.
Expand Down Expand Up @@ -89,7 +96,7 @@ export async function refreshSession(request: Request, { organizationId }: { org

async function updateSession(request: Request, debug: boolean) {
const session = await getSessionFromCookie(request.headers.get('Cookie') as string);
const { commitSession, getSession, destroySession } = await getSessionStorage();
const { commitSession, getSession } = await getSessionStorage();

// If no session, just continue
if (!session) {
Expand Down Expand Up @@ -138,13 +145,7 @@ async function updateSession(request: Request, debug: boolean) {
// istanbul ignore next
if (debug) console.log('Failed to refresh. Deleting cookie and redirecting.', e);

const cookieSession = await getSession(request.headers.get('Cookie'));

throw redirect('/', {
headers: {
'Set-Cookie': await destroySession(cookieSession),
},
});
throw new SessionRefreshError(e);
}
}

Expand Down Expand Up @@ -287,6 +288,8 @@ export async function authkitLoader<Data = unknown>(
const {
ensureSignedIn = false,
debug = false,
onSessionRefreshSuccess,
onSessionRefreshError,
storage,
cookie,
} = typeof loaderOrOptions === 'object' ? loaderOrOptions : options;
Expand All @@ -295,62 +298,107 @@ export async function authkitLoader<Data = unknown>(
const { getSession, destroySession } = await configureSessionStorage({ storage, cookieName });

const { request } = loaderArgs;
const session = await updateSession(request, debug);

if (!session) {
if (ensureSignedIn) {
const returnPathname = getReturnPathname(request.url);
try {
// Try to get session, this might throw SessionRefreshError
const session = await updateSession(request, debug);

if (!session) {
// No session found case (not authenticated)
if (ensureSignedIn) {
const returnPathname = getReturnPathname(request.url);
const cookieSession = await getSession(request.headers.get('Cookie'));

throw redirect(await getAuthorizationUrl({ returnPathname }), {
headers: {
'Set-Cookie': await destroySession(cookieSession),
},
});
}

const auth: UnauthorizedData = {
user: null,
accessToken: null,
impersonator: null,
organizationId: null,
permissions: null,
entitlements: null,
role: null,
sessionId: null,
sealedSession: null,
};

return await handleAuthLoader(loader, loaderArgs, auth);
}

// Session found and valid (or refreshed successfully)
const {
sessionId,
organizationId = null,
role = null,
permissions = [],
entitlements = [],
} = getClaimsFromAccessToken(session.accessToken);

const cookieSession = await getSession(request.headers.get('Cookie'));
const { impersonator = null } = session;

// checking for 'headers' in session determines if the session was refreshed or not
if (onSessionRefreshSuccess && 'headers' in session) {
await onSessionRefreshSuccess({
accessToken: session.accessToken,
user: session.user,
impersonator,
organizationId,
});
}

const auth: AuthorizedData = {
user: session.user,
sessionId,
accessToken: session.accessToken,
organizationId,
role,
permissions,
entitlements,
impersonator,
sealedSession: cookieSession.get('jwt'),
};

return await handleAuthLoader(loader, loaderArgs, auth, session);
} catch (error) {
if (error instanceof SessionRefreshError) {
const cookieSession = await getSession(request.headers.get('Cookie'));

throw redirect(await getAuthorizationUrl({ returnPathname }), {
if (onSessionRefreshError) {
try {
const result = await onSessionRefreshError({
error: error.cause,
request,
sessionData: cookieSession,
});

if (result instanceof Response) {
return result;
}
} catch (callbackError) {
// If callback throws a Response (like redirect), propagate it
if (callbackError instanceof Response) {
throw callbackError;
}
}
}

throw redirect('/', {
headers: {
'Set-Cookie': await destroySession(cookieSession),
},
});
}

const auth: UnauthorizedData = {
user: null,
accessToken: null,
impersonator: null,
organizationId: null,
permissions: null,
entitlements: null,
role: null,
sessionId: null,
sealedSession: null,
};

return await handleAuthLoader(loader, loaderArgs, auth);
// Propagate other errors
throw error;
}

// istanbul ignore next
const {
sessionId,
organizationId = null,
role = null,
permissions = [],
entitlements = [],
} = getClaimsFromAccessToken(session.accessToken);

const cookieSession = await getSession(request.headers.get('Cookie'));

// istanbul ignore next
const { impersonator = null } = session;

const auth: AuthorizedData = {
user: session.user,
sessionId,
accessToken: session.accessToken,
organizationId,
role,
permissions,
entitlements,
impersonator,
sealedSession: cookieSession.get('jwt'),
};

return await handleAuthLoader(loader, loaderArgs, auth, session);
}

async function handleAuthLoader(
Expand Down