Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export class QueryAgent {
/**
* Run the query agent.
*
* @deprecated Use {@link ask} instead.
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
Expand Down Expand Up @@ -93,9 +94,53 @@ export class QueryAgent {
return mapResponse(await response.json());
}

/**
* Ask query agent a question.
*
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
async ask(
query: string,
{ collections, context }: QueryAgentRunOptions = {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ask / askStream methods should have the context option removed now (they'll be superseded by the conversational query option) 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I know, just created this PR as easier to review (as this is mostly copy-paste).
And will add conversational params on the next PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yeah can remove it in this PR already

): Promise<QueryAgentResponse> {
const targetCollections = collections ?? this.collections;
if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

const { host, bearerToken, headers } =
await this.client.getConnectionDetails();

const response = await fetch(`${this.agentsHost}/agent/query`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
},
body: JSON.stringify({
headers,
query,
collections: mapCollections(targetCollections),
system_prompt: this.systemPrompt,
previous_response: context ? mapApiResponse(context) : undefined,
}),
});

if (!response.ok) {
await handleError(await response.text());
}

return mapResponse(await response.json());
}

/**
* Stream responses from the query agent.
*
* @deprecated Use {@link askStream} instead.
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
Expand All @@ -107,20 +152,23 @@ export class QueryAgent {
includeFinalState: false;
},
): AsyncGenerator<StreamedTokens>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress: false;
includeFinalState?: true;
},
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress?: true;
includeFinalState: false;
},
): AsyncGenerator<ProgressMessage | StreamedTokens>;
/** @deprecated Use {@link askStream} instead. */
stream(
query: string,
options?: QueryAgentStreamOptions & {
Expand Down Expand Up @@ -188,6 +236,101 @@ export class QueryAgent {
}
}

/**
* Ask query agent a question and stream the response.
*
* @param query - The natural language query string for the agent.
* @param options - Additional options for the run.
* @returns The response from the query agent.
*/
askStream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress: false;
includeFinalState: false;
},
): AsyncGenerator<StreamedTokens>;
askStream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress: false;
includeFinalState?: true;
},
): AsyncGenerator<StreamedTokens | QueryAgentResponse>;
askStream(
query: string,
options: QueryAgentStreamOptions & {
includeProgress?: true;
includeFinalState: false;
},
): AsyncGenerator<ProgressMessage | StreamedTokens>;
askStream(
query: string,
options?: QueryAgentStreamOptions & {
includeProgress?: true;
includeFinalState?: true;
},
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse>;
async *askStream(
query: string,
{
collections,
context,
includeProgress,
includeFinalState,
}: QueryAgentStreamOptions = {},
): AsyncGenerator<ProgressMessage | StreamedTokens | QueryAgentResponse> {
const targetCollections = collections ?? this.collections;

if (!targetCollections) {
throw Error("No collections provided to the query agent.");
}

const { host, bearerToken, headers } =
await this.client.getConnectionDetails();

const sseStream = fetchServerSentEvents(
`${this.agentsHost}/agent/stream_query`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: bearerToken!,
"X-Weaviate-Cluster-Url": host,
"X-Agent-Request-Origin": "typescript-client",
},
body: JSON.stringify({
headers,
query,
collections: mapCollections(targetCollections),
system_prompt: this.systemPrompt,
previous_response: context ? mapApiResponse(context) : undefined,
include_progress: includeProgress ?? true,
include_final_state: includeFinalState ?? true,
}),
},
);

for await (const event of sseStream) {
if (event.event === "error") {
await handleError(event.data);
}

let output: ProgressMessage | StreamedTokens | QueryAgentResponse;
if (event.event === "progress_message") {
output = mapProgressMessageFromSSE(event);
} else if (event.event === "streamed_tokens") {
output = mapStreamedTokensFromSSE(event);
} else if (event.event === "final_state") {
output = mapResponseFromSSE(event);
} else {
throw new Error(`Unexpected event type: ${event.event}: ${event.data}`);
}

yield output;
}
}

/**
* Run the Query Agent search-only mode.
*
Expand Down