Skip to content
190 changes: 189 additions & 1 deletion src/query/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import { WeaviateClient } from "weaviate-client";
import { QueryAgent } from "./agent.js";
import { ApiQueryAgentResponse } from "./response/api-response.js";
import { QueryAgentResponse } from "./response/response.js";
import {
QueryAgentResponse,
ComparisonOperator,
SearchModeResponse,
} from "./response/response.js";
import { QueryAgentSearcher } from "./search.js";
import { ApiSearchModeResponse } from "./response/api-response.js";
import { QueryAgentError } from "./response/error.js";

it("runs the query agent", async () => {
const mockClient = {
Expand Down Expand Up @@ -93,3 +100,184 @@ it("runs the query agent", async () => {
display: expect.any(Function),
});
});

it("prepareSearch returns a QueryAgentSearcher", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const agent = new QueryAgent(mockClient, {
systemPrompt: "test system prompt",
});

const searcher = agent.configureSearch("test query");
expect(searcher).toBeInstanceOf(QueryAgentSearcher);
});

it("search-only mode success: caches searches and sends on subsequent request", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const capturedBodies: ApiSearchModeResponse<undefined>[] = [];

const apiSuccess: ApiSearchModeResponse<undefined> = {
original_query: "Test this search only mode!",
searches: [
{
queries: ["search query"],
filters: [
[
{
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filter_operators: "AND",
collection: "test_collection",
},
],
usage: {
requests: 0,
request_tokens: undefined,
response_tokens: undefined,
total_tokens: undefined,
details: undefined,
},
total_time: 1.5,
search_results: {
objects: [
{
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
metadata: {},
references: {},
vectors: {},
properties: {
test_property: 1.0,
text: "hello",
},
},
{
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
metadata: {},
references: {},
vectors: {},
properties: {
test_property: 2.0,
text: "world!",
},
},
],
},
};

// Mock the API response, and capture the request body to assert later
global.fetch = jest.fn((url, init?: RequestInit) => {
if (init && init.body) {
capturedBodies.push(
JSON.parse(init.body as string) as ApiSearchModeResponse<undefined>,
);
}
return Promise.resolve({
ok: true,
json: () => Promise.resolve(apiSuccess),
} as Response);
}) as jest.Mock;

const agent = new QueryAgent(mockClient);
const searcher = agent.configureSearch("test query", {
collections: ["test_collection"],
});

const first = await searcher.run({ limit: 2, offset: 0 });

expect(first).toEqual<SearchModeResponse<undefined>>({
originalQuery: apiSuccess.original_query,
searches: [
{
collection: "test_collection",
queries: ["search query"],
filters: [
[
{
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filterOperators: "AND",
},
],
usage: {
requests: 0,
requestTokens: undefined,
responseTokens: undefined,
totalTokens: undefined,
details: undefined,
},
totalTime: 1.5,
searchResults: apiSuccess.search_results,
});

// First request should have searches: null (generation request)
expect(capturedBodies[0].searches).toBeNull();
const second = await searcher.run({ limit: 2, offset: 1 });
// Second request should include the original searches (execution request)
expect(capturedBodies[1].searches).toEqual(apiSuccess.searches);
// Response mapping should be the same (because response is mocked)
expect(second).toEqual(first);
});

it("search-only mode failure propagates QueryAgentError", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const errorJson = {
error: {
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
},
};

global.fetch = jest.fn(() =>
Promise.resolve({
ok: false,
text: () => Promise.resolve(JSON.stringify(errorJson)),
} as Response),
) as jest.Mock;

const agent = new QueryAgent(mockClient);
const searcher = agent.configureSearch("test query", {
collections: ["test_collection"],
});

try {
await searcher.run({ limit: 2, offset: 0 });
} catch (err) {
expect(err).toBeInstanceOf(QueryAgentError);
expect(err).toMatchObject({
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
});
}
});
33 changes: 33 additions & 0 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { mapApiResponse } from "./response/api-response-mapping.js";
import { fetchServerSentEvents } from "./response/server-sent-events.js";
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
import { handleError } from "./response/error.js";
import { QueryAgentSearcher } from "./search.js";

/**
* An agent for executing agentic queries against Weaviate.
Expand Down Expand Up @@ -185,6 +186,32 @@ export class QueryAgent {
yield output;
}
}

/**
* Configure a QueryAgentSearcher for the search-only mode of the query agent.
*
* This returns a configured QueryAgentSearcher, but does not send any requests or
* run the agent. To do that, you should call the `run` method on the searcher.
*
* This allows you to paginate through a consistent results set, as calling the
* `run` method on the searcher multiple times will result in the same underlying
* searches being performed each time.
*
* @param query - The natural language query string for the agent.
* @param options - Additional options for configuring the searcher.
* @param options.collections - The collections to query. Will override any collections if passed in the constructor.
* @returns A configured QueryAgentSearcher for the search-only mode of the query agent.
*/
configureSearch<T = undefined>(
query: string,
{ collections }: QueryAgentSearchOnlyOptions = {},
): QueryAgentSearcher<T> {
return new QueryAgentSearcher(this.client, query, {
collections: collections ?? this.collections,
systemPrompt: this.systemPrompt,
agentsHost: this.agentsHost,
});
}
}

/** Options for the QueryAgent. */
Expand Down Expand Up @@ -216,3 +243,9 @@ export type QueryAgentStreamOptions = {
/** Include final state in the stream. */
includeFinalState?: boolean;
};

/** Options for the QueryAgent search-only run. */
export type QueryAgentSearchOnlyOptions = {
/** List of collections to query. Will override any collections if passed in the constructor. */
collections?: (string | QueryAgentCollectionConfig)[];
};
1 change: 1 addition & 0 deletions src/query/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from "./agent.js";
export { QueryAgentCollectionConfig } from "./collection.js";
export * from "./response/index.js";
export * from "./search.js";
10 changes: 10 additions & 0 deletions src/query/response/api-response.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { WeaviateReturn } from "weaviate-client";

import {
NumericMetrics,
TextMetrics,
Expand Down Expand Up @@ -177,3 +179,11 @@ export type ApiSource = {
object_id: string;
collection: string;
};

export type ApiSearchModeResponse<T> = {
original_query: string;
searches?: ApiSearchResult[];
usage: ApiUsage;
total_time: number;
search_results: WeaviateReturn<T>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly results also probably are returned from API as underscore_case and not camelCase, so we'd have to map it as well to make it developer friendly 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, good catch! I've adding mapping from the api snake_case to camelCase 👍 I've also removed all the generics on the types to make this work (we can add them back later if we want, but they were for typing properties but our search results are potentially multi-collection, so maybe didn't make sense anyway 🤷 ).

The types of the search result objects are also an extension of the Weaviate types, to add a collection field (which was missing from the original type).

};
36 changes: 28 additions & 8 deletions src/query/response/response-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
StreamedTokens,
ProgressMessage,
DateFilterValue,
SearchModeResponse,
} from "./response.js";

import {
Expand All @@ -20,6 +21,7 @@ import {
ApiUsage,
ApiSource,
ApiDateFilterValue,
ApiSearchModeResponse,
} from "./api-response.js";

import { ServerSentEvent } from "./server-sent-events.js";
Expand Down Expand Up @@ -47,15 +49,16 @@ export const mapResponse = (
};
};

const mapInnerSearches = (searches: ApiSearchResult[]): SearchResult[] =>
searches.map((result) => ({
collection: result.collection,
queries: result.queries,
filters: result.filters.map(mapPropertyFilters),
filterOperators: result.filter_operators,
}));

const mapSearches = (searches: ApiSearchResult[][]): SearchResult[][] =>
searches.map((searchGroup) =>
searchGroup.map((result) => ({
collection: result.collection,
queries: result.queries,
filters: result.filters.map(mapPropertyFilters),
filterOperators: result.filter_operators,
})),
);
searches.map((searchGroup) => mapInnerSearches(searchGroup));

const mapDatePropertyFilter = (
filterValue: ApiDateFilterValue,
Expand Down Expand Up @@ -298,3 +301,20 @@ export const mapResponseFromSSE = (
display: () => display(properties),
};
};

export const mapSearchOnlyResponse = <T>(
response: ApiSearchModeResponse<T>,
): {
mappedResponse: SearchModeResponse<T>;
apiSearches: ApiSearchResult[] | undefined;
} => {
const apiSearches = response.searches;
const mappedResponse: SearchModeResponse<T> = {
originalQuery: response.original_query,
searches: apiSearches ? mapInnerSearches(apiSearches) : undefined,
usage: mapUsage(response.usage),
totalTime: response.total_time,
searchResults: response.search_results,
};
return { mappedResponse, apiSearches };
};
10 changes: 10 additions & 0 deletions src/query/response/response.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { WeaviateReturn } from "weaviate-client";

export type QueryAgentResponse = {
outputType: "finalState";
originalQuery: string;
Expand Down Expand Up @@ -260,3 +262,11 @@ export type StreamedTokens = {
outputType: "streamedTokens";
delta: string;
};

export type SearchModeResponse<T> = {
originalQuery: string;
searches?: SearchResult[];
usage: Usage;
totalTime: number;
searchResults: WeaviateReturn<T>;
};
Loading