Skip to content

Commit bd47e6a

Browse files
danmichaeljonesDan Jones
andauthored
Add QueryAgent streaming (#11)
* Add utilities to read SSEs * Add types for streaming responses * Add new stream method to query agent * Whitespace * Better typing for progress message details * Fix snake_case to camelCase --------- Co-authored-by: Dan Jones <[email protected]>
1 parent b7012a5 commit bd47e6a

File tree

4 files changed

+248
-2
lines changed

4 files changed

+248
-2
lines changed

src/query/agent.ts

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { WeaviateClient } from "weaviate-client";
2-
import { QueryAgentResponse } from "./response/response.js";
3-
import { mapResponse } from "./response/response-mapping.js";
2+
import { QueryAgentResponse, ProgressMessage, StreamedTokens } from "./response/response.js";
3+
import { mapResponse, mapProgressMessageFromSSE, mapStreamedTokensFromSSE, mapResponseFromSSE } from "./response/response-mapping.js";
44
import { mapApiResponse } from "./response/api-response-mapping.js";
5+
import { fetchServerSentEvents } from "./response/server-sent-events.js";
56
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
67

78
/**
@@ -78,6 +79,58 @@ export class QueryAgent {
7879

7980
return mapResponse(await response.json());
8081
}
82+
83+
/**
84+
* Stream responses from the query agent.
85+
*
86+
* @param query - The natural language query string for the agent.
87+
* @param options - Additional options for the run.
88+
* @returns The response from the query agent.
89+
*/
90+
async *stream(
91+
query: string,
92+
{ collections, context, includeProgress }: QueryAgentStreamOptions = {}
93+
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse> {
94+
const targetCollections = collections ?? this.collections;
95+
96+
if (!targetCollections) {
97+
throw Error("No collections provided to the query agent.");
98+
}
99+
100+
const { host, bearerToken, headers } =
101+
await this.client.getConnectionDetails();
102+
103+
const sseStream = fetchServerSentEvents(`${this.agentsHost}/agent/stream_query`, {
104+
method: "POST",
105+
headers: {
106+
"Content-Type": "application/json",
107+
Authorization: bearerToken!,
108+
"X-Weaviate-Cluster-Url": host,
109+
},
110+
body: JSON.stringify({
111+
headers,
112+
query,
113+
collections: mapCollections(targetCollections),
114+
system_prompt: this.systemPrompt,
115+
previous_response: context ? mapApiResponse(context) : undefined,
116+
include_progress: includeProgress ?? true,
117+
}),
118+
});
119+
120+
for await (const event of sseStream) {
121+
let output: ProgressMessage | StreamedTokens | QueryAgentResponse;
122+
if (event.event === "progress_message") {
123+
output = mapProgressMessageFromSSE(event);
124+
} else if (event.event === "streamed_tokens") {
125+
output = mapStreamedTokensFromSSE(event);
126+
} else if (event.event === "final_state") {
127+
output = mapResponseFromSSE(event);
128+
} else {
129+
throw new Error(`Unexpected event type: ${event.event}`);
130+
}
131+
yield output;
132+
}
133+
}
81134
}
82135

83136
/** Options for the QueryAgent. */
@@ -97,3 +150,13 @@ export type QueryAgentRunOptions = {
97150
/** Previous response from the agent. */
98151
context?: QueryAgentResponse;
99152
};
153+
154+
/** Options for the QueryAgent stream. */
155+
export type QueryAgentStreamOptions = {
156+
/** List of collections to query. Will override any collections if passed in the constructor. */
157+
collections?: (string | QueryAgentCollectionConfig)[];
158+
/** Previous response from the agent. */
159+
context?: QueryAgentResponse;
160+
/** Include progress messages in the stream. */
161+
includeProgress?: boolean;
162+
};

src/query/response/response-mapping.ts

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
PropertyAggregation,
77
Usage,
88
Source,
9+
StreamedTokens,
10+
ProgressMessage,
911
} from "./response.js";
1012

1113
import {
@@ -18,10 +20,13 @@ import {
1820
ApiSource,
1921
} from "./api-response.js";
2022

23+
import { ServerSentEvent } from "./server-sent-events.js";
24+
2125
export const mapResponse = (
2226
response: ApiQueryAgentResponse
2327
): QueryAgentResponse => {
2428
const properties: ResponseProperties = {
29+
outputType: "finalState",
2530
originalQuery: response.original_query,
2631
collectionNames: response.collection_names,
2732
searches: mapSearches(response.searches),
@@ -103,3 +108,64 @@ const display = (response: ResponseProperties) => {
103108
};
104109

105110
type ResponseProperties = Omit<QueryAgentResponse, "display">;
111+
112+
type ProgressMessageJSON = Omit<ProgressMessage, "outputType"> & {
113+
output_type: "progress_message";
114+
};
115+
116+
export const mapProgressMessageFromSSE = (sse: ServerSentEvent): ProgressMessage => {
117+
const data: ProgressMessageJSON = JSON.parse(sse.data);
118+
if (data.output_type !== "progress_message") {
119+
throw new Error(`Expected output_type "progress_message", got ${data.output_type}`);
120+
}
121+
122+
return {
123+
outputType: "progressMessage",
124+
stage: data.stage,
125+
message: data.message,
126+
details: data.details,
127+
};
128+
};
129+
130+
type StreamedTokensJSON = Omit<StreamedTokens, "outputType"> & {
131+
output_type: "streamed_tokens";
132+
};
133+
134+
export const mapStreamedTokensFromSSE = (sse: ServerSentEvent): StreamedTokens => {
135+
const data: StreamedTokensJSON = JSON.parse(sse.data);
136+
if (data.output_type !== "streamed_tokens") {
137+
throw new Error(`Expected output_type "streamed_tokens", got ${data.output_type}`);
138+
}
139+
140+
return {
141+
outputType: "streamedTokens",
142+
delta: data.delta,
143+
};
144+
};
145+
146+
147+
export const mapResponseFromSSE = (sse: ServerSentEvent): QueryAgentResponse => {
148+
const data: ApiQueryAgentResponse = JSON.parse(sse.data);
149+
150+
const properties: ResponseProperties = {
151+
outputType: "finalState",
152+
originalQuery: data.original_query,
153+
collectionNames: data.collection_names,
154+
searches: mapSearches(data.searches),
155+
aggregations: mapAggregations(data.aggregations),
156+
usage: mapUsage(data.usage),
157+
totalTime: data.total_time,
158+
aggregationAnswer: data.aggregation_answer,
159+
hasAggregationAnswer: data.has_aggregation_answer,
160+
hasSearchAnswer: data.has_search_answer,
161+
isPartialAnswer: data.is_partial_answer,
162+
missingInformation: data.missing_information,
163+
finalAnswer: data.final_answer,
164+
sources: mapSources(data.sources),
165+
};
166+
167+
return {
168+
...properties,
169+
display: () => display(properties),
170+
};
171+
};

src/query/response/response.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export type QueryAgentResponse = {
2+
outputType: "finalState";
23
originalQuery: string;
34
collectionNames: string[];
45
searches: SearchResult[][];
@@ -134,3 +135,24 @@ export type Source = {
134135
objectId: string;
135136
collection: string;
136137
};
138+
139+
export type QueryWithCollection = {
140+
query: string;
141+
collection: string;
142+
};
143+
144+
export type ProgressDetails = {
145+
queries?: QueryWithCollection[];
146+
};
147+
148+
export type ProgressMessage = {
149+
outputType: "progressMessage";
150+
stage: string;
151+
message: string;
152+
details: ProgressDetails;
153+
};
154+
155+
export type StreamedTokens = {
156+
outputType: "streamedTokens";
157+
delta: string;
158+
};
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
export type ServerSentEvent = {
2+
event: string;
3+
data: string;
4+
};
5+
6+
7+
/**
8+
* Fetch Server-Sent Events (SSE) from a URL.
9+
*
10+
* All fields other than "event" and "data" are ignored
11+
*
12+
* @param input - The URL to fetch the SSE from.
13+
* @param init - The request init options.
14+
* @returns An async generator of ServerSentEvent objects.
15+
*/
16+
export async function* fetchServerSentEvents(
17+
input: string | URL | globalThis.Request,
18+
init?: RequestInit
19+
): AsyncGenerator<ServerSentEvent> {
20+
const response = await fetch(input, {
21+
...init,
22+
headers: {
23+
...init?.headers,
24+
"Accept": "text/event-stream",
25+
}
26+
});
27+
28+
if (!response.ok || !response.body) {
29+
throw Error(`Query agent streaming failed. ${await response.text()}`);
30+
}
31+
32+
const reader = response.body.getReader();
33+
const textDecoder = new TextDecoder("utf-8");
34+
let buffer = "";
35+
36+
while (true) {
37+
const { done, value } = await reader.read();
38+
if (done) {
39+
break;
40+
}
41+
42+
// Use a buffer to accumulate text until we have a complete SSE (delimited by blank lines)
43+
buffer += textDecoder.decode(value, { stream: true });
44+
45+
const { events, remainingBuffer } = parseServerSentEvents(buffer);
46+
for (const event of events) {
47+
yield event;
48+
}
49+
buffer = remainingBuffer;
50+
}
51+
52+
// Flush the remaining buffer
53+
const { events } = parseServerSentEvents(buffer, true);
54+
for (const event of events) {
55+
yield event;
56+
}
57+
}
58+
59+
function parseServerSentEvents(buffer: string, flush?: boolean): { events: ServerSentEvent[]; remainingBuffer: string } {
60+
// Server sent events are delimited by blank lines,
61+
// and may be spread across multiple chunks from the API
62+
const sseChunks = buffer.split(/\r?\n\r?\n/);
63+
let remainingBuffer = "";
64+
65+
if (flush !== true) {
66+
// Put the (possibly incomplete) final event back into the buffer
67+
remainingBuffer = sseChunks.pop() ?? "";
68+
}
69+
70+
let events: ServerSentEvent[] = [];
71+
72+
for (const chunk of sseChunks) {
73+
const lines = chunk.split(/\r?\n/);
74+
let event = "message";
75+
let data = "";
76+
77+
for (const line of lines) {
78+
if (line.startsWith("event:")) {
79+
// Replace event name if we get one
80+
event = line.slice("event:".length).trim();
81+
} else if (line.startsWith("data:")) {
82+
if (data) {
83+
// Data was spread across multiple lines
84+
data += "\n";
85+
}
86+
data += line.slice("data:".length).trim();
87+
}
88+
}
89+
if (data) {
90+
events.push({ event, data });
91+
}
92+
}
93+
94+
return { events, remainingBuffer };
95+
}

0 commit comments

Comments
 (0)