diff --git a/src/auth.spec.ts b/src/auth.spec.ts index 46722cf..91af635 100644 --- a/src/auth.spec.ts +++ b/src/auth.spec.ts @@ -1,16 +1,41 @@ -import { getSignInUrl, getSignUpUrl, signOut } from './auth.js'; +import type { User } from '@workos-inc/node'; +import { data, redirect } from 'react-router'; +import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization } from './auth.js'; import * as authorizationUrl from './get-authorization-url.js'; import * as session from './session.js'; +import { assertIsResponse } from './test-utils/test-helpers.js'; const terminateSession = jest.mocked(session.terminateSession); +const refreshSession = jest.mocked(session.refreshSession); jest.mock('./session', () => ({ terminateSession: jest.fn().mockResolvedValue(new Response()), + refreshSession: jest.fn(), })); +// Mock redirect and data from react-router +jest.mock('react-router', () => { + const originalModule = jest.requireActual('react-router'); + return { + ...originalModule, + redirect: jest.fn().mockImplementation((to, init) => { + const response = new Response(null, { + status: 302, + headers: { Location: to, ...(init?.headers || {}) }, + }); + return response; + }), + data: jest.fn().mockImplementation((value, init) => ({ + data: value, + init, + })), + }; +}); + describe('auth', () => { beforeEach(() => { jest.spyOn(authorizationUrl, 'getAuthorizationUrl'); + jest.clearAllMocks(); }); describe('getSignInUrl', () => { @@ -39,4 +64,216 @@ describe('auth', () => { expect(terminateSession).toHaveBeenCalledWith(request); }); }); + + describe('switchToOrganization', () => { + const request = new Request('https://example.com'); + const organizationId = 'org_123456'; + + // Create a mock user that matches the User type + const mockUser = { + id: 'user-1', + email: 'test@example.com', + emailVerified: true, + firstName: 'Test', + lastName: 'User', + profilePictureUrl: 'https://example.com/avatar.jpg', + object: 'user', + createdAt: '2021-01-01T00:00:00Z', + updatedAt: '2021-01-01T00:00:00Z', + lastSignInAt: '2021-01-01T00:00:00Z', + externalId: null, + } as User; + + // Mock the return type of refreshSession + const mockAuthResponse = { + user: mockUser, + sessionId: 'session-123', + accessToken: 'new-access-token', + organizationId: 'org_123456' as string | undefined, + role: 'admin' as string | undefined, + permissions: ['read', 'write'] as string[] | undefined, + entitlements: ['premium'] as string[] | undefined, + impersonator: null, + sealedSession: 'sealed-session-data', + headers: { + 'Set-Cookie': 'new-cookie-value', + }, + }; + + beforeEach(() => { + refreshSession.mockResolvedValue(mockAuthResponse); + }); + + it('should call refreshSession with the correct params', async () => { + await switchToOrganization(request, organizationId); + + expect(refreshSession).toHaveBeenCalledWith(request, { organizationId }); + }); + + it('should return data with success and auth when no returnTo is provided', async () => { + const result = await switchToOrganization(request, organizationId); + + expect(data).toHaveBeenCalledWith( + { success: true, auth: mockAuthResponse }, + { + headers: { + 'Set-Cookie': 'new-cookie-value', + }, + }, + ); + expect(result).toEqual({ + data: { success: true, auth: mockAuthResponse }, + init: { + headers: { + 'Set-Cookie': 'new-cookie-value', + }, + }, + }); + }); + + it('should redirect to returnTo when provided', async () => { + const returnTo = '/dashboard'; + const result = await switchToOrganization(request, organizationId, { returnTo }); + + expect(redirect).toHaveBeenCalledWith(returnTo, { + headers: { + 'Set-Cookie': 'new-cookie-value', + }, + }); + + assertIsResponse(result); + expect(result.status).toBe(302); + expect(result.headers.get('Location')).toBe(returnTo); + expect(result.headers.get('Set-Cookie')).toBe('new-cookie-value'); + }); + + it('should handle case when refreshSession throws a redirect', async () => { + const redirectResponse = new Response(null, { + status: 302, + headers: { Location: '/login' }, + }); + refreshSession.mockRejectedValueOnce(redirectResponse); + + try { + await switchToOrganization(request, organizationId); + fail('Expected redirect response to be thrown'); + } catch (response) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toBe('/login'); + } + }); + + it('should redirect to authorization URL for SSO_required errors', async () => { + const authUrl = 'https://api.workos.com/sso/authorize'; + const errorWithSSOCause = new Error('SSO Required', { + cause: { error: 'sso_required' }, + }); + + refreshSession.mockRejectedValueOnce(errorWithSSOCause); + (authorizationUrl.getAuthorizationUrl as jest.Mock).mockResolvedValueOnce(authUrl); + + const result = await switchToOrganization(request, organizationId); + + expect(authorizationUrl.getAuthorizationUrl).toHaveBeenCalled(); + expect(redirect).toHaveBeenCalledWith(authUrl); + + assertIsResponse(result); + expect(result.status).toBe(302); + expect(result.headers.get('Location')).toBe(authUrl); + }); + + it('should handle mfa_enrollment errors', async () => { + const authUrl = 'https://api.workos.com/sso/authorize'; + const errorWithMFACause = new Error('MFA Enrollment Required', { + cause: { error: 'mfa_enrollment' }, + }); + + refreshSession.mockRejectedValueOnce(errorWithMFACause); + (authorizationUrl.getAuthorizationUrl as jest.Mock).mockResolvedValueOnce(authUrl); + + const result = await switchToOrganization(request, organizationId); + + expect(authorizationUrl.getAuthorizationUrl).toHaveBeenCalled(); + expect(redirect).toHaveBeenCalledWith(authUrl); + + assertIsResponse(result); + expect(result.status).toBe(302); + expect(result.headers.get('Location')).toBe(authUrl); + }); + + it('should return error data for Error instances', async () => { + const error = new Error('Invalid organization'); + refreshSession.mockRejectedValueOnce(error); + + const result = await switchToOrganization(request, organizationId); + + expect(data).toHaveBeenCalledWith( + { + success: false, + error: 'Invalid organization', + }, + { status: 400 }, + ); + expect(result).toEqual({ + data: { + success: false, + error: 'Invalid organization', + }, + init: { status: 400 }, + }); + }); + + it('should return error data for non-Error objects', async () => { + const error = 'String error message'; + refreshSession.mockRejectedValueOnce(error); + + await switchToOrganization(request, organizationId); + + expect(data).toHaveBeenCalledWith( + { + success: false, + error: 'String error message', + }, + { status: 400 }, + ); + }); + + it('should handle when Set-Cookie header is missing', async () => { + // Create a mock without the Set-Cookie header + const mockResponseWithoutCookie = { + ...mockAuthResponse, + headers: {}, + }; + refreshSession.mockResolvedValueOnce(mockResponseWithoutCookie); + + await switchToOrganization(request, organizationId); + + expect(data).toHaveBeenCalledWith( + { success: true, auth: mockResponseWithoutCookie }, + { + headers: { + 'Set-Cookie': '', + }, + }, + ); + }); + + it('should handle when returnTo is provided but Set-Cookie header is missing', async () => { + // Create a mock without the Set-Cookie header + const mockResponseWithoutCookie = { + ...mockAuthResponse, + headers: {}, + }; + refreshSession.mockResolvedValueOnce(mockResponseWithoutCookie); + + await switchToOrganization(request, organizationId, { returnTo: '/dashboard' }); + + expect(redirect).toHaveBeenCalledWith('/dashboard', { + headers: { + 'Set-Cookie': '', + }, + }); + }); + }); }); diff --git a/src/auth.ts b/src/auth.ts index ace06b3..12ace7c 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -1,16 +1,69 @@ +import { data, redirect } from 'react-router'; import { getAuthorizationUrl } from './get-authorization-url.js'; -import { terminateSession } from './session.js'; +import { refreshSession, terminateSession } from './session.js'; -async function getSignInUrl(returnPathname?: string) { +export async function getSignInUrl(returnPathname?: string) { return getAuthorizationUrl({ returnPathname, screenHint: 'sign-in' }); } -async function getSignUpUrl(returnPathname?: string) { +export async function getSignUpUrl(returnPathname?: string) { return getAuthorizationUrl({ returnPathname, screenHint: 'sign-up' }); } -async function signOut(request: Request) { +export async function signOut(request: Request) { return await terminateSession(request); } -export { getSignInUrl, getSignUpUrl, signOut }; +/** + * Switches the current session to a different organization. + * @param request - The incoming request object. + * @param organizationId - The ID of the organization to switch to. + * @param options - Optional parameters. + * @returns A redirect response to the specified returnTo URL or a data response with the updated auth data. + */ +export async function switchToOrganization( + request: Request, + organizationId: string, + { returnTo }: { returnTo?: string } = {}, +) { + try { + const auth = await refreshSession(request, { organizationId }); + + // if returnTo is provided, redirect to there + if (returnTo) { + return redirect(returnTo, { + headers: { + 'Set-Cookie': auth.headers?.['Set-Cookie'] ?? '', + }, + }); + } + + // otherwise return the updated auth data + return data( + { success: true, auth }, + { + headers: { + 'Set-Cookie': auth.headers?.['Set-Cookie'] ?? '', + }, + }, + ); + } catch (error) { + if (error instanceof Response && error.status === 302) { + throw error; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const errorCause: any = error instanceof Error ? error.cause : null; + if (errorCause?.error === 'sso_required' || errorCause?.error === 'mfa_enrollment') { + return redirect(await getAuthorizationUrl({ organizationId })); + } + + return data( + { + success: false, + error: error instanceof Error ? error.message : String(error), + }, + { status: 400 }, + ); + } +} diff --git a/src/get-authorization-url.ts b/src/get-authorization-url.ts index b6ff1e2..db6b34e 100644 --- a/src/get-authorization-url.ts +++ b/src/get-authorization-url.ts @@ -1,17 +1,24 @@ import { getConfig } from './config.js'; -import { GetAuthURLOptions } from './interfaces.js'; import { getWorkOS } from './workos.js'; -async function getAuthorizationUrl(options: GetAuthURLOptions = {}) { - const { returnPathname, screenHint } = options; +interface GetAuthURLOptions { + screenHint?: 'sign-up' | 'sign-in'; + returnPathname?: string; + organizationId?: string; + redirectUri?: string; + loginHint?: string; +} + +export async function getAuthorizationUrl(options: GetAuthURLOptions = {}) { + const { returnPathname, screenHint, organizationId, redirectUri, loginHint } = options; return getWorkOS().userManagement.getAuthorizationUrl({ provider: 'authkit', clientId: getConfig('clientId'), - redirectUri: getConfig('redirectUri'), + redirectUri: redirectUri || getConfig('redirectUri'), state: returnPathname ? btoa(JSON.stringify({ returnPathname })) : undefined, screenHint, + organizationId, + loginHint, }); } - -export { getAuthorizationUrl }; diff --git a/src/index.ts b/src/index.ts index 14c27a2..2704836 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,18 +1,18 @@ -import { getSignInUrl, getSignUpUrl, signOut } from './auth.js'; +import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization } from './auth.js'; import { authLoader } from './authkit-callback-route.js'; import { configure, getConfig } from './config.js'; -import { authkitLoader } from './session.js'; +import { authkitLoader, refreshSession } from './session.js'; import { getWorkOS } from './workos.js'; export { authLoader, - // authkitLoader, - // - getSignInUrl, - getSignUpUrl, - signOut, configure, getConfig, + getSignInUrl, + getSignUpUrl, getWorkOS, + refreshSession, + signOut, + switchToOrganization, }; diff --git a/src/session.spec.ts b/src/session.spec.ts index 2b7ef4b..acef0d9 100644 --- a/src/session.spec.ts +++ b/src/session.spec.ts @@ -7,7 +7,7 @@ import { getSessionStorage as getSessionStorageMock, } from './sessionStorage.js'; import { Session } from './interfaces.js'; -import { authkitLoader, encryptSession, terminateSession } from './session.js'; +import { authkitLoader, encryptSession, terminateSession, refreshSession } from './session.js'; import { assertIsResponse } from './test-utils/test-helpers.js'; import { getWorkOS } from './workos.js'; import { getConfig } from './config.js'; @@ -529,4 +529,130 @@ describe('session', () => { }); }); }); + + describe('refreshSession', () => { + const createMockRequest = (cookie = 'test-cookie', url = 'http://example.com./some-path') => + new Request(url, { + headers: new Headers({ + Cookie: cookie, + }), + }); + + let getSession: jest.Mock; + let destroySession: jest.Mock; + let commitSession: jest.Mock; + let mockSession: ReactRouterSession; + + beforeEach(() => { + getSession = jest.fn(); + destroySession = jest.fn().mockResolvedValue('destroyed-session-cookie'); + commitSession = jest.fn().mockResolvedValue('new-session-cookie'); + + mockSession = createMockSession({ + has: jest.fn().mockReturnValue(true), + get: jest.fn().mockReturnValue('encrypted-jwt'), + set: jest.fn(), + }); + + getSessionStorage.mockResolvedValue({ + cookieName: 'wos-cookie', + getSession, + destroySession, + commitSession, + }); + + getSession.mockResolvedValue(mockSession); + + const validSessionData = { + accessToken: 'valid.token', + refreshToken: 'refresh.token', + user: { + id: 'user-1', + email: 'test@example.com', + firstName: 'Test', + lastName: 'User', + object: 'user', + }, + impersonator: null, + }; + unsealData.mockResolvedValue(validSessionData); + sealData.mockResolvedValue('new-encrypted-jwt'); + + authenticateWithRefreshToken.mockResolvedValue({ + accessToken: 'new.valid.token', + refreshToken: 'new.refresh.token', + } as AuthenticationResponse); + + // Mock JWT decoding + (jose.decodeJwt as jest.Mock).mockReturnValue({ + sid: 'new-session-id', + org_id: 'org-123', + role: 'user', + permissions: ['read'], + entitlements: ['basic'], + }); + }); + + it('should refresh the session successfully', async () => { + const refreshedSession = await refreshSession(createMockRequest()); + + expect(getSessionStorage).toHaveBeenCalled(); + expect(authenticateWithRefreshToken).toHaveBeenCalledWith({ + clientId: expect.any(String), + refreshToken: 'refresh.token', + organizationId: undefined, + }); + + expect(mockSession.set).toHaveBeenCalledWith('jwt', 'new-encrypted-jwt'); + expect(commitSession).toHaveBeenCalledWith(mockSession); + + expect(refreshedSession).toEqual({ + user: expect.objectContaining({ id: 'user-1' }), + sessionId: 'new-session-id', + accessToken: 'new.valid.token', + organizationId: 'org-123', + role: 'user', + permissions: ['read'], + entitlements: ['basic'], + impersonator: null, + sealedSession: 'encrypted-jwt', + headers: { + 'Set-Cookie': 'new-session-cookie', + }, + }); + }); + + it('should refresh the session with organizationId', async () => { + await refreshSession(createMockRequest(), { organizationId: 'org-456' }); + + expect(authenticateWithRefreshToken).toHaveBeenCalledWith({ + clientId: expect.any(String), + refreshToken: 'refresh.token', + organizationId: 'org-456', + }); + }); + + it('should redirect to sign-in when no session exists', async () => { + // Mock no session found + unsealData.mockResolvedValue(null); + + try { + await refreshSession(createMockRequest()); + fail('Expected redirect response to be thrown'); + } catch (response: unknown) { + assertIsResponse(response); + expect(response.status).toBe(302); + expect(response.headers.get('Location')).toMatch(/^https:\/\/auth\.workos\.com\/oauth/); + } + }); + + it('should throw error when refresh fails', async () => { + // Mock refresh token failure + authenticateWithRefreshToken.mockRejectedValue(new Error('Invalid refresh token')); + + await expect(refreshSession(createMockRequest())).rejects.toThrow( + 'Failed to refresh session: Invalid refresh token', + ); + }); + }); }); diff --git a/src/session.ts b/src/session.ts index e134b52..17063d5 100644 --- a/src/session.ts +++ b/src/session.ts @@ -14,7 +14,7 @@ import { sealData, unsealData } from 'iron-session'; import { createRemoteJWKSet, decodeJwt, jwtVerify } from 'jose'; import { getConfig } from './config.js'; import { configureSessionStorage, getSessionStorage } from './sessionStorage.js'; -import { isResponse, isRedirect, isJsonResponse } from './utils.js'; +import { isJsonResponse, isRedirect, isResponse } from './utils.js'; // must be a type since this is a subtype of response // interfaces must conform to the types they extend @@ -22,6 +22,71 @@ export type TypedResponse = Response & { json(): Promise; }; +/** + * This function is used to refresh the session by using the refresh token. + * It will authenticate the user with the refresh token and return a new session object. + * @param request - The request object + * @param options - Optional configuration options + * @returns A promise that resolves to the new session object + */ +export async function refreshSession(request: Request, { organizationId }: { organizationId?: string } = {}) { + const { getSession, commitSession } = await getSessionStorage(); + const session = await getSessionFromCookie(request.headers.get('Cookie') as string); + + if (!session) { + throw redirect(await getAuthorizationUrl()); + } + + try { + const { accessToken, refreshToken } = await getWorkOS().userManagement.authenticateWithRefreshToken({ + clientId: getConfig('clientId'), + refreshToken: session.refreshToken, + organizationId, + }); + + const newSession = { + accessToken, + refreshToken, + user: session.user, + impersonator: session.impersonator, + headers: {} as Record, + }; + + const cookieSession = await getSession(request.headers.get('Cookie')); + cookieSession.set('jwt', await encryptSession(newSession)); + const cookie = await commitSession(cookieSession); + + newSession.headers = { + 'Set-Cookie': cookie, + }; + + const { + sessionId, + organizationId: newOrgId, + role, + permissions, + entitlements, + } = getClaimsFromAccessToken(accessToken); + + return { + user: session.user, + sessionId, + accessToken, + organizationId: newOrgId, + role, + permissions, + entitlements, + impersonator: session.impersonator || null, + sealedSession: cookieSession.get('jwt'), + headers: newSession.headers, + }; + } catch (error) { + throw new Error(`Failed to refresh session: ${error instanceof Error ? error.message : String(error)}`, { + cause: error, + }); + } +} + async function updateSession(request: Request, debug: boolean) { const session = await getSessionFromCookie(request.headers.get('Cookie') as string); const { commitSession, getSession, destroySession } = await getSessionStorage(); @@ -83,7 +148,7 @@ async function updateSession(request: Request, debug: boolean) { } } -async function encryptSession(session: Session) { +export async function encryptSession(session: Session) { return sealData(session, { password: getConfig('cookiePassword'), ttl: 0, @@ -122,7 +187,7 @@ type AuthorizedAuthLoader = (args: LoaderFunctionArgs & { auth: Authorized * ); * } */ -async function authkitLoader( +export async function authkitLoader( loaderArgs: LoaderFunctionArgs, options: AuthKitLoaderOptions & { ensureSignedIn: true }, ): Promise>; @@ -142,7 +207,7 @@ async function authkitLoader( * return authkitLoader({ request }); * } */ -async function authkitLoader( +export async function authkitLoader( loaderArgs: LoaderFunctionArgs, options?: AuthKitLoaderOptions, ): Promise>; @@ -172,7 +237,7 @@ async function authkitLoader( * ); * } */ -async function authkitLoader( +export async function authkitLoader( loaderArgs: LoaderFunctionArgs, loader: AuthorizedAuthLoader, options: AuthKitLoaderOptions & { ensureSignedIn: true }, @@ -207,13 +272,13 @@ async function authkitLoader( * ); * } */ -async function authkitLoader( +export async function authkitLoader( loaderArgs: LoaderFunctionArgs, loader: AuthLoader, options?: AuthKitLoaderOptions, ): Promise>; -async function authkitLoader( +export async function authkitLoader( loaderArgs: LoaderFunctionArgs, loaderOrOptions?: AuthLoader | AuthorizedAuthLoader | AuthKitLoaderOptions, options: AuthKitLoaderOptions = {}, @@ -328,7 +393,7 @@ async function handleAuthLoader( return data({ ...loaderResult, ...auth }, session ? { headers: { ...session.headers } } : undefined); } -async function terminateSession(request: Request) { +export async function terminateSession(request: Request) { const { getSession, destroySession } = await getSessionStorage(); const encryptedSession = await getSession(request.headers.get('Cookie')); const { accessToken } = (await getSessionFromCookie( @@ -402,5 +467,3 @@ function getReturnPathname(url: string): string { // istanbul ignore next return `${newUrl.pathname}${newUrl.searchParams.size > 0 ? '?' + newUrl.searchParams.toString() : ''}`; } - -export { authkitLoader, encryptSession, terminateSession }; diff --git a/tsconfig.json b/tsconfig.json index 814c084..33b85ee 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,8 +1,8 @@ { "$schema": "http://json.schemastore.org/tsconfig", "compilerOptions": { - "target": "ES2021", - "lib": ["DOM", "ES2021", "DOM.Iterable"], + "target": "ES2022", + "lib": ["DOM", "ES2022", "DOM.Iterable"], "jsx": "react", "module": "Node16", "moduleResolution": "Node",