Skip to content
176 changes: 175 additions & 1 deletion src/query/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import { WeaviateClient } from "weaviate-client";
import { QueryAgent } from "./agent.js";
import { ApiQueryAgentResponse } from "./response/api-response.js";
import { QueryAgentResponse } from "./response/response.js";
import { QueryAgentResponse, ComparisonOperator } from "./response/response.js";
import { ApiSearchModeResponse } from "./response/api-response.js";
import { QueryAgentError } from "./response/error.js";

it("runs the query agent", async () => {
const mockClient = {
Expand Down Expand Up @@ -93,3 +95,175 @@ it("runs the query agent", async () => {
display: expect.any(Function),
});
});

it("search-only mode success: caches searches and sends on subsequent request", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const capturedBodies: ApiSearchModeResponse<undefined>[] = [];

const apiSuccess: ApiSearchModeResponse<undefined> = {
original_query: "Test this search only mode!",
searches: [
{
queries: ["search query"],
filters: [
[
{
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filter_operators: "AND",
collection: "test_collection",
},
],
usage: {
requests: 0,
request_tokens: undefined,
response_tokens: undefined,
total_tokens: undefined,
details: undefined,
},
total_time: 1.5,
search_results: {
objects: [
{
uuid: "e6dc0a31-76f8-4bd3-b563-677ced6eb557",
metadata: {},
references: {},
vectors: {},
properties: {
test_property: 1.0,
text: "hello",
},
},
{
uuid: "cf5401cc-f4f1-4eb9-a6a1-173d34f94339",
metadata: {},
references: {},
vectors: {},
properties: {
test_property: 2.0,
text: "world!",
},
},
],
},
};

// Mock the API response, and capture the request body to assert later
global.fetch = jest.fn((url, init?: RequestInit) => {
if (init && init.body) {
capturedBodies.push(
JSON.parse(init.body as string) as ApiSearchModeResponse<undefined>,
);
}
return Promise.resolve({
ok: true,
json: () => Promise.resolve(apiSuccess),
} as Response);
}) as jest.Mock;

const agent = new QueryAgent(mockClient);

const first = await agent.search("test query", {
limit: 2,
collections: ["test_collection"],
});
expect(first).toMatchObject({
originalQuery: apiSuccess.original_query,
searches: [
{
collection: "test_collection",
queries: ["search query"],
filters: [
[
{
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
],
],
filterOperators: "AND",
},
],
usage: {
requests: 0,
requestTokens: undefined,
responseTokens: undefined,
totalTokens: undefined,
details: undefined,
},
totalTime: 1.5,
searchResults: apiSuccess.search_results,
});
expect(typeof first.next).toBe("function");

// First request should have searches: null (generation request)
expect(capturedBodies[0].searches).toBeNull();

// Second request uses the next method on the first response
const second = await first.next({ limit: 2, offset: 1 });
// Second request should include the original searches (execution request)
expect(capturedBodies[1].searches).toEqual(apiSuccess.searches);
// Response mapping should be the same (because response is mocked)
expect(second).toMatchObject({
originalQuery: apiSuccess.original_query,
searches: first.searches,
usage: first.usage,
totalTime: first.totalTime,
searchResults: first.searchResults,
});
expect(typeof second.next).toBe("function");
});

it("search-only mode failure propagates QueryAgentError", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const errorJson = {
error: {
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
},
};

global.fetch = jest.fn(() =>
Promise.resolve({
ok: false,
text: () => Promise.resolve(JSON.stringify(errorJson)),
} as Response),
) as jest.Mock;

const agent = new QueryAgent(mockClient);
try {
await agent.search("test query", {
limit: 2,
collections: ["test_collection"],
});
} catch (err) {
expect(err).toBeInstanceOf(QueryAgentError);
expect(err).toMatchObject({
message: "Test error message",
code: "test_error_code",
details: { info: "test detail" },
});
}
});
29 changes: 29 additions & 0 deletions src/query/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { mapApiResponse } from "./response/api-response-mapping.js";
import { fetchServerSentEvents } from "./response/server-sent-events.js";
import { mapCollections, QueryAgentCollectionConfig } from "./collection.js";
import { handleError } from "./response/error.js";
import { QueryAgentSearcher } from "./search.js";
import { SearchModeResponse } from "./response/response.js";

/**
* An agent for executing agentic queries against Weaviate.
Expand Down Expand Up @@ -185,6 +187,25 @@ export class QueryAgent {
yield output;
}
}

/**
* Run the Query Agent search-only mode.
*
* Sends the initial search request and returns the first page of results.
* The returned response includes a `next` method for pagination which
* reuses the same underlying searches to ensure consistency across pages.
*/
async search<T = undefined>(
query: string,
{ limit = 20, collections }: QueryAgentSearchOnlyOptions = {},
): Promise<SearchModeResponse<T>> {
const searcher = new QueryAgentSearcher<T>(this.client, query, {
collections: collections ?? this.collections,
systemPrompt: this.systemPrompt,
agentsHost: this.agentsHost,
});
return searcher.run({ limit, offset: 0 });
}
}

/** Options for the QueryAgent. */
Expand Down Expand Up @@ -216,3 +237,11 @@ export type QueryAgentStreamOptions = {
/** Include final state in the stream. */
includeFinalState?: boolean;
};

/** Options for the QueryAgent search-only run. */
export type QueryAgentSearchOnlyOptions = {
/** The maximum number of results to return. */
limit?: number;
/** List of collections to query. Will override any collections if passed in the constructor. */
collections?: (string | QueryAgentCollectionConfig)[];
};
1 change: 1 addition & 0 deletions src/query/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from "./agent.js";
export { QueryAgentCollectionConfig } from "./collection.js";
export * from "./response/index.js";
export * from "./search.js";
10 changes: 10 additions & 0 deletions src/query/response/api-response.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { WeaviateReturn } from "weaviate-client";

import {
NumericMetrics,
TextMetrics,
Expand Down Expand Up @@ -177,3 +179,11 @@ export type ApiSource = {
object_id: string;
collection: string;
};

export type ApiSearchModeResponse<T> = {
original_query: string;
searches?: ApiSearchResult[];
usage: ApiUsage;
total_time: number;
search_results: WeaviateReturn<T>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly results also probably are returned from API as underscore_case and not camelCase, so we'd have to map it as well to make it developer friendly 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh, good catch! I've adding mapping from the api snake_case to camelCase 👍 I've also removed all the generics on the types to make this work (we can add them back later if we want, but they were for typing properties but our search results are potentially multi-collection, so maybe didn't make sense anyway 🤷 ).

The types of the search result objects are also an extension of the Weaviate types, to add a collection field (which was missing from the original type).

};
36 changes: 28 additions & 8 deletions src/query/response/response-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
StreamedTokens,
ProgressMessage,
DateFilterValue,
MappedSearchModeResponse,
} from "./response.js";

import {
Expand All @@ -20,6 +21,7 @@ import {
ApiUsage,
ApiSource,
ApiDateFilterValue,
ApiSearchModeResponse,
} from "./api-response.js";

import { ServerSentEvent } from "./server-sent-events.js";
Expand Down Expand Up @@ -47,15 +49,16 @@ export const mapResponse = (
};
};

const mapInnerSearches = (searches: ApiSearchResult[]): SearchResult[] =>
searches.map((result) => ({
collection: result.collection,
queries: result.queries,
filters: result.filters.map(mapPropertyFilters),
filterOperators: result.filter_operators,
}));

const mapSearches = (searches: ApiSearchResult[][]): SearchResult[][] =>
searches.map((searchGroup) =>
searchGroup.map((result) => ({
collection: result.collection,
queries: result.queries,
filters: result.filters.map(mapPropertyFilters),
filterOperators: result.filter_operators,
})),
);
searches.map((searchGroup) => mapInnerSearches(searchGroup));

const mapDatePropertyFilter = (
filterValue: ApiDateFilterValue,
Expand Down Expand Up @@ -298,3 +301,20 @@ export const mapResponseFromSSE = (
display: () => display(properties),
};
};

export const mapSearchOnlyResponse = <T>(
response: ApiSearchModeResponse<T>,
): {
mappedResponse: MappedSearchModeResponse<T>;
apiSearches: ApiSearchResult[] | undefined;
} => {
const apiSearches = response.searches;
const mappedResponse: MappedSearchModeResponse<T> = {
originalQuery: response.original_query,
searches: apiSearches ? mapInnerSearches(apiSearches) : undefined,
usage: mapUsage(response.usage),
totalTime: response.total_time,
searchResults: response.search_results,
};
return { mappedResponse, apiSearches };
};
22 changes: 22 additions & 0 deletions src/query/response/response.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { WeaviateReturn } from "weaviate-client";

export type QueryAgentResponse = {
outputType: "finalState";
originalQuery: string;
Expand Down Expand Up @@ -260,3 +262,23 @@ export type StreamedTokens = {
outputType: "streamedTokens";
delta: string;
};

export type MappedSearchModeResponse<T> = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like internal type and everything in this file is exported to user (see index.ts file).
So Maybe move it to response-mapping.ts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense 👍 Moving to response-mapping.ts would end up with circular dependencies, so I've just removed this type and am using Omit<SearchModeResponse, "next"> in it's place (since it's internal anyway)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works as well :) fyi circular dependencies are supported in JS/TS (especially fine with for types)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, TIL!

originalQuery: string;
searches?: SearchResult[];
usage: Usage;
totalTime: number;
searchResults: WeaviateReturn<T>;
};

/** Options for the executing a prepared QueryAgent search. */
export type SearchExecutionOptions = {
/** The maximum number of results to return. */
limit?: number;
/** The offset of the results to return, for paginating through query result sets. */
offset?: number;
};

export type SearchModeResponse<T> = MappedSearchModeResponse<T> & {
next: (options: SearchExecutionOptions) => Promise<SearchModeResponse<T>>;
};
Loading