Skip to content

Commit f132c8a

Browse files
committed
should be able to get plurality of threads in getThreadMessages
1 parent 3e46d82 commit f132c8a

File tree

5 files changed

+67
-39
lines changed

5 files changed

+67
-39
lines changed

mcp/src/apis/getMessageContext.ts

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ import {
1111
type zUser,
1212
} from '../types.js';
1313
import { addChannelInfo } from '../util/addChannelInfo.js';
14-
import { convertTsToTimestamp } from '../util/formatTs.js';
15-
import { getChannelIds } from '../util/getChannelIds.js';
1614
import { getUsersMap } from '../util/getUsersMap.js';
1715
import { getMessageFields } from '../util/messageFields.js';
1816
import { messagesToTree } from '../util/messagesToTree.js';
17+
import { normalizeMessageFilterQueryParameters } from '../util/normalizeMessageFilterQueryParameters.js';
1918
import { selectExpandedMessages } from '../util/selectExpandedMessages.js';
2019

2120
const inputSchema = {
@@ -60,25 +59,11 @@ export const getMessageContextFactory: ApiFactory<
6059
users: Record<string, z.infer<typeof zUser>>;
6160
}> => {
6261
const client = await pgPool.connect();
63-
const channelIds = await getChannelIds(
62+
const messageFilters = normalizeMessageFilterQueryParameters(
6463
pgPool,
65-
passedMessageFilters.map((x) => x.channel),
64+
passedMessageFilters,
6665
);
6766

68-
if (channelIds === null || channelIds.length === 0) {
69-
throw new Error('You must pass at least one existing channel id');
70-
}
71-
72-
const messageFilters = passedMessageFilters.reduce<
73-
z.infer<typeof zMessageFilter>[]
74-
>((acc, curr, index) => {
75-
acc.push({
76-
channel: channelIds[index] || '',
77-
ts: convertTsToTimestamp(curr.ts),
78-
});
79-
return acc;
80-
}, []);
81-
8267
try {
8368
const result = await client.query<Message>(
8469
selectExpandedMessages(

mcp/src/apis/getThreadMessages.ts

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,22 @@ import {
44
type Message,
55
type MessageInThread,
66
type ServerContext,
7+
type User,
78
zIncludeFilters,
9+
zMessageFilter,
810
zMessageInThread,
911
zUser,
1012
} from '../types.js';
11-
import { convertTsToTimestamp } from '../util/formatTs.js';
1213
import { getUsersMap } from '../util/getUsersMap.js';
1314
import { getMessageFields } from '../util/messageFields.js';
1415
import { messagesToTree } from '../util/messagesToTree.js';
16+
import { normalizeMessageFilterQueryParameters } from '../util/normalizeMessageFilterQueryParameters.js';
1517

1618
const inputSchema = {
1719
...zIncludeFilters.shape,
18-
channel: z
19-
.string()
20-
.min(1)
21-
.describe('The ID of the channel to fetch messages from.'),
22-
23-
ts: z
24-
.string()
25-
.min(1)
26-
.describe(
27-
'The thread timestamp to fetch messages for. This is the ts of the parent message. Use the `thread_ts` field from a known message in the thread.',
28-
),
20+
messageFilters: z
21+
.array(zMessageFilter)
22+
.describe('The messages to fetch the threads for.'),
2923
} as const;
3024

3125
const outputSchema = {
@@ -52,28 +46,45 @@ export const getThreadMessagesFactory: ApiFactory<
5246
outputSchema,
5347
},
5448
fn: async ({
55-
channel,
5649
includeFiles,
5750
includePermalinks,
58-
ts,
51+
messageFilters: passedMessageFilters,
5952
}): Promise<{
6053
messages: MessageInThread[];
61-
users: Record<string, z.infer<typeof zUser>>;
54+
users: Record<string, User>;
6255
}> => {
56+
const messageFilters = normalizeMessageFilterQueryParameters(
57+
pgPool,
58+
passedMessageFilters,
59+
);
60+
6361
const result = await pgPool.query<Message>(
6462
/* sql */ `
65-
SELECT ${getMessageFields({ includeFiles })} FROM slack.message
66-
WHERE channel_id = $1 AND (thread_ts = $2 OR ts = $2)
67-
ORDER BY ts DESC`, // messagesToTree expects messages in descending order
68-
[channel, convertTsToTimestamp(ts)],
63+
WITH filters AS (
64+
SELECT
65+
f->>'channel' AS channel_id,
66+
(f->>'ts')::timestamptz AS ts
67+
FROM jsonb_array_elements($1::jsonb) AS f
68+
)
69+
SELECT ${getMessageFields({ includeFiles })}
70+
FROM slack.message m
71+
WHERE EXISTS (
72+
SELECT 1 FROM filters f
73+
WHERE m.channel_id = f.channel_id
74+
AND (m.thread_ts = f.ts OR m.ts = f.ts)
75+
)
76+
ORDER BY ts DESC`, // messagesToTree expects messages in descending order
77+
[JSON.stringify(messageFilters)],
6978
);
7079

7180
const { involvedUsers, channels } = messagesToTree(
7281
result.rows,
7382
includePermalinks || false,
7483
);
7584
const users = await getUsersMap(pgPool, involvedUsers);
76-
const messages = channels[channel]?.messages || [];
85+
86+
// Flatten messages from all channels
87+
const messages = Object.values(channels).flatMap((c) => c.messages);
7788

7889
return {
7990
messages,

mcp/src/apis/search.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import {
88
zMessage,
99
} from '../types.js';
1010
import { generatePermalink } from '../util/addMessageLinks.js';
11-
import { findChannel } from '../util/findChannel.js';
1211
import { findUser } from '../util/findUser.js';
1312
import { getChannelIds } from '../util/getChannelIds.js';
1413
import { getMessageKey } from '../util/getMessageKey.js';

mcp/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ export const zMessageFilter = z.object({
151151
),
152152
});
153153

154+
export type MessageFilter = z.infer<typeof zMessageFilter>;
155+
154156
export const zTimeFilters = z.object({
155157
timestampStart: z.coerce
156158
.date()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import type { Pool } from 'pg';
2+
import type { MessageFilter } from '../types';
3+
import { convertTsToTimestamp } from './formatTs';
4+
import { getChannelIds } from './getChannelIds';
5+
6+
export const normalizeMessageFilterQueryParameters = async (
7+
pgPool: Pool,
8+
messageFilters: MessageFilter[],
9+
): Promise<MessageFilter[]> => {
10+
const channelIds = await getChannelIds(
11+
pgPool,
12+
messageFilters.map((x) => x.channel),
13+
);
14+
15+
if (channelIds === null || channelIds.length === 0) {
16+
throw new Error('You must pass at least one existing channel id');
17+
}
18+
19+
const normalized = messageFilters.reduce<MessageFilter[]>(
20+
(acc, curr, index) => {
21+
acc.push({
22+
channel: channelIds[index] || '',
23+
ts: convertTsToTimestamp(curr.ts),
24+
});
25+
return acc;
26+
},
27+
[],
28+
);
29+
30+
return normalized;
31+
};

0 commit comments

Comments
 (0)