Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 238 additions & 1 deletion src/auth.spec.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
import { getSignInUrl, getSignUpUrl, signOut } from './auth.js';
import type { User } from '@workos-inc/node';
import { data, redirect } from 'react-router';
import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization } from './auth.js';
import * as authorizationUrl from './get-authorization-url.js';
import * as session from './session.js';
import { assertIsResponse } from './test-utils/test-helpers.js';

const terminateSession = jest.mocked(session.terminateSession);
const refreshSession = jest.mocked(session.refreshSession);

jest.mock('./session', () => ({
terminateSession: jest.fn().mockResolvedValue(new Response()),
refreshSession: jest.fn(),
}));

// Mock redirect and data from react-router
jest.mock('react-router', () => {
const originalModule = jest.requireActual('react-router');
return {
...originalModule,
redirect: jest.fn().mockImplementation((to, init) => {
const response = new Response(null, {
status: 302,
headers: { Location: to, ...(init?.headers || {}) },
});
return response;
}),
data: jest.fn().mockImplementation((value, init) => ({
data: value,
init,
})),
};
});

describe('auth', () => {
beforeEach(() => {
jest.spyOn(authorizationUrl, 'getAuthorizationUrl');
jest.clearAllMocks();
});

describe('getSignInUrl', () => {
Expand Down Expand Up @@ -39,4 +64,216 @@ describe('auth', () => {
expect(terminateSession).toHaveBeenCalledWith(request);
});
});

describe('switchToOrganization', () => {
const request = new Request('https://example.com');
const organizationId = 'org_123456';

// Create a mock user that matches the User type
const mockUser = {
id: 'user-1',
email: '[email protected]',
emailVerified: true,
firstName: 'Test',
lastName: 'User',
profilePictureUrl: 'https://example.com/avatar.jpg',
object: 'user',
createdAt: '2021-01-01T00:00:00Z',
updatedAt: '2021-01-01T00:00:00Z',
lastSignInAt: '2021-01-01T00:00:00Z',
externalId: null,
} as User;

// Mock the return type of refreshSession
const mockAuthResponse = {
user: mockUser,
sessionId: 'session-123',
accessToken: 'new-access-token',
organizationId: 'org_123456' as string | undefined,
role: 'admin' as string | undefined,
permissions: ['read', 'write'] as string[] | undefined,
entitlements: ['premium'] as string[] | undefined,
impersonator: null,
sealedSession: 'sealed-session-data',
headers: {
'Set-Cookie': 'new-cookie-value',
},
};

beforeEach(() => {
refreshSession.mockResolvedValue(mockAuthResponse);
});

it('should call refreshSession with the correct params', async () => {
await switchToOrganization(request, organizationId);

expect(refreshSession).toHaveBeenCalledWith(request, { organizationId });
});

it('should return data with success and auth when no returnTo is provided', async () => {
const result = await switchToOrganization(request, organizationId);

expect(data).toHaveBeenCalledWith(
{ success: true, auth: mockAuthResponse },
{
headers: {
'Set-Cookie': 'new-cookie-value',
},
},
);
expect(result).toEqual({
data: { success: true, auth: mockAuthResponse },
init: {
headers: {
'Set-Cookie': 'new-cookie-value',
},
},
});
});

it('should redirect to returnTo when provided', async () => {
const returnTo = '/dashboard';
const result = await switchToOrganization(request, organizationId, { returnTo });

expect(redirect).toHaveBeenCalledWith(returnTo, {
headers: {
'Set-Cookie': 'new-cookie-value',
},
});

assertIsResponse(result);
expect(result.status).toBe(302);
expect(result.headers.get('Location')).toBe(returnTo);
expect(result.headers.get('Set-Cookie')).toBe('new-cookie-value');
});

it('should handle case when refreshSession throws a redirect', async () => {
const redirectResponse = new Response(null, {
status: 302,
headers: { Location: '/login' },
});
refreshSession.mockRejectedValueOnce(redirectResponse);

try {
await switchToOrganization(request, organizationId);
fail('Expected redirect response to be thrown');
} catch (response) {
assertIsResponse(response);
expect(response.status).toBe(302);
expect(response.headers.get('Location')).toBe('/login');
}
});

it('should redirect to authorization URL for SSO_required errors', async () => {
const authUrl = 'https://api.workos.com/sso/authorize';
const errorWithSSOCause = new Error('SSO Required', {
cause: { error: 'sso_required' },
});

refreshSession.mockRejectedValueOnce(errorWithSSOCause);
(authorizationUrl.getAuthorizationUrl as jest.Mock).mockResolvedValueOnce(authUrl);

const result = await switchToOrganization(request, organizationId);

expect(authorizationUrl.getAuthorizationUrl).toHaveBeenCalled();
expect(redirect).toHaveBeenCalledWith(authUrl);

assertIsResponse(result);
expect(result.status).toBe(302);
expect(result.headers.get('Location')).toBe(authUrl);
});

it('should handle mfa_enrollment errors', async () => {
const authUrl = 'https://api.workos.com/sso/authorize';
const errorWithMFACause = new Error('MFA Enrollment Required', {
cause: { error: 'mfa_enrollment' },
});

refreshSession.mockRejectedValueOnce(errorWithMFACause);
(authorizationUrl.getAuthorizationUrl as jest.Mock).mockResolvedValueOnce(authUrl);

const result = await switchToOrganization(request, organizationId);

expect(authorizationUrl.getAuthorizationUrl).toHaveBeenCalled();
expect(redirect).toHaveBeenCalledWith(authUrl);

assertIsResponse(result);
expect(result.status).toBe(302);
expect(result.headers.get('Location')).toBe(authUrl);
});

it('should return error data for Error instances', async () => {
const error = new Error('Invalid organization');
refreshSession.mockRejectedValueOnce(error);

const result = await switchToOrganization(request, organizationId);

expect(data).toHaveBeenCalledWith(
{
success: false,
error: 'Invalid organization',
},
{ status: 400 },
);
expect(result).toEqual({
data: {
success: false,
error: 'Invalid organization',
},
init: { status: 400 },
});
});

it('should return error data for non-Error objects', async () => {
const error = 'String error message';
refreshSession.mockRejectedValueOnce(error);

await switchToOrganization(request, organizationId);

expect(data).toHaveBeenCalledWith(
{
success: false,
error: 'String error message',
},
{ status: 400 },
);
});

it('should handle when Set-Cookie header is missing', async () => {
// Create a mock without the Set-Cookie header
const mockResponseWithoutCookie = {
...mockAuthResponse,
headers: {},
};
refreshSession.mockResolvedValueOnce(mockResponseWithoutCookie);

await switchToOrganization(request, organizationId);

expect(data).toHaveBeenCalledWith(
{ success: true, auth: mockResponseWithoutCookie },
{
headers: {
'Set-Cookie': '',
},
},
);
});

it('should handle when returnTo is provided but Set-Cookie header is missing', async () => {
// Create a mock without the Set-Cookie header
const mockResponseWithoutCookie = {
...mockAuthResponse,
headers: {},
};
refreshSession.mockResolvedValueOnce(mockResponseWithoutCookie);

await switchToOrganization(request, organizationId, { returnTo: '/dashboard' });

expect(redirect).toHaveBeenCalledWith('/dashboard', {
headers: {
'Set-Cookie': '',
},
});
});
});
});
63 changes: 58 additions & 5 deletions src/auth.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
import { data, redirect } from 'react-router';
import { getAuthorizationUrl } from './get-authorization-url.js';
import { terminateSession } from './session.js';
import { refreshSession, terminateSession } from './session.js';

async function getSignInUrl(returnPathname?: string) {
export async function getSignInUrl(returnPathname?: string) {
return getAuthorizationUrl({ returnPathname, screenHint: 'sign-in' });
}

async function getSignUpUrl(returnPathname?: string) {
export async function getSignUpUrl(returnPathname?: string) {
return getAuthorizationUrl({ returnPathname, screenHint: 'sign-up' });
}

async function signOut(request: Request) {
export async function signOut(request: Request) {
return await terminateSession(request);
}

export { getSignInUrl, getSignUpUrl, signOut };
/**
* Switches the current session to a different organization.
* @param request - The incoming request object.
* @param organizationId - The ID of the organization to switch to.
* @param options - Optional parameters.
* @returns A redirect response to the specified returnTo URL or a data response with the updated auth data.
*/
export async function switchToOrganization(
request: Request,
organizationId: string,
{ returnTo }: { returnTo?: string } = {},
) {
try {
const auth = await refreshSession(request, { organizationId });

// if returnTo is provided, redirect to there
if (returnTo) {
return redirect(returnTo, {
headers: {
'Set-Cookie': auth.headers?.['Set-Cookie'] ?? '',
},
});
}

// otherwise return the updated auth data
return data(
{ success: true, auth },
{
headers: {
'Set-Cookie': auth.headers?.['Set-Cookie'] ?? '',
},
},
);
} catch (error) {
if (error instanceof Response && error.status === 302) {
throw error;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const errorCause: any = error instanceof Error ? error.cause : null;
if (errorCause?.error === 'sso_required' || errorCause?.error === 'mfa_enrollment') {
return redirect(await getAuthorizationUrl({ organizationId }));
}

return data(
{
success: false,
error: error instanceof Error ? error.message : String(error),
},
{ status: 400 },
);
}
}
19 changes: 13 additions & 6 deletions src/get-authorization-url.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import { getConfig } from './config.js';
import { GetAuthURLOptions } from './interfaces.js';
import { getWorkOS } from './workos.js';

async function getAuthorizationUrl(options: GetAuthURLOptions = {}) {
const { returnPathname, screenHint } = options;
interface GetAuthURLOptions {
screenHint?: 'sign-up' | 'sign-in';
returnPathname?: string;
organizationId?: string;
redirectUri?: string;
loginHint?: string;
}

export async function getAuthorizationUrl(options: GetAuthURLOptions = {}) {
const { returnPathname, screenHint, organizationId, redirectUri, loginHint } = options;

return getWorkOS().userManagement.getAuthorizationUrl({
provider: 'authkit',
clientId: getConfig('clientId'),
redirectUri: getConfig('redirectUri'),
redirectUri: redirectUri || getConfig('redirectUri'),
state: returnPathname ? btoa(JSON.stringify({ returnPathname })) : undefined,
screenHint,
organizationId,
loginHint,
});
}

export { getAuthorizationUrl };
14 changes: 7 additions & 7 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import { getSignInUrl, getSignUpUrl, signOut } from './auth.js';
import { getSignInUrl, getSignUpUrl, signOut, switchToOrganization } from './auth.js';
import { authLoader } from './authkit-callback-route.js';
import { configure, getConfig } from './config.js';
import { authkitLoader } from './session.js';
import { authkitLoader, refreshSession } from './session.js';
import { getWorkOS } from './workos.js';

export {
authLoader,
//
authkitLoader,
//
getSignInUrl,
getSignUpUrl,
signOut,
configure,
getConfig,
getSignInUrl,
getSignUpUrl,
getWorkOS,
refreshSession,
signOut,
switchToOrganization,
};
Loading