Skip to content

Commit a072e6d

Browse files
improvement(custom-tools): make them workspace scoped + ux to manage them (#1772)
* improvement(custom-tools): make them workspace scoped * fix auth check * remove comments * add dup check * fix dup error message display * fix tests * fix on app loading of custom tools
1 parent 3b901b3 commit a072e6d

File tree

18 files changed

+8258
-527
lines changed

18 files changed

+8258
-527
lines changed

apps/sim/app/api/tools/custom/route.test.ts

Lines changed: 146 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ describe('Custom Tools API Routes', () => {
1212
const sampleTools = [
1313
{
1414
id: 'tool-1',
15+
workspaceId: 'workspace-123',
1516
userId: 'user-123',
1617
title: 'Weather Tool',
1718
schema: {
@@ -37,6 +38,7 @@ describe('Custom Tools API Routes', () => {
3738
},
3839
{
3940
id: 'tool-2',
41+
workspaceId: 'workspace-123',
4042
userId: 'user-123',
4143
title: 'Calculator Tool',
4244
schema: {
@@ -82,7 +84,20 @@ describe('Custom Tools API Routes', () => {
8284
// Reset all mock implementations
8385
mockSelect.mockReturnValue({ from: mockFrom })
8486
mockFrom.mockReturnValue({ where: mockWhere })
85-
mockWhere.mockReturnValue({ limit: mockLimit })
87+
// where() can be called with limit() or directly awaited
88+
// Create a mock query builder that supports both patterns
89+
mockWhere.mockImplementation((condition) => {
90+
// Return an object that is both awaitable and has a limit() method
91+
const queryBuilder = {
92+
limit: mockLimit,
93+
then: (resolve: (value: typeof sampleTools) => void) => {
94+
resolve(sampleTools)
95+
return queryBuilder
96+
},
97+
catch: (reject: (error: Error) => void) => queryBuilder,
98+
}
99+
return queryBuilder
100+
})
86101
mockLimit.mockResolvedValue(sampleTools)
87102
mockInsert.mockReturnValue({ values: mockValues })
88103
mockValues.mockResolvedValue({ id: 'new-tool-id' })
@@ -99,11 +114,34 @@ describe('Custom Tools API Routes', () => {
99114
delete: mockDelete,
100115
transaction: vi.fn().mockImplementation(async (callback) => {
101116
// Execute the callback with a transaction object that has the same methods
117+
// Create transaction-specific mocks that follow the same pattern
118+
const txMockSelect = vi.fn().mockReturnValue({ from: mockFrom })
119+
const txMockInsert = vi.fn().mockReturnValue({ values: mockValues })
120+
const txMockUpdate = vi.fn().mockReturnValue({ set: mockSet })
121+
const txMockDelete = vi.fn().mockReturnValue({ where: mockWhere })
122+
123+
// Transaction where() should also support the query builder pattern
124+
const txMockWhere = vi.fn().mockImplementation((condition) => {
125+
const queryBuilder = {
126+
limit: mockLimit,
127+
then: (resolve: (value: typeof sampleTools) => void) => {
128+
resolve(sampleTools)
129+
return queryBuilder
130+
},
131+
catch: (reject: (error: Error) => void) => queryBuilder,
132+
}
133+
return queryBuilder
134+
})
135+
136+
// Update mockFrom to return txMockWhere for transaction queries
137+
const txMockFrom = vi.fn().mockReturnValue({ where: txMockWhere })
138+
txMockSelect.mockReturnValue({ from: txMockFrom })
139+
102140
return await callback({
103-
select: mockSelect,
104-
insert: mockInsert,
105-
update: mockUpdate,
106-
delete: mockDelete,
141+
select: txMockSelect,
142+
insert: txMockInsert,
143+
update: txMockUpdate,
144+
delete: txMockDelete,
107145
})
108146
}),
109147
},
@@ -112,8 +150,15 @@ describe('Custom Tools API Routes', () => {
112150
// Mock schema
113151
vi.doMock('@sim/db/schema', () => ({
114152
customTools: {
115-
userId: 'userId', // Add these properties to enable WHERE clauses with eq()
116153
id: 'id',
154+
workspaceId: 'workspaceId',
155+
userId: 'userId',
156+
title: 'title',
157+
},
158+
workflow: {
159+
id: 'id',
160+
workspaceId: 'workspaceId',
161+
userId: 'userId',
117162
},
118163
}))
119164

@@ -122,9 +167,18 @@ describe('Custom Tools API Routes', () => {
122167
getSession: vi.fn().mockResolvedValue(mockSession),
123168
}))
124169

125-
// Mock getUserId
126-
vi.doMock('@/app/api/auth/oauth/utils', () => ({
127-
getUserId: vi.fn().mockResolvedValue('user-123'),
170+
// Mock hybrid auth
171+
vi.doMock('@/lib/auth/hybrid', () => ({
172+
checkHybridAuth: vi.fn().mockResolvedValue({
173+
success: true,
174+
userId: 'user-123',
175+
authType: 'session',
176+
}),
177+
}))
178+
179+
// Mock permissions
180+
vi.doMock('@/lib/permissions/utils', () => ({
181+
getUserEntityPermissions: vi.fn().mockResolvedValue('admin'),
128182
}))
129183

130184
// Mock logger
@@ -137,14 +191,23 @@ describe('Custom Tools API Routes', () => {
137191
}),
138192
}))
139193

140-
// Mock eq function from drizzle-orm
194+
// Mock drizzle-orm functions
141195
vi.doMock('drizzle-orm', async () => {
142196
const actual = await vi.importActual('drizzle-orm')
143197
return {
144198
...(actual as object),
145199
eq: vi.fn().mockImplementation((field, value) => ({ field, value, operator: 'eq' })),
200+
and: vi.fn().mockImplementation((...conditions) => ({ operator: 'and', conditions })),
201+
or: vi.fn().mockImplementation((...conditions) => ({ operator: 'or', conditions })),
202+
isNull: vi.fn().mockImplementation((field) => ({ field, operator: 'isNull' })),
203+
ne: vi.fn().mockImplementation((field, value) => ({ field, value, operator: 'ne' })),
146204
}
147205
})
206+
207+
// Mock utils
208+
vi.doMock('@/lib/utils', () => ({
209+
generateRequestId: vi.fn().mockReturnValue('test-request-id'),
210+
}))
148211
})
149212

150213
afterEach(() => {
@@ -155,9 +218,11 @@ describe('Custom Tools API Routes', () => {
155218
* Test GET endpoint
156219
*/
157220
describe('GET /api/tools/custom', () => {
158-
it('should return tools for authenticated user', async () => {
159-
// Create mock request
160-
const req = createMockRequest('GET')
221+
it('should return tools for authenticated user with workspaceId', async () => {
222+
// Create mock request with workspaceId
223+
const req = new NextRequest(
224+
'http://localhost:3000/api/tools/custom?workspaceId=workspace-123'
225+
)
161226

162227
// Simulate DB returning tools
163228
mockWhere.mockReturnValueOnce(Promise.resolve(sampleTools))
@@ -182,11 +247,16 @@ describe('Custom Tools API Routes', () => {
182247

183248
it('should handle unauthorized access', async () => {
184249
// Create mock request
185-
const req = createMockRequest('GET')
186-
187-
// Mock session to return no user
188-
vi.doMock('@/lib/auth', () => ({
189-
getSession: vi.fn().mockResolvedValue(null),
250+
const req = new NextRequest(
251+
'http://localhost:3000/api/tools/custom?workspaceId=workspace-123'
252+
)
253+
254+
// Mock hybrid auth to return unauthorized
255+
vi.doMock('@/lib/auth/hybrid', () => ({
256+
checkHybridAuth: vi.fn().mockResolvedValue({
257+
success: false,
258+
error: 'Unauthorized',
259+
}),
190260
}))
191261

192262
// Import handler after mocks are set up
@@ -205,114 +275,53 @@ describe('Custom Tools API Routes', () => {
205275
// Create mock request with workflowId parameter
206276
const req = new NextRequest('http://localhost:3000/api/tools/custom?workflowId=workflow-123')
207277

208-
// Import handler after mocks are set up
209-
const { GET } = await import('@/app/api/tools/custom/route')
210-
211-
// Call the handler
212-
const _response = await GET(req)
213-
214-
// Verify getUserId was called with correct parameters
215-
const getUserId = (await import('@/app/api/auth/oauth/utils')).getUserId
216-
expect(getUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
217-
218-
// Verify DB query filters by user
219-
expect(mockWhere).toHaveBeenCalled()
220-
})
221-
})
222-
223-
/**
224-
* Test POST endpoint
225-
*/
226-
describe('POST /api/tools/custom', () => {
227-
it('should create new tools when IDs are not provided', async () => {
228-
// Create test tool data
229-
const newTool = {
230-
title: 'New Tool',
231-
schema: {
232-
type: 'function',
233-
function: {
234-
name: 'newTool',
235-
description: 'A brand new tool',
236-
parameters: {
237-
type: 'object',
238-
properties: {},
239-
required: [],
240-
},
241-
},
242-
},
243-
code: 'return "hello world";',
244-
}
245-
246-
// Create mock request with new tool
247-
const req = createMockRequest('POST', { tools: [newTool] })
248-
249-
// Import handler after mocks are set up
250-
const { POST } = await import('@/app/api/tools/custom/route')
251-
252-
// Call the handler
253-
const response = await POST(req)
254-
const data = await response.json()
255-
256-
// Verify response
257-
expect(response.status).toBe(200)
258-
expect(data).toHaveProperty('success', true)
278+
// Mock workflow lookup to return workspaceId (for limit(1) call)
279+
mockLimit.mockResolvedValueOnce([{ workspaceId: 'workspace-123' }])
259280

260-
// Verify insert was called with correct parameters
261-
expect(mockInsert).toHaveBeenCalled()
262-
expect(mockValues).toHaveBeenCalled()
263-
})
264-
265-
it('should update existing tools when ID is provided', async () => {
266-
// Create test tool data with ID
267-
const updateTool = {
268-
id: 'tool-1',
269-
title: 'Updated Weather Tool',
270-
schema: {
271-
type: 'function',
272-
function: {
273-
name: 'getWeatherUpdate',
274-
description: 'Get updated weather information',
275-
parameters: {
276-
type: 'object',
277-
properties: {},
278-
required: [],
279-
},
281+
// Mock the where() call for fetching tools (returns awaitable query builder)
282+
mockWhere.mockImplementationOnce((condition) => {
283+
const queryBuilder = {
284+
limit: mockLimit,
285+
then: (resolve: (value: typeof sampleTools) => void) => {
286+
resolve(sampleTools)
287+
return queryBuilder
280288
},
281-
},
282-
code: 'return { temperature: 75, conditions: "partly cloudy" };',
283-
}
284-
285-
// Mock DB to find existing tool
286-
mockLimit.mockResolvedValueOnce([sampleTools[0]])
287-
288-
// Create mock request with tool update
289-
const req = createMockRequest('POST', { tools: [updateTool] })
289+
catch: (reject: (error: Error) => void) => queryBuilder,
290+
}
291+
return queryBuilder
292+
})
290293

291294
// Import handler after mocks are set up
292-
const { POST } = await import('@/app/api/tools/custom/route')
295+
const { GET } = await import('@/app/api/tools/custom/route')
293296

294297
// Call the handler
295-
const response = await POST(req)
298+
const response = await GET(req)
296299
const data = await response.json()
297300

298301
// Verify response
299302
expect(response.status).toBe(200)
300-
expect(data).toHaveProperty('success', true)
303+
expect(data).toHaveProperty('data')
301304

302-
// Verify update was called with correct parameters
303-
expect(mockUpdate).toHaveBeenCalled()
304-
expect(mockSet).toHaveBeenCalled()
305+
// Verify DB query was called
305306
expect(mockWhere).toHaveBeenCalled()
306307
})
308+
})
307309

310+
/**
311+
* Test POST endpoint
312+
*/
313+
describe('POST /api/tools/custom', () => {
308314
it('should reject unauthorized requests', async () => {
309-
// Mock session to return no user
310-
vi.doMock('@/lib/auth', () => ({
311-
getSession: vi.fn().mockResolvedValue(null),
315+
// Mock hybrid auth to return unauthorized
316+
vi.doMock('@/lib/auth/hybrid', () => ({
317+
checkHybridAuth: vi.fn().mockResolvedValue({
318+
success: false,
319+
error: 'Unauthorized',
320+
}),
312321
}))
313322

314323
// Create mock request
315-
const req = createMockRequest('POST', { tools: [] })
324+
const req = createMockRequest('POST', { tools: [], workspaceId: 'workspace-123' })
316325

317326
// Import handler after mocks are set up
318327
const { POST } = await import('@/app/api/tools/custom/route')
@@ -333,8 +342,8 @@ describe('Custom Tools API Routes', () => {
333342
code: 'return "invalid";',
334343
}
335344

336-
// Create mock request with invalid tool
337-
const req = createMockRequest('POST', { tools: [invalidTool] })
345+
// Create mock request with invalid tool and workspaceId
346+
const req = createMockRequest('POST', { tools: [invalidTool], workspaceId: 'workspace-123' })
338347

339348
// Import handler after mocks are set up
340349
const { POST } = await import('@/app/api/tools/custom/route')
@@ -354,12 +363,14 @@ describe('Custom Tools API Routes', () => {
354363
* Test DELETE endpoint
355364
*/
356365
describe('DELETE /api/tools/custom', () => {
357-
it('should delete a tool by ID', async () => {
358-
// Mock finding existing tool
366+
it('should delete a workspace-scoped tool by ID', async () => {
367+
// Mock finding existing workspace-scoped tool
359368
mockLimit.mockResolvedValueOnce([sampleTools[0]])
360369

361-
// Create mock request with ID parameter
362-
const req = new NextRequest('http://localhost:3000/api/tools/custom?id=tool-1')
370+
// Create mock request with ID and workspaceId parameters
371+
const req = new NextRequest(
372+
'http://localhost:3000/api/tools/custom?id=tool-1&workspaceId=workspace-123'
373+
)
363374

364375
// Import handler after mocks are set up
365376
const { DELETE } = await import('@/app/api/tools/custom/route')
@@ -412,12 +423,21 @@ describe('Custom Tools API Routes', () => {
412423
expect(data).toHaveProperty('error', 'Tool not found')
413424
})
414425

415-
it('should prevent unauthorized deletion', async () => {
416-
// Mock finding tool that belongs to a different user
417-
const otherUserTool = { ...sampleTools[0], userId: 'different-user' }
418-
mockLimit.mockResolvedValueOnce([otherUserTool])
426+
it('should prevent unauthorized deletion of user-scoped tool', async () => {
427+
// Mock hybrid auth for the DELETE request
428+
vi.doMock('@/lib/auth/hybrid', () => ({
429+
checkHybridAuth: vi.fn().mockResolvedValue({
430+
success: true,
431+
userId: 'user-456', // Different user
432+
authType: 'session',
433+
}),
434+
}))
419435

420-
// Create mock request
436+
// Mock finding user-scoped tool (no workspaceId) that belongs to user-123
437+
const userScopedTool = { ...sampleTools[0], workspaceId: null, userId: 'user-123' }
438+
mockLimit.mockResolvedValueOnce([userScopedTool])
439+
440+
// Create mock request (no workspaceId for user-scoped tool)
421441
const req = new NextRequest('http://localhost:3000/api/tools/custom?id=tool-1')
422442

423443
// Import handler after mocks are set up
@@ -429,13 +449,16 @@ describe('Custom Tools API Routes', () => {
429449

430450
// Verify response
431451
expect(response.status).toBe(403)
432-
expect(data).toHaveProperty('error', 'Unauthorized')
452+
expect(data).toHaveProperty('error', 'Access denied')
433453
})
434454

435455
it('should reject unauthorized requests', async () => {
436-
// Mock session to return no user
437-
vi.doMock('@/lib/auth', () => ({
438-
getSession: vi.fn().mockResolvedValue(null),
456+
// Mock hybrid auth to return unauthorized
457+
vi.doMock('@/lib/auth/hybrid', () => ({
458+
checkHybridAuth: vi.fn().mockResolvedValue({
459+
success: false,
460+
error: 'Unauthorized',
461+
}),
439462
}))
440463

441464
// Create mock request

0 commit comments

Comments
 (0)