diff --git a/src/config.test.ts b/src/config.test.ts index 5c793865..2db3fa82 100644 --- a/src/config.test.ts +++ b/src/config.test.ts @@ -41,6 +41,11 @@ describe('Config', () => { MAX_RESULT_LIMIT: undefined, DISABLE_QUERY_DATASOURCE_FILTER_VALIDATION: undefined, DISABLE_METADATA_API_REQUESTS: undefined, + ENABLE_SERVER_LOGGING: undefined, + SERVER_LOG_DIRECTORY: undefined, + INCLUDE_PROJECT_IDS: undefined, + INCLUDE_DATASOURCE_IDS: undefined, + INCLUDE_WORKBOOK_IDS: undefined, }; }); @@ -647,4 +652,73 @@ describe('Config', () => { expect(config.jwtAdditionalPayload).toBe('{}'); }); }); + + describe('Bounded context parsing', () => { + it('should set boundedContext to null sets when no project, datasource, or workbook IDs are provided', () => { + process.env = { + ...process.env, + ...defaultEnvVars, + }; + + const config = new Config(); + expect(config.boundedContext).toEqual({ + projectIds: null, + datasourceIds: null, + workbookIds: null, + }); + }); + + it('should set boundedContext to the specified project, datasource, and workbook IDs when provided', () => { + process.env = { + ...process.env, + ...defaultEnvVars, + INCLUDE_PROJECT_IDS: ' 123, 456, 123 ', // spacing is intentional here to test trimming + INCLUDE_DATASOURCE_IDS: '789,101', + INCLUDE_WORKBOOK_IDS: '112,113', + }; + + const config = new Config(); + expect(config.boundedContext).toEqual({ + projectIds: new Set(['123', '456']), + datasourceIds: new Set(['789', '101']), + workbookIds: new Set(['112', '113']), + }); + }); + + it('should throw error when INCLUDE_PROJECT_IDS is set to an empty string', () => { + process.env = { + ...process.env, + ...defaultEnvVars, + INCLUDE_PROJECT_IDS: '', + }; + + expect(() => new Config()).toThrow( + 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', + ); + }); + + it('should throw error when INCLUDE_DATASOURCE_IDS is set to an empty string', () => { + process.env = { + ...process.env, + ...defaultEnvVars, + INCLUDE_DATASOURCE_IDS: '', + }; + + expect(() => new Config()).toThrow( + 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', + ); + }); + + it('should throw error when INCLUDE_WORKBOOK_IDS is set to an empty string', () => { + process.env = { + ...process.env, + ...defaultEnvVars, + INCLUDE_WORKBOOK_IDS: '', + }; + + expect(() => new Config()).toThrow( + 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', + ); + }); + }); }); diff --git a/src/config.ts b/src/config.ts index 7cb11d42..9518048a 100644 --- a/src/config.ts +++ b/src/config.ts @@ -11,6 +11,12 @@ const __dirname = fileURLToPath(new URL('.', import.meta.url)); const authTypes = ['pat', 'direct-trust'] as const; type AuthType = (typeof authTypes)[number]; +export type BoundedContext = { + projectIds: Set | null; + datasourceIds: Set | null; + workbookIds: Set | null; +}; + export class Config { auth: AuthType; server: string; @@ -37,6 +43,7 @@ export class Config { disableMetadataApiRequests: boolean; enableServerLogging: boolean; serverLogDirectory: string; + boundedContext: BoundedContext; constructor() { const cleansedVars = removeClaudeMcpBundleUserConfigTemplates(process.env); @@ -66,6 +73,9 @@ export class Config { DISABLE_METADATA_API_REQUESTS: disableMetadataApiRequests, ENABLE_SERVER_LOGGING: enableServerLogging, SERVER_LOG_DIRECTORY: serverLogDirectory, + INCLUDE_PROJECT_IDS: includeProjectIds, + INCLUDE_DATASOURCE_IDS: includeDatasourceIds, + INCLUDE_WORKBOOK_IDS: includeWorkbookIds, } = cleansedVars; const defaultPort = 3927; @@ -86,6 +96,29 @@ export class Config { this.disableMetadataApiRequests = disableMetadataApiRequests === 'true'; this.enableServerLogging = enableServerLogging === 'true'; this.serverLogDirectory = serverLogDirectory || join(__dirname, 'logs'); + this.boundedContext = { + projectIds: createSetFromCommaSeparatedString(includeProjectIds), + datasourceIds: createSetFromCommaSeparatedString(includeDatasourceIds), + workbookIds: createSetFromCommaSeparatedString(includeWorkbookIds), + }; + + if (this.boundedContext.projectIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_PROJECT_IDS must have at least one value', + ); + } + + if (this.boundedContext.datasourceIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_DATASOURCE_IDS must have at least one value', + ); + } + + if (this.boundedContext.workbookIds?.size === 0) { + throw new Error( + 'When set, the environment variable INCLUDE_WORKBOOK_IDS must have at least one value', + ); + } const maxResultLimitNumber = maxResultLimit ? parseInt(maxResultLimit) : NaN; this.maxResultLimit = @@ -181,6 +214,22 @@ function getCorsOriginConfig(corsOriginConfig: string): CorsOptions['origin'] { } } +// Creates a set from a comma-separated string of values. +// Returns null if the value is undefined. +function createSetFromCommaSeparatedString(value: string | undefined): Set | null { + if (value === undefined) { + return null; + } + + return new Set( + value + .trim() + .split(',') + .map((id) => id.trim()) + .filter(Boolean), + ); +} + // When the user does not provide a site name in the Claude MCP Bundle configuration, // Claude doesn't replace its value and sets the site name to "${user_config.site_name}". function removeClaudeMcpBundleUserConfigTemplates( diff --git a/src/scripts/createClaudeMcpBundleManifest.ts b/src/scripts/createClaudeMcpBundleManifest.ts index 346df838..ad44761f 100644 --- a/src/scripts/createClaudeMcpBundleManifest.ts +++ b/src/scripts/createClaudeMcpBundleManifest.ts @@ -166,6 +166,30 @@ const envVars = { required: false, sensitive: false, }, + INCLUDE_PROJECT_IDS: { + includeInUserConfig: false, + type: 'string', + title: 'IDs of projects to constrain tool results by', + description: 'A comma-separated list of project IDs to constrain tool results by.', + required: false, + sensitive: false, + }, + INCLUDE_DATASOURCE_IDS: { + includeInUserConfig: false, + type: 'string', + title: 'IDs of datasources to constrain tool results by', + description: 'A comma-separated list of datasource IDs to constrain tool results by.', + required: false, + sensitive: false, + }, + INCLUDE_WORKBOOK_IDS: { + includeInUserConfig: false, + type: 'string', + title: 'IDs of workbooks to constrain tool results by', + description: 'A comma-separated list of workbook IDs to constrain tool results by.', + required: false, + sensitive: false, + }, MAX_RESULT_LIMIT: { includeInUserConfig: false, type: 'number', diff --git a/src/sdks/tableau/apis/datasourcesApi.ts b/src/sdks/tableau/apis/datasourcesApi.ts index 04bb6032..38728eda 100644 --- a/src/sdks/tableau/apis/datasourcesApi.ts +++ b/src/sdks/tableau/apis/datasourcesApi.ts @@ -5,7 +5,7 @@ import { dataSourceSchema } from '../types/dataSource.js'; import { paginationSchema } from '../types/pagination.js'; import { paginationParameters } from './paginationParameters.js'; -const listDatasourcesRestEndpoint = makeEndpoint({ +const listDatasourcesEndpoint = makeEndpoint({ method: 'get', path: '/sites/:siteId/datasources', alias: 'listDatasources', @@ -33,5 +33,15 @@ const listDatasourcesRestEndpoint = makeEndpoint({ }), }); -const datasourcesApi = makeApi([listDatasourcesRestEndpoint]); +const queryDatasourceEndpoint = makeEndpoint({ + method: 'get', + path: '/sites/:siteId/datasources/:datasourceId', + alias: 'queryDatasource', + description: 'Returns information about the specified data source.', + response: z.object({ + datasource: dataSourceSchema, + }), +}); + +const datasourcesApi = makeApi([listDatasourcesEndpoint, queryDatasourceEndpoint]); export const datasourcesApis = [...datasourcesApi] as const satisfies ZodiosEndpointDefinitions; diff --git a/src/sdks/tableau/apis/viewsApi.ts b/src/sdks/tableau/apis/viewsApi.ts index 6857a686..74e539ea 100644 --- a/src/sdks/tableau/apis/viewsApi.ts +++ b/src/sdks/tableau/apis/viewsApi.ts @@ -5,6 +5,14 @@ import { paginationSchema } from '../types/pagination.js'; import { viewSchema } from '../types/view.js'; import { paginationParameters } from './paginationParameters.js'; +const getViewEndpoint = makeEndpoint({ + method: 'get', + path: `/sites/:siteId/views/:viewId`, + alias: 'getView', + description: 'Gets the details of a specific view.', + response: z.object({ view: viewSchema }), +}); + const queryViewDataEndpoint = makeEndpoint({ method: 'get', path: `/sites/:siteId/views/:viewId/data`, @@ -90,6 +98,7 @@ const queryViewsForSiteEndpoint = makeEndpoint({ }); const viewsApi = makeApi([ + getViewEndpoint, queryViewDataEndpoint, queryViewImageEndpoint, queryViewsForWorkbookEndpoint, diff --git a/src/sdks/tableau/methods/datasourcesMethods.ts b/src/sdks/tableau/methods/datasourcesMethods.ts index 31a63902..cc1cf6b9 100644 --- a/src/sdks/tableau/methods/datasourcesMethods.ts +++ b/src/sdks/tableau/methods/datasourcesMethods.ts @@ -50,4 +50,28 @@ export default class DatasourcesMethods extends AuthenticatedMethods => { + return ( + await this._apiClient.queryDatasource({ + params: { siteId, datasourceId }, + ...this.authHeader, + }) + ).datasource; + }; } diff --git a/src/sdks/tableau/methods/pulseMethods.ts b/src/sdks/tableau/methods/pulseMethods.ts index 1d8be358..7e7441c9 100644 --- a/src/sdks/tableau/methods/pulseMethods.ts +++ b/src/sdks/tableau/methods/pulseMethods.ts @@ -7,7 +7,7 @@ import { pulseApis } from '../apis/pulseApi.js'; import { Credentials } from '../types/credentials.js'; import { pulseBundleRequestSchema, - pulseBundleResponseSchema, + PulseBundleResponse, PulseInsightBundleType, PulseMetric, PulseMetricDefinition, @@ -139,7 +139,7 @@ export default class PulseMethods extends AuthenticatedMethods generatePulseMetricValueInsightBundle = async ( bundleRequest: z.infer, bundleType: PulseInsightBundleType, - ): Promise>> => { + ): Promise> => { return await guardAgainstPulseDisabled(async () => { const response = await this._apiClient.generatePulseMetricValueInsightBundle( { bundle_request: bundleRequest.bundle_request }, @@ -150,7 +150,8 @@ export default class PulseMethods extends AuthenticatedMethods }; } -type PulseResult = Result; +export type PulseDisabledError = 'tableau-server' | 'pulse-disabled'; +type PulseResult = Result; async function guardAgainstPulseDisabled(callback: () => Promise): Promise> { try { return new Ok(await callback()); diff --git a/src/sdks/tableau/methods/viewsMethods.ts b/src/sdks/tableau/methods/viewsMethods.ts index 535cef30..e2dd89ae 100644 --- a/src/sdks/tableau/methods/viewsMethods.ts +++ b/src/sdks/tableau/methods/viewsMethods.ts @@ -18,6 +18,19 @@ export default class ViewsMethods extends AuthenticatedMethods super(new Zodios(baseUrl, viewsApis), creds); } + /** + * Gets the details of a specific view. + * + * Required scopes: `tableau:content:read` + * + * @param {string} viewId The ID of the view to get. + * @param {string} siteId - The Tableau site ID + * @link https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_workbooks_and_views.htm#get_view + */ + getView = async ({ viewId, siteId }: { viewId: string; siteId: string }): Promise => { + return (await this._apiClient.getView({ params: { siteId, viewId }, ...this.authHeader })).view; + }; + /** * Returns a specified view rendered as data in comma separated value (CSV) format. * diff --git a/src/sdks/tableau/types/pulse.test.ts b/src/sdks/tableau/types/pulse.test.ts index f9d372a4..cba7ed07 100644 --- a/src/sdks/tableau/types/pulse.test.ts +++ b/src/sdks/tableau/types/pulse.test.ts @@ -115,6 +115,7 @@ function createValidPulseMetric(overrides = {}): any { comparison: { comparison: 'previous_period' }, }, definition_id: 'BBC908D8-29ED-48AB-A78E-ACF8A424C8C3', + datasource_luid: 'A6FC3C9F-4F40-4906-8DB0-AC70C5FB5A11', is_default: true, schema_version: '1.0', metric_version: '1', diff --git a/src/sdks/tableau/types/pulse.ts b/src/sdks/tableau/types/pulse.ts index 36408a87..03762dfb 100644 --- a/src/sdks/tableau/types/pulse.ts +++ b/src/sdks/tableau/types/pulse.ts @@ -67,6 +67,7 @@ export const pulseMetricSchema = z.object({ metric_version: z.coerce.number(), goals: pulseGoalsSchema.optional(), is_followed: z.boolean(), + datasource_luid: z.string(), }); export const pulseRepresentationOptionsSchema = z.object({ @@ -339,6 +340,8 @@ export const pulseBundleResponseSchema = z.object({ }), }); +export type PulseBundleResponse = z.infer; + export const pulseInsightBundleTypeEnum = ['ban', 'springboard', 'basic', 'detail'] as const; export type PulseInsightBundleType = (typeof pulseInsightBundleTypeEnum)[number]; diff --git a/src/tools/contentExploration/mockSearchContentItems.ts b/src/tools/contentExploration/mockSearchContentItems.ts new file mode 100644 index 00000000..eec665e7 --- /dev/null +++ b/src/tools/contentExploration/mockSearchContentItems.ts @@ -0,0 +1,45 @@ +import { ReducedSearchContentResponse } from './searchContentUtils.js'; + +export const mockSearchContentItems = [ + { + uri: 'test-uri-1', + content: { + type: 'workbook', + luid: 'workbook-1-luid', + title: 'Sales Dashboard', + ownerName: 'John Doe', + ownerId: 123, + ownerEmail: 'john.doe@example.com', + projectName: 'Finance', + containerName: 'Finance', + hitsTotal: 150, + hitsSmallSpanTotal: 10, + hitsMediumSpanTotal: 25, + hitsLargeSpanTotal: 50, + favoritesTotal: 5, + modifiedTime: '2024-01-15T10:30:00Z', + createdTime: '2023-12-01T09:00:00Z', + tags: ['dashboard', 'sales'], + }, + }, + { + uri: 'test-uri-2', + content: { + type: 'datasource', + luid: 'datasource-1-luid', + title: 'Customer Data', + ownerName: 'Jane Smith', + ownerId: 456, + ownerEmail: 'jane.smith@example.com', + projectName: 'Marketing', + containerName: 'Marketing', + hitsTotal: 75, + hitsSmallSpanTotal: 5, + hitsMediumSpanTotal: 15, + hitsLargeSpanTotal: 30, + favoritesTotal: 3, + modifiedTime: '2024-01-10T14:20:00Z', + createdTime: '2023-11-15T11:30:00Z', + }, + }, +] satisfies Array; diff --git a/src/tools/contentExploration/searchContent.test.ts b/src/tools/contentExploration/searchContent.test.ts index 02530d7a..3dbeee51 100644 --- a/src/tools/contentExploration/searchContent.test.ts +++ b/src/tools/contentExploration/searchContent.test.ts @@ -3,7 +3,7 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Server } from '../../server.js'; import { getSearchContentTool } from './searchContent.js'; -const mockSearchContentResponse = { +export const mockSearchContentResponse = { next: 'next-page-url', prev: 'prev-page-url', pageIndex: 0, @@ -20,6 +20,7 @@ const mockSearchContentResponse = { ownerName: 'John Doe', ownerId: 123, ownerEmail: 'john.doe@example.com', + projectId: 123456, projectName: 'Finance', containerName: 'Finance', hitsTotal: 150, @@ -36,11 +37,12 @@ const mockSearchContentResponse = { uri: 'test-uri-2', content: { type: 'datasource', - luid: 'datasource-1-luid', + datasourceLuid: 'datasource-1-luid', title: 'Customer Data', ownerName: 'Jane Smith', ownerId: 456, ownerEmail: 'jane.smith@example.com', + projectId: 987654, projectName: 'Marketing', containerName: 'Marketing', hitsTotal: 75, @@ -308,8 +310,10 @@ describe('searchContentTool', () => { const result = await getToolResult({ terms: 'nonexistent' }); expect(result.isError).toBe(false); - const responseData = JSON.parse(result.content[0].text as string); - expect(responseData).toEqual([]); + const responseData = result.content[0].text as string; + expect(responseData).toEqual( + 'No search results were found. Either none exist or you do not have permission to view them', + ); }); }); diff --git a/src/tools/contentExploration/searchContent.ts b/src/tools/contentExploration/searchContent.ts index 4c22ef92..5fbf24b4 100644 --- a/src/tools/contentExploration/searchContent.ts +++ b/src/tools/contentExploration/searchContent.ts @@ -13,6 +13,8 @@ import { Tool } from '../tool.js'; import { buildFilterString, buildOrderByString, + constrainSearchContent, + ReducedSearchContentResponse, reduceSearchContentResponse, } from './searchContentUtils.js'; @@ -62,7 +64,7 @@ This tool searches across all supported content types for objects relevant to th const config = getConfig(); const orderByString = orderBy ? buildOrderByString(orderBy) : undefined; const filterString = filter ? buildFilterString(filter) : undefined; - return await searchContentTool.logAndExecute({ + return await searchContentTool.logAndExecute>({ requestId, args: {}, callback: async () => { @@ -87,6 +89,8 @@ This tool searches across all supported content types for objects relevant to th }), ); }, + constrainSuccessResult: (items) => + constrainSearchContent({ items, boundedContext: getConfig().boundedContext }), }); }, }); diff --git a/src/tools/contentExploration/searchContentUtils.test.ts b/src/tools/contentExploration/searchContentUtils.test.ts index be5942aa..0bd9702e 100644 --- a/src/tools/contentExploration/searchContentUtils.test.ts +++ b/src/tools/contentExploration/searchContentUtils.test.ts @@ -1,7 +1,10 @@ import { OrderBy, SearchContentFilter } from '../../sdks/tableau/types/contentExploration.js'; +import invariant from '../../utils/invariant.js'; +import { mockSearchContentResponse } from './searchContent.test.js'; import { buildFilterString, buildOrderByString, + constrainSearchContent, reduceSearchContentResponse, } from './searchContentUtils.js'; @@ -450,4 +453,119 @@ describe('searchContentUtils', () => { }); }); }); + + describe('constrainSearchContent', () => { + it('should return empty result when no items are found', () => { + const result = constrainSearchContent({ + items: [], + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, + }); + + invariant(result.type === 'empty'); + expect(result.message).toBe( + 'No search results were found. Either none exist or you do not have permission to view them', + ); + }); + + it('should return empty result when all items were filtered out by the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: new Set(['123']), + datasourceIds: null, + workbookIds: null, + }, + }); + + invariant(result.type === 'empty'); + expect(result.message).toBe( + [ + 'The set of allowed content that can be queried is limited by the server configuration.', + 'While search results were found, they were all filtered out by the server configuration.', + ].join(' '), + ); + }); + + it('should return success result when no items were filtered out by the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toBe(items); + }); + + it('should return success result when some items were filtered out by allowed projects in the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: new Set(['123456']), + datasourceIds: null, + workbookIds: null, + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toEqual([items[0]]); + }); + + it('should return success result when some items were filtered out by allowed projects in the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: new Set(['123456']), + datasourceIds: null, + workbookIds: null, + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toEqual([items[0]]); + }); + + it('should return success result when some items were filtered out by allowed datasources in the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: null, + datasourceIds: new Set(['some-other-datasource-luid']), + workbookIds: null, + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toEqual([items[0]]); + }); + + it('should return success result when some items were filtered out by allowed workbooks in the bounded context', () => { + const items = reduceSearchContentResponse(mockSearchContentResponse); + const result = constrainSearchContent({ + items, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: new Set(['some-other-workbook-luid']), + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toEqual([items[1]]); + }); + }); }); diff --git a/src/tools/contentExploration/searchContentUtils.ts b/src/tools/contentExploration/searchContentUtils.ts index 9e9cbd35..e71faa35 100644 --- a/src/tools/contentExploration/searchContentUtils.ts +++ b/src/tools/contentExploration/searchContentUtils.ts @@ -1,8 +1,12 @@ +import { BoundedContext } from '../../config.js'; import { OrderBy, SearchContentFilter, SearchContentResponse, } from '../../sdks/tableau/types/contentExploration.js'; +import { ConstrainedResult } from '../tool.js'; + +export type ReducedSearchContentResponse = Partial>; export function buildOrderByString(orderBy: OrderBy): string { const methodsUsed = new Set(); @@ -74,8 +78,8 @@ export function buildFilterString(filter: SearchContentFilter): string { export function reduceSearchContentResponse( response: SearchContentResponse, -): Array> { - const searchResults: Array> = []; +): Array { + const searchResults: Array = []; if (response.items) { for (const item of response.items) { searchResults.push(getReducedSearchItemContent(item.content)); @@ -84,8 +88,46 @@ export function reduceSearchContentResponse( return searchResults; } -function getReducedSearchItemContent(content: Record): Record { - const reducedContent: Record = {}; +type SearchItemContent = + | 'caption' + | 'comments' + | 'connectedWorkbooksCount' + | 'connectionType' + | 'containerName' + | 'datasourceIsPublished' + | 'datasourceLuid' + | 'downstreamWorkbookCount' + | 'extractCreationPending' + | 'extractRefreshedAt' + | 'extractUpdatedAt' + | 'favoritesTotal' + | 'hasActiveDataQualityWarning' + | 'hasExtracts' + | 'hasSevereDataQualityWarning' + | 'hitsSmallSpanTotal' + | 'hitsTotal' + | 'isCertified' + | 'isConnectable' + | 'locationName' + | 'luid' + | 'modifiedTime' + | 'ownerId' + | 'ownerName' + | 'parentWorkbookName' + | 'projectId' + | 'projectName' + | 'sheetType' + | 'tags' + | 'title' + | 'totalViewCount' + | 'viewCountLastMonth' + | 'type' + | 'workbookDescription'; + +function getReducedSearchItemContent( + content: Record, +): Partial> { + const reducedContent: ReducedSearchContentResponse = {}; if (content.modifiedTime) { reducedContent.modifiedTime = content.modifiedTime; } @@ -135,6 +177,9 @@ function getReducedSearchItemContent(content: Record): Record): Record; + boundedContext: BoundedContext; +}): ConstrainedResult> { + if (items.length === 0) { + return { + type: 'empty', + message: + 'No search results were found. Either none exist or you do not have permission to view them', + }; + } + + const { projectIds, datasourceIds, workbookIds } = boundedContext; + + if (projectIds) { + items = items.filter((item) => { + if (typeof item.projectId === 'number' && projectIds.has(item.projectId.toString())) { + // ⚠️ The Search API returns the project "id" (e.g. 861566) + // but the Project REST APIs return the project "LUID" and there is no good way to look up one from the other. + // Admins who want to use a project filter here will need to provide both the id and LUID in their bounded context. + return true; + } + + return false; + }); + } + + if (datasourceIds) { + items = items.filter((item) => { + if ( + (item.type === 'datasource' || item.type === 'unifieddatasource') && + typeof item.datasourceLuid === 'string' && + !datasourceIds.has(item.datasourceLuid) + ) { + return false; + } + + return true; + }); + } + + if (workbookIds) { + items = items.filter((item) => { + if ( + item.type === 'workbook' && + typeof item.luid === 'string' && + !workbookIds.has(item.luid) + ) { + return false; + } + + return true; + }); + } + + if (items.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed content that can be queried is limited by the server configuration.', + 'While search results were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: items, + }; +} diff --git a/src/tools/getDatasourceMetadata/getDatasourceMetadata.test.ts b/src/tools/getDatasourceMetadata/getDatasourceMetadata.test.ts index 93dba1e4..87218ed4 100644 --- a/src/tools/getDatasourceMetadata/getDatasourceMetadata.test.ts +++ b/src/tools/getDatasourceMetadata/getDatasourceMetadata.test.ts @@ -3,8 +3,11 @@ import { Err, Ok } from 'ts-results-es'; import { Server } from '../../server.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; +import { exportedForTesting as resourceAccessCheckerExportedForTesting } from '../resourceAccessChecker.js'; import { getGetDatasourceMetadataTool } from './getDatasourceMetadata.js'; +const { resetResourceAccessCheckerSingleton } = resourceAccessCheckerExportedForTesting; + const mockReadMetadataResponses = vi.hoisted(() => ({ success: { data: [ @@ -164,8 +167,14 @@ describe('getDatasourceMetadataTool', () => { beforeEach(() => { vi.clearAllMocks(); // Set default config for existing tests + resetResourceAccessCheckerSingleton(); mocks.mockGetConfig.mockReturnValue({ disableMetadataApiRequests: false, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, }); }); @@ -506,6 +515,11 @@ describe('getDatasourceMetadataTool', () => { // Configure to disable metadata API requests mocks.mockGetConfig.mockReturnValue({ disableMetadataApiRequests: true, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, }); mocks.mockReadMetadata.mockResolvedValue(new Ok(mockReadMetadataResponses.success)); @@ -554,6 +568,11 @@ describe('getDatasourceMetadataTool', () => { // Configure to disable metadata API requests mocks.mockGetConfig.mockReturnValue({ disableMetadataApiRequests: true, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, }); const errorMessage = 'ReadMetadata API Error'; @@ -584,6 +603,28 @@ describe('getDatasourceMetadataTool', () => { expect(result.content[0].text).toBe(getVizqlDataServiceDisabledError()); expect(mocks.mockGraphql).not.toHaveBeenCalled(); }); + + it('should return data source not allowed error when datasource is not allowed', async () => { + mocks.mockGetConfig.mockReturnValue({ + boundedContext: { + projectIds: null, + datasourceIds: new Set(['some-other-datasource-luid']), + workbookIds: null, + }, + }); + + const result = await getToolResult(); + expect(result.isError).toBe(true); + expect(result.content[0].text).toBe( + [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + 'Querying the datasource with LUID test-luid is not allowed.', + ].join(' '), + ); + + expect(mocks.mockReadMetadata).not.toHaveBeenCalled(); + expect(mocks.mockGraphql).not.toHaveBeenCalled(); + }); }); async function getToolResult(): Promise { diff --git a/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts b/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts index a6b062f0..f737565b 100644 --- a/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts +++ b/src/tools/getDatasourceMetadata/getDatasourceMetadata.ts @@ -7,6 +7,7 @@ import { useRestApi } from '../../restApiInstance.js'; import { GraphQLResponse } from '../../sdks/tableau/apis/metadataApi.js'; import { Server } from '../../server.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; import { validateDatasourceLuid } from '../validateDatasourceLuid.js'; import { @@ -82,9 +83,14 @@ const paramsSchema = { datasourceLuid: z.string().nonempty(), }; -export type GetDatasourceMetadataError = { - type: 'feature-disabled'; -}; +export type GetDatasourceMetadataError = + | { + type: 'feature-disabled'; + } + | { + type: 'datasource-not-allowed'; + message: string; + }; export const getGetDatasourceMetadataTool = (server: Server): Tool => { const getDatasourceMetadataTool = new Tool({ @@ -113,6 +119,18 @@ export const getGetDatasourceMetadataTool = (server: Server): Tool { + const isDatasourceAllowedResult = await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid, + restApiArgs: { config, requestId, server }, + }); + + if (!isDatasourceAllowedResult.allowed) { + return new Err({ + type: 'datasource-not-allowed', + message: isDatasourceAllowedResult.message, + }); + } + return await useRestApi({ config, requestId, @@ -150,10 +168,18 @@ export const getGetDatasourceMetadataTool = (server: Server): Tool { + return { + type: 'success', + result: fields, + }; + }, getErrorText: (error: GetDatasourceMetadataError) => { switch (error.type) { case 'feature-disabled': return getVizqlDataServiceDisabledError(); + case 'datasource-not-allowed': + return error.message; } }, }); diff --git a/src/tools/listDatasources/listDatasources.test.ts b/src/tools/listDatasources/listDatasources.test.ts index c8bb9b1b..43a84a78 100644 --- a/src/tools/listDatasources/listDatasources.test.ts +++ b/src/tools/listDatasources/listDatasources.test.ts @@ -1,29 +1,9 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Server } from '../../server.js'; -import { getListDatasourcesTool } from './listDatasources.js'; - -const mockDatasources = { - pagination: { - pageNumber: 1, - pageSize: 10, - totalAvailable: 2, - }, - datasources: [ - { - id: 'ds1', - name: 'Superstore', - description: 'Sample superstore data source', - project: { name: 'Samples', id: 'proj1' }, - }, - { - id: 'ds2', - name: 'Finance', - description: 'Financial analysis data source', - project: { name: 'Finance', id: 'proj2' }, - }, - ], -}; +import invariant from '../../utils/invariant.js'; +import { constrainDatasources, getListDatasourcesTool } from './listDatasources.js'; +import { mockDatasources } from './mockDatasources.js'; const mocks = vi.hoisted(() => ({ mockListDatasources: vi.fn(), @@ -72,6 +52,59 @@ describe('listDatasourcesTool', () => { expect(result.isError).toBe(true); expect(result.content[0].text).toContain(errorMessage); }); + + describe('constrainDatasources', () => { + it('should return empty result when no datasources are found', () => { + const result = constrainDatasources({ + datasources: [], + boundedContext: { projectIds: null, datasourceIds: null, workbookIds: null }, + }); + + invariant(result.type === 'empty'); + expect(result.message).toBe( + 'No datasources were found. Either none exist or you do not have permission to view them', + ); + }); + + it('should return empty results when all datasources were filtered out by the bounded context', () => { + const result = constrainDatasources({ + datasources: mockDatasources.datasources, + boundedContext: { projectIds: new Set(['123']), datasourceIds: null, workbookIds: null }, + }); + + invariant(result.type === 'empty'); + expect(result.message).toBe( + [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + 'While data sources were found, they were all filtered out by the server configuration.', + ].join(' '), + ); + }); + + it('should return success result when no datasources were filtered out by the bounded context', () => { + const result = constrainDatasources({ + datasources: mockDatasources.datasources, + boundedContext: { projectIds: null, datasourceIds: null, workbookIds: null }, + }); + + invariant(result.type === 'success'); + expect(result.result).toBe(mockDatasources.datasources); + }); + + it('should return success result when some datasources were filtered out by the bounded context', () => { + const result = constrainDatasources({ + datasources: mockDatasources.datasources, + boundedContext: { + projectIds: new Set([mockDatasources.datasources[0].project.id]), + datasourceIds: new Set([mockDatasources.datasources[0].id]), + workbookIds: null, + }, + }); + + invariant(result.type === 'success'); + expect(result.result).toEqual([mockDatasources.datasources[0]]); + }); + }); }); async function getToolResult(params: { filter: string }): Promise { diff --git a/src/tools/listDatasources/listDatasources.ts b/src/tools/listDatasources/listDatasources.ts index b80b9701..de370734 100644 --- a/src/tools/listDatasources/listDatasources.ts +++ b/src/tools/listDatasources/listDatasources.ts @@ -2,12 +2,13 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Ok } from 'ts-results-es'; import { z } from 'zod'; -import { getConfig } from '../../config.js'; +import { BoundedContext, getConfig } from '../../config.js'; import { useRestApi } from '../../restApiInstance.js'; +import { DataSource } from '../../sdks/tableau/types/dataSource.js'; import { Server } from '../../server.js'; import { paginate } from '../../utils/paginate.js'; import { genericFilterDescription } from '../genericFilterDescription.js'; -import { Tool } from '../tool.js'; +import { ConstrainedResult, Tool } from '../tool.js'; import { parseAndValidateDatasourcesFilterString } from './datasourcesFilterUtils.js'; const paramsSchema = { @@ -115,9 +116,51 @@ export const getListDatasourcesTool = (server: Server): Tool + constrainDatasources({ datasources, boundedContext: getConfig().boundedContext }), }); }, }); return listDatasourcesTool; }; + +export function constrainDatasources({ + datasources, + boundedContext, +}: { + datasources: Array; + boundedContext: BoundedContext; +}): ConstrainedResult> { + if (datasources.length === 0) { + return { + type: 'empty', + message: + 'No datasources were found. Either none exist or you do not have permission to view them', + }; + } + + const { projectIds, datasourceIds } = boundedContext; + if (projectIds) { + datasources = datasources.filter((datasource) => projectIds.has(datasource.project.id)); + } + + if (datasourceIds) { + datasources = datasources.filter((datasource) => datasourceIds.has(datasource.id)); + } + + if (datasources.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + 'While data sources were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: datasources, + }; +} diff --git a/src/tools/listDatasources/mockDatasources.ts b/src/tools/listDatasources/mockDatasources.ts new file mode 100644 index 00000000..56498d7c --- /dev/null +++ b/src/tools/listDatasources/mockDatasources.ts @@ -0,0 +1,25 @@ +export const mockDatasources = { + pagination: { + pageNumber: 1, + pageSize: 10, + totalAvailable: 2, + }, + datasources: [ + { + id: '2d935df8-fe7e-4fd8-bb14-35eb4ba31d45', + name: 'Superstore Datasource', + project: { + id: 'cbec32db-a4a2-4308-b5f0-4fc67322f359', + name: 'Samples', + }, + }, + { + id: 'ba1da5d9-e92b-4ff2-ad91-4238265d877c', + name: 'Finance Datasource', + project: { + name: 'Finance', + id: '4862efd9-3c24-4053-ae1f-18caf18b6ffe', + }, + }, + ], +}; diff --git a/src/tools/pulse/constrainPulseDefinitions.ts b/src/tools/pulse/constrainPulseDefinitions.ts new file mode 100644 index 00000000..46e5ff53 --- /dev/null +++ b/src/tools/pulse/constrainPulseDefinitions.ts @@ -0,0 +1,38 @@ +import { getConfig } from '../../config.js'; +import { PulseMetricDefinition } from '../../sdks/tableau/types/pulse.js'; +import { ConstrainedResult } from '../tool.js'; + +export function constrainPulseDefinitions( + definitions: Array, +): ConstrainedResult> { + if (definitions.length === 0) { + return { + type: 'empty', + message: + 'No Pulse Metric Definitions were found. Either none exist or you do not have permission to view them', + }; + } + + const { datasourceIds } = getConfig().boundedContext; + + if (datasourceIds) { + definitions = definitions.filter((definition) => + datasourceIds.has(definition.specification.datasource.id), + ); + } + + if (definitions.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed Pulse Metric Definitions that can be queried is limited by the server configuration.', + 'While Pulse Metric Definitions were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: definitions, + }; +} diff --git a/src/tools/pulse/constrainPulseMetrics.ts b/src/tools/pulse/constrainPulseMetrics.ts new file mode 100644 index 00000000..5d192d88 --- /dev/null +++ b/src/tools/pulse/constrainPulseMetrics.ts @@ -0,0 +1,36 @@ +import { getConfig } from '../../config.js'; +import { PulseMetric } from '../../sdks/tableau/types/pulse.js'; +import { ConstrainedResult } from '../tool.js'; + +export function constrainPulseMetrics( + metrics: Array, +): ConstrainedResult> { + if (metrics.length === 0) { + return { + type: 'empty', + message: + 'No Pulse Metrics were found. Either none exist or you do not have permission to view them', + }; + } + + const { datasourceIds } = getConfig().boundedContext; + + if (datasourceIds) { + metrics = metrics.filter((metric) => datasourceIds.has(metric.datasource_luid)); + } + + if (metrics.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed Pulse Metrics that can be queried is limited by the server configuration.', + 'While Pulse Metrics were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: metrics, + }; +} diff --git a/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts b/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts index ba77968a..e66fc837 100644 --- a/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts +++ b/src/tools/pulse/generateMetricValueInsightBundle/generatePulseMetricValueInsightBundleTool.ts @@ -1,10 +1,13 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import { Err } from 'ts-results-es'; import z from 'zod'; import { getConfig } from '../../../config.js'; import { useRestApi } from '../../../restApiInstance.js'; +import { PulseDisabledError } from '../../../sdks/tableau/methods/pulseMethods.js'; import { pulseBundleRequestSchema, + PulseBundleResponse, pulseInsightBundleTypeEnum, } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; @@ -16,6 +19,16 @@ const paramsSchema = { bundleType: z.optional(z.enum(pulseInsightBundleTypeEnum)), }; +export type GeneratePulseMetricValueInsightBundleError = + | { + type: 'feature-disabled'; + reason: PulseDisabledError; + } + | { + type: 'datasource-not-allowed'; + message: string; + }; + export const getGeneratePulseMetricValueInsightBundleTool = ( server: Server, ): Tool => { @@ -138,11 +151,30 @@ Generate an insight bundle for the current aggregated value for Pulse Metric usi }, callback: async ({ bundleRequest, bundleType }, { requestId }): Promise => { const config = getConfig(); - return await generatePulseMetricValueInsightBundleTool.logAndExecute({ + return await generatePulseMetricValueInsightBundleTool.logAndExecute< + PulseBundleResponse, + GeneratePulseMetricValueInsightBundleError + >({ requestId, args: { bundleRequest, bundleType }, callback: async () => { - return await useRestApi({ + const { datasourceIds } = config.boundedContext; + if (datasourceIds) { + const datasourceLuid = + bundleRequest.bundle_request.input.metric.definition.datasource.id; + + if (!datasourceIds.has(datasourceLuid)) { + return new Err({ + type: 'datasource-not-allowed', + message: [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + `Generating the Pulse Metric Value Insight Bundle is not allowed because its definition's datasource with LUID ${datasourceLuid} is not in the allowed set of data sources.`, + ].join(' '), + }); + } + } + + const result = await useRestApi({ config, requestId, server, @@ -153,8 +185,30 @@ Generate an insight bundle for the current aggregated value for Pulse Metric usi bundleType ?? 'ban', ), }); + + if (result.isErr()) { + return new Err({ + type: 'feature-disabled', + reason: result.error, + }); + } + + return result; + }, + constrainSuccessResult: (insightBundle) => { + return { + type: 'success', + result: insightBundle, + }; + }, + getErrorText: (error) => { + switch (error.type) { + case 'feature-disabled': + return getPulseDisabledError(error.reason); + case 'datasource-not-allowed': + return error.message; + } }, - getErrorText: getPulseDisabledError, }); }, }); diff --git a/src/tools/pulse/getPulseDisabledError.ts b/src/tools/pulse/getPulseDisabledError.ts index d100e021..1f236faa 100644 --- a/src/tools/pulse/getPulseDisabledError.ts +++ b/src/tools/pulse/getPulseDisabledError.ts @@ -1,4 +1,6 @@ -export function getPulseDisabledError(reason: 'tableau-server' | 'pulse-disabled'): string { +import { PulseDisabledError } from '../../sdks/tableau/methods/pulseMethods.js'; + +export function getPulseDisabledError(reason: PulseDisabledError): string { switch (reason) { case 'tableau-server': return [ diff --git a/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts b/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts index f174a2eb..58349514 100644 --- a/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts +++ b/src/tools/pulse/listAllMetricDefinitions/listAllPulseMetricDefinitions.ts @@ -6,6 +6,7 @@ import { useRestApi } from '../../../restApiInstance.js'; import { pulseMetricDefinitionViewEnum } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; import { Tool } from '../../tool.js'; +import { constrainPulseDefinitions } from '../constrainPulseDefinitions.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; const paramsSchema = { @@ -60,6 +61,7 @@ Retrieves a list of all published Pulse Metric Definitions using the Tableau RES }, }); }, + constrainSuccessResult: constrainPulseDefinitions, getErrorText: getPulseDisabledError, }); }, diff --git a/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts b/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts index b9e06e2d..bfb34945 100644 --- a/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts +++ b/src/tools/pulse/listMetricDefinitionsFromDefinitionIds/listPulseMetricDefinitionsFromDefinitionIds.ts @@ -6,6 +6,7 @@ import { useRestApi } from '../../../restApiInstance.js'; import { pulseMetricDefinitionViewEnum } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; import { Tool } from '../../tool.js'; +import { constrainPulseDefinitions } from '../constrainPulseDefinitions.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; const paramsSchema = { @@ -74,6 +75,7 @@ Retrieves a list of specific Pulse Metric Definitions using the Tableau REST API }, }); }, + constrainSuccessResult: constrainPulseDefinitions, getErrorText: getPulseDisabledError, }); }, diff --git a/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts b/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts index a2ba878a..bdcf0b24 100644 --- a/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts +++ b/src/tools/pulse/listMetricSubscriptions/listPulseMetricSubscriptions.ts @@ -46,6 +46,72 @@ Retrieves a list of published Pulse Metric Subscriptions for the current user us }, }); }, + constrainSuccessResult: async (subscriptions) => { + if (subscriptions.length === 0) { + return { + type: 'empty', + message: + 'No Pulse Metric Subscriptions were found. Either none exist or you do not have permission to view them', + }; + } + + const { datasourceIds } = getConfig().boundedContext; + + if (!datasourceIds) { + // No datasource IDs to filter by, return all subscriptions. + return { + type: 'success', + result: subscriptions, + }; + } + + const metricsResult = await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:insight_metrics:read'], + callback: async (restApi) => { + return await restApi.pulseMethods.listPulseMetricsFromMetricIds( + subscriptions.map((subscription) => subscription.metric_id), + ); + }, + }); + + if (metricsResult.isErr()) { + return { + type: 'error', + message: [ + 'The set of allowed Pulse Metric Subscriptions that can be queried is limited by the server configuration.', + 'While Pulse Metric Subscriptions were found, retrieving information about them to determine if they are allowed to be viewed failed.', + ].join(' '), + }; + } + + const allowedMetricIds = new Set( + metricsResult.value + .filter((metric) => datasourceIds.has(metric.datasource_luid)) + .map((metric) => metric.id), + ); + + subscriptions = subscriptions.filter((subscription) => + allowedMetricIds.has(subscription.metric_id), + ); + + if (subscriptions.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed Pulse Metric Subscriptions that can be queried is limited by the server configuration.', + 'While Pulse Metric Subscriptions were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: subscriptions, + }; + }, getErrorText: getPulseDisabledError, }); }, diff --git a/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts b/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts index 81da4e1f..0514f513 100644 --- a/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts +++ b/src/tools/pulse/listMetricsFromMetricDefinitionId/listPulseMetricsFromMetricDefinitionId.ts @@ -3,8 +3,11 @@ import { z } from 'zod'; import { getConfig } from '../../../config.js'; import { useRestApi } from '../../../restApiInstance.js'; +import { PulseDisabledError } from '../../../sdks/tableau/methods/pulseMethods.js'; +import { PulseMetric } from '../../../sdks/tableau/types/pulse.js'; import { Server } from '../../../server.js'; import { Tool } from '../../tool.js'; +import { constrainPulseMetrics } from '../constrainPulseMetrics.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; const paramsSchema = { @@ -34,7 +37,10 @@ Retrieves a list of published Pulse Metrics from a Pulse Metric Definition using }, callback: async ({ pulseMetricDefinitionID }, { requestId }): Promise => { const config = getConfig(); - return await listPulseMetricsFromMetricDefinitionIdTool.logAndExecute({ + return await listPulseMetricsFromMetricDefinitionIdTool.logAndExecute< + Array, + PulseDisabledError + >({ requestId, args: { pulseMetricDefinitionID }, callback: async () => { @@ -50,6 +56,7 @@ Retrieves a list of published Pulse Metrics from a Pulse Metric Definition using }, }); }, + constrainSuccessResult: constrainPulseMetrics, getErrorText: getPulseDisabledError, }); }, diff --git a/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts b/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts index 5968bf9c..8c9ba8d7 100644 --- a/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts +++ b/src/tools/pulse/listMetricsFromMetricIds/listPulseMetricsFromMetricIds.ts @@ -5,6 +5,7 @@ import { getConfig } from '../../../config.js'; import { useRestApi } from '../../../restApiInstance.js'; import { Server } from '../../../server.js'; import { Tool } from '../../tool.js'; +import { constrainPulseMetrics } from '../constrainPulseMetrics.js'; import { getPulseDisabledError } from '../getPulseDisabledError.js'; const paramsSchema = { @@ -52,6 +53,7 @@ Retrieves a list of published Pulse Metrics from a list of metric IDs using the }, }); }, + constrainSuccessResult: constrainPulseMetrics, getErrorText: getPulseDisabledError, }); }, diff --git a/src/tools/queryDatasource/datasourceCredentials.test.ts b/src/tools/queryDatasource/datasourceCredentials.test.ts index 33338628..b8ec1381 100644 --- a/src/tools/queryDatasource/datasourceCredentials.test.ts +++ b/src/tools/queryDatasource/datasourceCredentials.test.ts @@ -5,19 +5,20 @@ import { const { resetDatasourceCredentials } = datasourceCredentialsExportedForTesting; -describe('getDatasourceCredentials', () => { - const originalEnv = process.env; +const mocks = vi.hoisted(() => ({ + mockGetConfig: vi.fn(), +})); + +vi.mock('../../config.js', () => ({ + getConfig: mocks.mockGetConfig, +})); +describe('getDatasourceCredentials', () => { beforeEach(() => { resetDatasourceCredentials(); - process.env = { - ...originalEnv, - DATASOURCE_CREDENTIALS: undefined, - }; - }); - - afterEach(() => { - process.env = { ...originalEnv }; + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: undefined, + }); }); it('should return undefined when DATASOURCE_CREDENTIALS is not set', () => { @@ -25,13 +26,14 @@ describe('getDatasourceCredentials', () => { }); it('should return undefined when DATASOURCE_CREDENTIALS is empty', () => { - process.env.DATASOURCE_CREDENTIALS = ''; expect(getDatasourceCredentials('test-luid')).toBeUndefined(); }); it('should return credentials for a valid datasource LUID', () => { - process.env.DATASOURCE_CREDENTIALS = JSON.stringify({ - 'ds-luid': [{ luid: 'test-luid', u: 'test-user', p: 'test-pass' }], + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: JSON.stringify({ + 'ds-luid': [{ luid: 'test-luid', u: 'test-user', p: 'test-pass' }], + }), }); expect(getDatasourceCredentials('ds-luid')).toEqual([ @@ -53,23 +55,30 @@ describe('getDatasourceCredentials', () => { }); it('should return undefined for a non-existent datasource LUID', () => { - process.env.DATASOURCE_CREDENTIALS = JSON.stringify({ - 'ds-luid': [{ luid: 'test-luid', u: 'test-user', p: 'test-pass' }], + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: JSON.stringify({ + 'ds-luid': [{ luid: 'test-luid', u: 'test-user', p: 'test-pass' }], + }), }); expect(getDatasourceCredentials('other-luid')).toBeUndefined(); }); it('should throw error when DATASOURCE_CREDENTIALS is invalid JSON', () => { - process.env.DATASOURCE_CREDENTIALS = 'invalid-json'; + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: 'invalid-json', + }); + expect(() => getDatasourceCredentials('test-luid')).toThrow( 'Invalid datasource credentials format. Could not parse JSON string: invalid-json', ); }); it('should throw error when credential schema is invalid', () => { - process.env.DATASOURCE_CREDENTIALS = JSON.stringify({ - 'ds-luid': [{ luid: 'test-luid', x: 'test-user', y: 'test-pass' }], + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: JSON.stringify({ + 'ds-luid': [{ luid: 'test-luid', x: 'test-user', y: 'test-pass' }], + }), }); expect(() => getDatasourceCredentials('ds-luid')).toThrow(); diff --git a/src/tools/queryDatasource/queryDatasource.test.ts b/src/tools/queryDatasource/queryDatasource.test.ts index 55e8fda1..c37cccaf 100644 --- a/src/tools/queryDatasource/queryDatasource.test.ts +++ b/src/tools/queryDatasource/queryDatasource.test.ts @@ -5,10 +5,12 @@ import { Err, Ok } from 'ts-results-es'; import { QueryOutput } from '../../sdks/tableau/apis/vizqlDataServiceApi.js'; import { Server } from '../../server.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; +import { exportedForTesting as resourceAccessCheckerExportedForTesting } from '../resourceAccessChecker.js'; import { exportedForTesting as datasourceCredentialsExportedForTesting } from './datasourceCredentials.js'; import { getQueryDatasourceTool } from './queryDatasource.js'; const { resetDatasourceCredentials } = datasourceCredentialsExportedForTesting; +const { resetResourceAccessCheckerSingleton } = resourceAccessCheckerExportedForTesting; const mockVdsResponses = vi.hoisted(() => ({ success: { @@ -41,6 +43,7 @@ const mockVdsResponses = vi.hoisted(() => ({ const mocks = vi.hoisted(() => ({ mockQueryDatasource: vi.fn(), + mockGetConfig: vi.fn(), })); vi.mock('../../restApiInstance.js', () => ({ @@ -55,19 +58,23 @@ vi.mock('../../restApiInstance.js', () => ({ ), })); -describe('queryDatasourceTool', () => { - const originalEnv = process.env; +vi.mock('../../config.js', () => ({ + getConfig: mocks.mockGetConfig, +})); +describe('queryDatasourceTool', () => { beforeEach(() => { vi.clearAllMocks(); resetDatasourceCredentials(); - process.env = { - ...originalEnv, - }; - }); - - afterEach(() => { - process.env = { ...originalEnv }; + resetResourceAccessCheckerSingleton(); + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: undefined, + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, + }); }); it('should create a tool instance with correct properties', () => { @@ -136,11 +143,17 @@ describe('queryDatasourceTool', () => { it('should add datasource credentials to the request when provided', async () => { mocks.mockQueryDatasource.mockResolvedValue(new Ok(mockVdsResponses.success)); - - process.env.DATASOURCE_CREDENTIALS = JSON.stringify({ - '71db762b-6201-466b-93da-57cc0aec8ed9': [ - { luid: 'test-luid', u: 'test-user', p: 'test-pass' }, - ], + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: JSON.stringify({ + '71db762b-6201-466b-93da-57cc0aec8ed9': [ + { luid: 'test-luid', u: 'test-user', p: 'test-pass' }, + ], + }), + boundedContext: { + projectIds: null, + datasourceIds: null, + workbookIds: null, + }, }); const result = await getToolResult(); @@ -474,6 +487,28 @@ describe('queryDatasourceTool', () => { expect(result.isError).toBe(true); expect(result.content[0].text).toBe(getVizqlDataServiceDisabledError()); }); + + it('should return data source not allowed error when datasource is not allowed', async () => { + mocks.mockGetConfig.mockReturnValue({ + datasourceCredentials: undefined, + boundedContext: { + projectIds: null, + datasourceIds: new Set(['some-other-datasource-luid']), + workbookIds: null, + }, + }); + + const result = await getToolResult(); + expect(result.isError).toBe(true); + expect(result.content[0].text).toBe( + [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + 'Querying the datasource with LUID 71db762b-6201-466b-93da-57cc0aec8ed9 is not allowed.', + ].join(' '), + ); + + expect(mocks.mockQueryDatasource).not.toHaveBeenCalled(); + }); }); async function getToolResult(): Promise { diff --git a/src/tools/queryDatasource/queryDatasource.ts b/src/tools/queryDatasource/queryDatasource.ts index 40dcbfae..6800fec7 100644 --- a/src/tools/queryDatasource/queryDatasource.ts +++ b/src/tools/queryDatasource/queryDatasource.ts @@ -13,6 +13,7 @@ import { } from '../../sdks/tableau/apis/vizqlDataServiceApi.js'; import { Server } from '../../server.js'; import { getVizqlDataServiceDisabledError } from '../getVizqlDataServiceDisabledError.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; import { getDatasourceCredentials } from './datasourceCredentials.js'; import { handleQueryDatasourceError } from './queryDatasourceErrorHandler.js'; @@ -31,6 +32,10 @@ export type QueryDatasourceError = | { type: 'feature-disabled'; } + | { + type: 'datasource-not-allowed'; + message: string; + } | { type: 'filter-validation'; message: string; @@ -58,6 +63,18 @@ export const getQueryDatasourceTool = (server: Server): Tool { + const isDatasourceAllowedResult = await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid, + restApiArgs: { config, requestId, server }, + }); + + if (!isDatasourceAllowedResult.allowed) { + return new Err({ + type: 'datasource-not-allowed', + message: isDatasourceAllowedResult.message, + }); + } + const datasource: Datasource = { datasourceLuid }; const options = { returnFormat: 'OBJECTS', @@ -118,10 +135,18 @@ export const getQueryDatasourceTool = (server: Server): Tool { + return { + type: 'success', + result: queryOutput, + }; + }, getErrorText: (error: QueryDatasourceError) => { switch (error.type) { case 'feature-disabled': return getVizqlDataServiceDisabledError(); + case 'datasource-not-allowed': + return error.message; case 'filter-validation': return JSON.stringify({ requestId, diff --git a/src/tools/resourceAccessChecker.test.ts b/src/tools/resourceAccessChecker.test.ts new file mode 100644 index 00000000..3f1cc80f --- /dev/null +++ b/src/tools/resourceAccessChecker.test.ts @@ -0,0 +1,585 @@ +import { getConfig } from '../config.js'; +import { Server } from '../server.js'; +import { mockDatasources } from './listDatasources/mockDatasources.js'; +import { exportedForTesting } from './resourceAccessChecker.js'; +import { mockView } from './views/mockView.js'; +import { mockWorkbook } from './workbooks/mockWorkbook.js'; + +const { createResourceAccessChecker } = exportedForTesting; + +const mocks = vi.hoisted(() => ({ + mockGetView: vi.fn(), + mockGetWorkbook: vi.fn(), + mockQueryDatasource: vi.fn(), +})); + +vi.mock('../restApiInstance.js', () => ({ + useRestApi: vi.fn().mockImplementation(async ({ callback }) => + callback({ + viewsMethods: { + getView: mocks.mockGetView, + }, + workbooksMethods: { + getWorkbook: mocks.mockGetWorkbook, + }, + datasourcesMethods: { + queryDatasource: mocks.mockQueryDatasource, + }, + siteId: 'test-site-id', + }), + ), +})); + +describe('ResourceAccessChecker', () => { + const restApiArgs = { config: getConfig(), requestId: 'request-id', server: getServer() }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe('isDatasourceAllowed', () => { + describe('allowed', () => { + it('should return allowed when the datasource LUID is allowed by the datasources in the bounded context', async () => { + const mockDatasource = mockDatasources.datasources[0]; + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: new Set([mockDatasource.id]), + workbookIds: null, + }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Check again to exercise the cache. + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect(mocks.mockQueryDatasource).not.toHaveBeenCalled(); + }); + + it('should return allowed when the datasource exists in a project that is allowed by the projects in the bounded context', async () => { + const mockDatasource = mockDatasources.datasources[0]; + mocks.mockQueryDatasource.mockResolvedValue(mockDatasource); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set([mockDatasource.project.id]), + datasourceIds: new Set([mockDatasource.id]), + workbookIds: null, + }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Query Datasource" API each time. + expect(mocks.mockQueryDatasource).toHaveBeenCalledTimes(2); + }); + }); + + describe('not allowed', () => { + it('should return not allowed when the datasource LUID is not allowed by the datasources in the bounded context', async () => { + const mockDatasource = mockDatasources.datasources[0]; + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: new Set(['some-datasource-luid']), + workbookIds: null, + }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + `Querying the datasource with LUID ${mockDatasource.id} is not allowed.`, + ].join(' '), + }); + }); + + it('should return not allowed when the datasource exists in a project that is not allowed by the projects in the bounded context', async () => { + const mockDatasource = mockDatasources.datasources[0]; + mocks.mockQueryDatasource.mockResolvedValue(mockDatasource); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: null, + workbookIds: null, + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The datasource with LUID ${mockDatasource.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Query Datasource" API each time. + expect(mocks.mockQueryDatasource).toHaveBeenCalledTimes(2); + }); + + it('should return not allowed when the datasource is allowed by the datasources in the bounded context but exists in a project that is not allowed by the projects in the bounded context', async () => { + const mockDatasource = mockDatasources.datasources[0]; + mocks.mockQueryDatasource.mockResolvedValue(mockDatasource); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: new Set([mockDatasource.id]), + workbookIds: null, + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The datasource with LUID ${mockDatasource.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isDatasourceAllowed({ + datasourceLuid: mockDatasource.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Query Datasource" API each time. + expect(mocks.mockQueryDatasource).toHaveBeenCalledTimes(2); + }); + }); + }); + + describe('isWorkbookAllowed', () => { + describe('allowed', () => { + it('should return allowed when the workbook ID is allowed by the workbooks in the bounded context', async () => { + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Check again to exercise the cache. + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect(mocks.mockGetWorkbook).not.toHaveBeenCalled(); + }); + + it('should return allowed when the workbook is in a project that is allowed by the projects in the bounded context', async () => { + mocks.mockGetWorkbook.mockResolvedValue(mockWorkbook); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set([mockWorkbook.project.id]), + datasourceIds: null, + workbookIds: null, + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get Workbook" API each time. + expect(mocks.mockGetWorkbook).toHaveBeenCalledTimes(2); + }); + + it('should return allowed when the workbook is allowed by the workbooks in the bounded context and exists in a project that is allowed by the projects in the bounded context', async () => { + mocks.mockGetWorkbook.mockResolvedValue(mockWorkbook); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set([mockWorkbook.project.id]), + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get Workbook" API each time. + expect(mocks.mockGetWorkbook).toHaveBeenCalledTimes(2); + }); + }); + + describe('not allowed', () => { + it('should return not allowed when the workbook ID is not allowed by the workbooks in the bounded context', async () => { + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: null, + workbookIds: new Set(['some-workbook-id']), + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: [ + 'The set of allowed workbooks that can be queried is limited by the server configuration.', + `Querying the workbook with LUID ${mockWorkbook.id} is not allowed.`, + ].join(' '), + }); + }); + + it('should return not allowed when the workbook is in a project that is not allowed by the projects in the bounded context', async () => { + mocks.mockGetWorkbook.mockResolvedValue(mockWorkbook); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: null, + workbookIds: null, + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The workbook with LUID ${mockWorkbook.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get Workbook" API each time. + expect(mocks.mockGetWorkbook).toHaveBeenCalledTimes(2); + }); + + it('should return not allowed when the workbook is allowed by the workbooks in the bounded context and exists in a project that is not allowed by the projects in the bounded context', async () => { + mocks.mockGetWorkbook.mockResolvedValue(mockWorkbook); + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The workbook with LUID ${mockWorkbook.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isWorkbookAllowed({ + workbookId: mockWorkbook.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get Workbook" API each time. + expect(mocks.mockGetWorkbook).toHaveBeenCalledTimes(2); + }); + }); + }); + + describe('isViewAllowed', () => { + describe('allowed', () => { + it('should return allowed when the view exists in a workbook that is allowed by the workbooks in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is not enabled, we cached the result so we only need to call the "Get View" API once. + expect(mocks.mockGetView).toHaveBeenCalledOnce(); + }); + + it('should return allowed when the view exists in a workbook that is allowed by the projects in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set([mockView.project.id]), + datasourceIds: null, + workbookIds: null, + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is enabled, we can't cache the result and we need to call the "Get View" API each time. + expect(mocks.mockGetView).toHaveBeenCalledTimes(2); + }); + + it('should return allowed when the view exists in a workbook that is allowed by the workbooks and the projects in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set([mockView.project.id]), + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ allowed: true }); + + // Since project filtering is enabled, we can't cache the result and we need to call the "Get View" API each time. + expect(mocks.mockGetView).toHaveBeenCalledTimes(2); + }); + }); + + describe('not allowed', () => { + it('should return not allowed when the view exists in a workbook that is not allowed by the workbooks in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: null, + datasourceIds: null, + workbookIds: new Set(['some-workbook-id']), + }); + + const expectedMessage = [ + 'The set of allowed workbooks that can be queried is limited by the server configuration.', + `The view with LUID ${mockView.id} cannot be queried because it does not belong to an allowed workbook.`, + ].join(' '); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is not enabled, we can cache the result so we only need to call the "Get View" API once. + expect(mocks.mockGetView).toHaveBeenCalledTimes(1); + }); + + it('should return not allowed when the view exists in a workbook that is not allowed by the projects in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: null, + workbookIds: null, + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The view with LUID ${mockView.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get View" API each time. + expect(mocks.mockGetView).toHaveBeenCalledTimes(2); + }); + + it('should return not allowed when the view exists in a workbook that is allowed in the bounded context but exists in a project that is not allowed by the projects in the bounded context', async () => { + mocks.mockGetView.mockResolvedValue(mockView); + const resourceAccessChecker = createResourceAccessChecker({ + projectIds: new Set(['some-project-id']), + datasourceIds: null, + workbookIds: new Set([mockWorkbook.id]), + }); + + const expectedMessage = [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The view with LUID ${mockView.id} cannot be queried because it does not belong to an allowed project.`, + ].join(' '); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + expect( + await resourceAccessChecker.isViewAllowed({ + viewId: mockView.id, + restApiArgs, + }), + ).toEqual({ + allowed: false, + message: expectedMessage, + }); + + // Since project filtering is enabled, we cannot cache the result so we need to call the "Get View" API each time. + expect(mocks.mockGetView).toHaveBeenCalledTimes(2); + }); + }); + }); +}); + +function getServer(): InstanceType { + const server = new Server(); + server.tool = vi.fn(); + return server; +} diff --git a/src/tools/resourceAccessChecker.ts b/src/tools/resourceAccessChecker.ts new file mode 100644 index 00000000..6c0d5fd3 --- /dev/null +++ b/src/tools/resourceAccessChecker.ts @@ -0,0 +1,315 @@ +import { RequestId } from '@modelcontextprotocol/sdk/types.js'; + +import { BoundedContext, Config, getConfig } from '../config.js'; +import { useRestApi } from '../restApiInstance.js'; +import { Server } from '../server.js'; + +type AllowedResult = { allowed: true } | { allowed: false; message: string }; +type RestApiArgs = { + config: Config; + requestId: RequestId; + server: Server; +}; + +class ResourceAccessChecker { + private _allowedProjectIds: Set | null | undefined; + private _allowedDatasourceIds: Set | null | undefined; + private _allowedWorkbookIds: Set | null | undefined; + + private readonly _cachedDatasourceIds: Map; + private readonly _cachedWorkbookIds: Map; + private readonly _cachedViewIds: Map; + + static create(): ResourceAccessChecker { + return new ResourceAccessChecker(); + } + + static createForTesting(boundedContext: BoundedContext): ResourceAccessChecker { + return new ResourceAccessChecker(boundedContext); + } + + // Optional bounded context to use for testing. + private constructor(boundedContext?: BoundedContext) { + // The methods assume these sets are non-empty. + this._allowedProjectIds = boundedContext?.projectIds; + this._allowedDatasourceIds = boundedContext?.datasourceIds; + this._allowedWorkbookIds = boundedContext?.workbookIds; + + this._cachedDatasourceIds = new Map(); + this._cachedWorkbookIds = new Map(); + this._cachedViewIds = new Map(); + } + + private get allowedProjectIds(): Set | null { + if (this._allowedProjectIds === undefined) { + this._allowedProjectIds = getConfig().boundedContext.projectIds; + } + + return this._allowedProjectIds; + } + + private get allowedDatasourceIds(): Set | null { + if (this._allowedDatasourceIds === undefined) { + this._allowedDatasourceIds = getConfig().boundedContext.datasourceIds; + } + + return this._allowedDatasourceIds; + } + + private get allowedWorkbookIds(): Set | null { + if (this._allowedWorkbookIds === undefined) { + this._allowedWorkbookIds = getConfig().boundedContext.workbookIds; + } + + return this._allowedWorkbookIds; + } + + async isDatasourceAllowed({ + datasourceLuid, + restApiArgs, + }: { + datasourceLuid: string; + restApiArgs: RestApiArgs; + }): Promise { + const result = await this._isDatasourceAllowed({ + datasourceLuid, + restApiArgs, + }); + + if (!this.allowedProjectIds) { + // If project filtering is enabled, we cannot cache the result since the datasource may be moved between projects. + this._cachedDatasourceIds.set(datasourceLuid, result); + } + + return result; + } + + async isWorkbookAllowed({ + workbookId, + restApiArgs, + }: { + workbookId: string; + restApiArgs: RestApiArgs; + }): Promise { + const result = await this._isWorkbookAllowed({ + workbookId, + restApiArgs, + }); + + if (!this.allowedProjectIds) { + // If project filtering is enabled, we cannot cache the result since the workbook may be moved between projects. + this._cachedWorkbookIds.set(workbookId, result); + } + + return result; + } + + async isViewAllowed({ + viewId, + restApiArgs, + }: { + viewId: string; + restApiArgs: RestApiArgs; + }): Promise { + const result = await this._isViewAllowed({ + viewId, + restApiArgs, + }); + + if (!this.allowedProjectIds) { + // If project filtering is enabled, we cannot cache the result since the workbook containing the view may be moved between projects. + this._cachedViewIds.set(viewId, result); + } + + return result; + } + + private async _isDatasourceAllowed({ + datasourceLuid, + restApiArgs: { config, requestId, server }, + }: { + datasourceLuid: string; + restApiArgs: RestApiArgs; + }): Promise { + const cachedResult = this._cachedDatasourceIds.get(datasourceLuid); + if (cachedResult) { + return cachedResult; + } + + if (this.allowedDatasourceIds && !this.allowedDatasourceIds.has(datasourceLuid)) { + return { + allowed: false, + message: [ + 'The set of allowed data sources that can be queried is limited by the server configuration.', + `Querying the datasource with LUID ${datasourceLuid} is not allowed.`, + ].join(' '), + }; + } + + if (this.allowedProjectIds) { + const datasourceProjectId = await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:content:read'], + callback: async (restApi) => { + const datasource = await restApi.datasourcesMethods.queryDatasource({ + siteId: restApi.siteId, + datasourceId: datasourceLuid, + }); + + return datasource.project.id; + }, + }); + + if (!this.allowedProjectIds.has(datasourceProjectId)) { + return { + allowed: false, + message: [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The datasource with LUID ${datasourceLuid} cannot be queried because it does not belong to an allowed project.`, + ].join(' '), + }; + } + } + + return { allowed: true }; + } + + private async _isWorkbookAllowed({ + workbookId, + restApiArgs: { config, requestId, server }, + }: { + workbookId: string; + restApiArgs: RestApiArgs; + }): Promise { + const cachedResult = this._cachedWorkbookIds.get(workbookId); + if (cachedResult) { + return cachedResult; + } + + if (this.allowedWorkbookIds && !this.allowedWorkbookIds.has(workbookId)) { + return { + allowed: false, + message: [ + 'The set of allowed workbooks that can be queried is limited by the server configuration.', + `Querying the workbook with LUID ${workbookId} is not allowed.`, + ].join(' '), + }; + } + + if (this.allowedProjectIds) { + const workbookProjectId = await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:content:read'], + callback: async (restApi) => { + const workbook = await restApi.workbooksMethods.getWorkbook({ + siteId: restApi.siteId, + workbookId, + }); + + return workbook.project?.id ?? ''; + }, + }); + + if (!this.allowedProjectIds.has(workbookProjectId)) { + return { + allowed: false, + message: [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The workbook with LUID ${workbookId} cannot be queried because it does not belong to an allowed project.`, + ].join(' '), + }; + } + } + + return { allowed: true }; + } + + private async _isViewAllowed({ + viewId, + restApiArgs: { config, requestId, server }, + }: { + viewId: string; + restApiArgs: RestApiArgs; + }): Promise { + const cachedResult = this._cachedViewIds.get(viewId); + if (cachedResult) { + return cachedResult; + } + + let viewWorkbookId = ''; + let viewProjectId = ''; + + if (this.allowedWorkbookIds) { + const view = await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:content:read'], + callback: async (restApi) => { + return await restApi.viewsMethods.getView({ + siteId: restApi.siteId, + viewId, + }); + }, + }); + + viewWorkbookId = view.workbook?.id ?? ''; + viewProjectId = view.project?.id ?? ''; + + if (!this.allowedWorkbookIds.has(viewWorkbookId)) { + return { + allowed: false, + message: [ + 'The set of allowed workbooks that can be queried is limited by the server configuration.', + `The view with LUID ${viewId} cannot be queried because it does not belong to an allowed workbook.`, + ].join(' '), + }; + } + } + + if (this.allowedProjectIds) { + viewProjectId = + viewProjectId || + (await useRestApi({ + config, + requestId, + server, + jwtScopes: ['tableau:content:read'], + callback: async (restApi) => { + const view = await restApi.viewsMethods.getView({ + siteId: restApi.siteId, + viewId, + }); + + return view.project?.id ?? ''; + }, + })); + + if (!this.allowedProjectIds.has(viewProjectId)) { + return { + allowed: false, + message: [ + 'The set of allowed projects that can be queried is limited by the server configuration.', + `The view with LUID ${viewId} cannot be queried because it does not belong to an allowed project.`, + ].join(' '), + }; + } + } + + return { allowed: true }; + } +} + +let resourceAccessChecker = ResourceAccessChecker.create(); +const exportedForTesting = { + createResourceAccessChecker: ResourceAccessChecker.createForTesting, + resetResourceAccessCheckerSingleton: () => { + resourceAccessChecker = ResourceAccessChecker.create(); + }, +}; + +export { exportedForTesting, resourceAccessChecker }; diff --git a/src/tools/tool.test.ts b/src/tools/tool.test.ts index b466407a..f6df8b5c 100644 --- a/src/tools/tool.test.ts +++ b/src/tools/tool.test.ts @@ -63,6 +63,12 @@ describe('Tool', () => { requestId: '2', args: { param1: 'test' }, callback, + constrainSuccessResult: (result) => { + return { + type: 'success', + result, + }; + }, }); expect(result.isError).toBe(false); @@ -88,6 +94,12 @@ describe('Tool', () => { requestId: '2', args: { param1: 'test' }, callback, + constrainSuccessResult: (result) => { + return { + type: 'success', + result, + }; + }, }); expect(result.isError).toBe(true); @@ -103,6 +115,12 @@ describe('Tool', () => { requestId: '2', args, callback: vi.fn(), + constrainSuccessResult: (result) => { + return { + type: 'success', + result, + }; + }, }); expect(mockParams.argsValidator).toHaveBeenCalledWith(args); @@ -130,10 +148,85 @@ describe('Tool', () => { requestId: '2', args: { param1: 'test' }, callback: () => Promise.resolve(Ok('test')), + constrainSuccessResult: (result) => { + return { + type: 'success', + result, + }; + }, }); expect(result.isError).toBe(true); expect(result.content[0].type).toBe('text'); expect(result.content[0].text).toBe('requestId: 2, error: Test error'); }); + + it('should constrain the success result', async () => { + const tool = new Tool(mockParams); + const successResult = { data: 'success' }; + + const result = await tool.logAndExecute({ + requestId: '2', + args: { param1: 'test' }, + callback: () => Promise.resolve(Ok(successResult)), + constrainSuccessResult: (result) => { + return { + type: 'success', + result: { + ...result, + additionalField: 'extra', + }, + }; + }, + }); + + expect(result.isError).toBe(false); + expect(result.content[0].type).toBe('text'); + expect(JSON.parse(result.content[0].text as string)).toEqual({ + ...successResult, + additionalField: 'extra', + }); + }); + + it('should return empty result when the constrained result is empty', async () => { + const tool = new Tool(mockParams); + const successResult = { data: 'success' }; + + const result = await tool.logAndExecute({ + requestId: '2', + args: { param1: 'test' }, + callback: () => Promise.resolve(Ok(successResult)), + constrainSuccessResult: (_result) => { + return { + type: 'empty', + message: 'No data found', + }; + }, + }); + + expect(result.isError).toBe(false); + expect(result.content[0].type).toBe('text'); + expect(result.content[0].text).toBe('No data found'); + }); + + it('should return error result when the constrained result is error', async () => { + const tool = new Tool(mockParams); + const successResult = { data: 'success' }; + + const result = await tool.logAndExecute({ + requestId: '2', + args: { param1: 'test' }, + callback: () => Promise.resolve(Ok(successResult)), + constrainSuccessResult: (_result) => { + return { + type: 'error', + message: 'An error occurred', + }; + }, + }); + + expect(result.isError).toBe(true); + expect(result.content[0].type).toBe('text'); + expect(result.content[0].text).toBe('An error occurred'); + }); }); diff --git a/src/tools/tool.ts b/src/tools/tool.ts index f8805dc0..5ac788d4 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -14,6 +14,20 @@ type ArgsValidator = Args exte ? (args: z.objectOutputType) => void : never; +export type ConstrainedResult = + | { + type: 'success'; + result: T; + } + | { + type: 'empty'; + message: string; + } + | { + type: 'error'; + message: string; + }; + /** * The parameters for creating a tool instance * @@ -65,6 +79,9 @@ type LogAndExecuteParams // A function that can transform an error result of the callback into a string. // Required if the callback can return an error result. getErrorText?: (error: E) => string; + + // A function that constrains the success result of the tool + constrainSuccessResult: (result: T) => ConstrainedResult | Promise>; }; /** @@ -125,6 +142,7 @@ export class Tool { callback, getSuccessResult, getErrorText, + constrainSuccessResult, }: LogAndExecuteParams): Promise { this.logInvocation({ requestId, args }); @@ -140,8 +158,17 @@ export class Tool { const result = await callback(); if (result.isOk()) { + const constrainedResult = await constrainSuccessResult(result.value); + + if (constrainedResult.type !== 'success') { + return { + isError: constrainedResult.type === 'error', + content: [{ type: 'text', text: constrainedResult.message }], + }; + } + if (getSuccessResult) { - return getSuccessResult(result.value); + return getSuccessResult(constrainedResult.result); } return { @@ -149,7 +176,7 @@ export class Tool { content: [ { type: 'text', - text: JSON.stringify(result.value), + text: JSON.stringify(constrainedResult.result), }, ], }; diff --git a/src/tools/views/getViewData.ts b/src/tools/views/getViewData.ts index d4a0a724..a375c6b8 100644 --- a/src/tools/views/getViewData.ts +++ b/src/tools/views/getViewData.ts @@ -1,16 +1,22 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { Ok } from 'ts-results-es'; +import { Err, Ok } from 'ts-results-es'; import { z } from 'zod'; import { getConfig } from '../../config.js'; import { useRestApi } from '../../restApiInstance.js'; import { Server } from '../../server.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; const paramsSchema = { viewId: z.string(), }; +export type GetViewDataError = { + type: 'view-not-allowed'; + message: string; +}; + export const getGetViewDataTool = (server: Server): Tool => { const getViewDataTool = new Tool({ server, @@ -25,10 +31,22 @@ export const getGetViewDataTool = (server: Server): Tool => callback: async ({ viewId }, { requestId }): Promise => { const config = getConfig(); - return await getViewDataTool.logAndExecute({ + return await getViewDataTool.logAndExecute({ requestId, args: { viewId }, callback: async () => { + const isViewAllowedResult = await resourceAccessChecker.isViewAllowed({ + viewId, + restApiArgs: { config, requestId, server }, + }); + + if (!isViewAllowedResult.allowed) { + return new Err({ + type: 'view-not-allowed', + message: isViewAllowedResult.message, + }); + } + return new Ok( await useRestApi({ config, @@ -44,6 +62,18 @@ export const getGetViewDataTool = (server: Server): Tool => }), ); }, + constrainSuccessResult: (viewData) => { + return { + type: 'success', + result: viewData, + }; + }, + getErrorText: (error: GetViewDataError) => { + switch (error.type) { + case 'view-not-allowed': + return error.message; + } + }, }); }, }); diff --git a/src/tools/views/getViewImage.ts b/src/tools/views/getViewImage.ts index 41fa1b0e..4ba7492f 100644 --- a/src/tools/views/getViewImage.ts +++ b/src/tools/views/getViewImage.ts @@ -1,11 +1,12 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { Ok } from 'ts-results-es'; +import { Err, Ok } from 'ts-results-es'; import { z } from 'zod'; import { getConfig } from '../../config.js'; import { useRestApi } from '../../restApiInstance.js'; import { Server } from '../../server.js'; import { convertPngDataToToolResult } from '../convertPngDataToToolResult.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; const paramsSchema = { @@ -14,6 +15,11 @@ const paramsSchema = { height: z.number().gt(0).optional(), }; +export type GetViewImageError = { + type: 'view-not-allowed'; + message: string; +}; + export const getGetViewImageTool = (server: Server): Tool => { const getViewImageTool = new Tool({ server, @@ -28,10 +34,22 @@ export const getGetViewImageTool = (server: Server): Tool = callback: async ({ viewId, width, height }, { requestId }): Promise => { const config = getConfig(); - return await getViewImageTool.logAndExecute({ + return await getViewImageTool.logAndExecute({ requestId, args: { viewId }, callback: async () => { + const isViewAllowedResult = await resourceAccessChecker.isViewAllowed({ + viewId, + restApiArgs: { config, requestId, server }, + }); + + if (!isViewAllowedResult.allowed) { + return new Err({ + type: 'view-not-allowed', + message: isViewAllowedResult.message, + }); + } + return new Ok( await useRestApi({ config, @@ -50,7 +68,19 @@ export const getGetViewImageTool = (server: Server): Tool = }), ); }, + constrainSuccessResult: (viewImage) => { + return { + type: 'success', + result: viewImage, + }; + }, getSuccessResult: convertPngDataToToolResult, + getErrorText: (error: GetViewImageError) => { + switch (error.type) { + case 'view-not-allowed': + return error.message; + } + }, }); }, }); diff --git a/src/tools/views/listViews.ts b/src/tools/views/listViews.ts index 7b960b89..de88255d 100644 --- a/src/tools/views/listViews.ts +++ b/src/tools/views/listViews.ts @@ -79,7 +79,7 @@ export const getListViewsTool = (server: Server): Tool => { server, jwtScopes: ['tableau:content:read'], callback: async (restApi) => { - const workbooks = await paginate({ + const views = await paginate({ pageConfig: { pageSize, limit: config.maxResultLimit @@ -100,11 +100,48 @@ export const getListViewsTool = (server: Server): Tool => { }, }); - return workbooks; + return views; }, }), ); }, + constrainSuccessResult: (views) => { + if (views.length === 0) { + return { + type: 'empty', + message: + 'No views were found. Either none exist or you do not have permission to view them', + }; + } + + const { projectIds, workbookIds } = getConfig().boundedContext; + if (projectIds) { + views = views.filter((view) => + view.project?.id ? projectIds.has(view.project.id) : false, + ); + } + + if (workbookIds) { + views = views.filter((view) => + view.workbook?.id ? workbookIds.has(view.workbook.id) : false, + ); + } + + if (views.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed views that can be queried is limited by the server configuration.', + 'While views were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: views, + }; + }, }); }, }); diff --git a/src/tools/views/mockView.ts b/src/tools/views/mockView.ts new file mode 100644 index 00000000..99586815 --- /dev/null +++ b/src/tools/views/mockView.ts @@ -0,0 +1,10 @@ +export const mockView = { + id: '4d18c547-bbb1-4187-ae5a-7f78b35adf2d', + name: 'Overview', + project: { + id: 'ae5e9374-2a58-40ab-93e4-a2fd1b07cf7d', + }, + workbook: { + id: '96a43833-27db-40b6-aa80-751efc776b9a', + }, +}; diff --git a/src/tools/workbooks/getWorkbook.test.ts b/src/tools/workbooks/getWorkbook.test.ts index ee114884..35dc4b8f 100644 --- a/src/tools/workbooks/getWorkbook.test.ts +++ b/src/tools/workbooks/getWorkbook.test.ts @@ -1,11 +1,13 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Server } from '../../server.js'; +import { mockView } from '../views/mockView.js'; import { getGetWorkbookTool } from './getWorkbook.js'; import { mockWorkbook } from './mockWorkbook.js'; const mocks = vi.hoisted(() => ({ mockGetWorkbook: vi.fn(), + mockQueryViewsForWorkbook: vi.fn(), })); vi.mock('../../restApiInstance.js', () => ({ @@ -14,6 +16,9 @@ vi.mock('../../restApiInstance.js', () => ({ workbooksMethods: { getWorkbook: mocks.mockGetWorkbook, }, + viewsMethods: { + queryViewsForWorkbook: mocks.mockQueryViewsForWorkbook, + }, siteId: 'test-site-id', }), ), @@ -35,6 +40,7 @@ describe('getWorkbookTool', () => { it('should successfully get workbook', async () => { mocks.mockGetWorkbook.mockResolvedValue(mockWorkbook); + mocks.mockQueryViewsForWorkbook.mockResolvedValue([mockView]); const result = await getToolResult({ workbookId: '96a43833-27db-40b6-aa80-751efc776b9a' }); expect(result.isError).toBe(false); expect(result.content[0].text).toContain('Superstore'); diff --git a/src/tools/workbooks/getWorkbook.ts b/src/tools/workbooks/getWorkbook.ts index 024028b7..69cfa5b4 100644 --- a/src/tools/workbooks/getWorkbook.ts +++ b/src/tools/workbooks/getWorkbook.ts @@ -1,16 +1,23 @@ import { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; -import { Ok } from 'ts-results-es'; +import { Err, Ok } from 'ts-results-es'; import { z } from 'zod'; import { getConfig } from '../../config.js'; import { useRestApi } from '../../restApiInstance.js'; +import { Workbook } from '../../sdks/tableau/types/workbook.js'; import { Server } from '../../server.js'; +import { resourceAccessChecker } from '../resourceAccessChecker.js'; import { Tool } from '../tool.js'; const paramsSchema = { workbookId: z.string(), }; +export type GetWorkbookError = { + type: 'workbook-not-allowed'; + message: string; +}; + export const getGetWorkbookTool = (server: Server): Tool => { const getWorkbookTool = new Tool({ server, @@ -25,10 +32,22 @@ export const getGetWorkbookTool = (server: Server): Tool => callback: async ({ workbookId }, { requestId }): Promise => { const config = getConfig(); - return await getWorkbookTool.logAndExecute({ + return await getWorkbookTool.logAndExecute({ requestId, args: { workbookId }, callback: async () => { + const isWorkbookAllowedResult = await resourceAccessChecker.isWorkbookAllowed({ + workbookId, + restApiArgs: { config, requestId, server }, + }); + + if (!isWorkbookAllowedResult.allowed) { + return new Err({ + type: 'workbook-not-allowed', + message: isWorkbookAllowedResult.message, + }); + } + return new Ok( await useRestApi({ config, @@ -58,6 +77,18 @@ export const getGetWorkbookTool = (server: Server): Tool => }), ); }, + constrainSuccessResult: (workbook) => { + return { + type: 'success', + result: workbook, + }; + }, + getErrorText: (error: GetWorkbookError) => { + switch (error.type) { + case 'workbook-not-allowed': + return error.message; + } + }, }); }, }); diff --git a/src/tools/workbooks/listWorkbooks.test.ts b/src/tools/workbooks/listWorkbooks.test.ts index ecb67d3c..3fc99177 100644 --- a/src/tools/workbooks/listWorkbooks.test.ts +++ b/src/tools/workbooks/listWorkbooks.test.ts @@ -10,7 +10,7 @@ const mockWorkbooks = { pageSize: 10, totalAvailable: 1, }, - workbooks: [mockWorkbook], + workbooks: [{ workbook: mockWorkbook }], }; const mocks = vi.hoisted(() => ({ diff --git a/src/tools/workbooks/listWorkbooks.ts b/src/tools/workbooks/listWorkbooks.ts index c00d6d7c..e5c10712 100644 --- a/src/tools/workbooks/listWorkbooks.ts +++ b/src/tools/workbooks/listWorkbooks.ts @@ -101,6 +101,41 @@ export const getListWorkbooksTool = (server: Server): Tool }), ); }, + constrainSuccessResult: (workbooks) => { + if (workbooks.length === 0) { + return { + type: 'empty', + message: + 'No workbooks were found. Either none exist or you do not have permission to view them', + }; + } + + const { projectIds, workbookIds } = getConfig().boundedContext; + if (projectIds) { + workbooks = workbooks.filter((workbook) => + workbook.project?.id ? projectIds.has(workbook.project.id) : false, + ); + } + + if (workbookIds) { + workbooks = workbooks.filter((workbook) => workbookIds.has(workbook.id)); + } + + if (workbooks.length === 0) { + return { + type: 'empty', + message: [ + 'The set of allowed workbooks that can be queried is limited by the server configuration.', + 'While workbooks were found, they were all filtered out by the server configuration.', + ].join(' '), + }; + } + + return { + type: 'success', + result: workbooks, + }; + }, }); }, }); diff --git a/src/tools/workbooks/mockWorkbook.ts b/src/tools/workbooks/mockWorkbook.ts index c0024250..5191f2e2 100644 --- a/src/tools/workbooks/mockWorkbook.ts +++ b/src/tools/workbooks/mockWorkbook.ts @@ -1,18 +1,13 @@ +import { mockView } from '../views/mockView.js'; + export const mockWorkbook = { - workbook: { - id: '96a43833-27db-40b6-aa80-751efc776b9a', - name: 'Superstore', - contentUrl: 'Superstore', - project: { name: 'Samples', id: 'ae5e9374-2a58-40ab-93e4-a2fd1b07cf7d' }, - showTabs: true, - defaultViewId: '4d18c547-bbb1-4187-ae5a-7f78b35adf2d', - views: { - view: [ - { - id: '4d18c547-bbb1-4187-ae5a-7f78b35adf2d', - name: 'Overview', - }, - ], - }, + id: '96a43833-27db-40b6-aa80-751efc776b9a', + name: 'Superstore', + contentUrl: 'Superstore', + project: { name: 'Samples', id: 'ae5e9374-2a58-40ab-93e4-a2fd1b07cf7d' }, + showTabs: true, + defaultViewId: '4d18c547-bbb1-4187-ae5a-7f78b35adf2d', + views: { + view: [mockView], }, }; diff --git a/types/process-env.d.ts b/types/process-env.d.ts index de12eb88..426a3072 100644 --- a/types/process-env.d.ts +++ b/types/process-env.d.ts @@ -24,6 +24,9 @@ export interface ProcessEnvEx { DISABLE_METADATA_API_REQUESTS: string | undefined; ENABLE_SERVER_LOGGING: string | undefined; SERVER_LOG_DIRECTORY: string | undefined; + INCLUDE_PROJECT_IDS: string | undefined; + INCLUDE_DATASOURCE_IDS: string | undefined; + INCLUDE_WORKBOOK_IDS: string | undefined; } declare global {