diff --git a/src/query/agent.test.ts b/src/query/agent.test.ts index d0eb948..90d1ffb 100644 --- a/src/query/agent.test.ts +++ b/src/query/agent.test.ts @@ -1,8 +1,15 @@ import { WeaviateClient } from "weaviate-client"; import { QueryAgent } from "./agent.js"; import { ApiQueryAgentResponse } from "./response/api-response.js"; -import { QueryAgentResponse, ComparisonOperator } from "./response/response.js"; -import { ApiSearchModeResponse } from "./response/api-response.js"; +import { + QueryAgentResponse, + ComparisonOperator, + AskModeResponse, +} from "./response/response.js"; +import { + ApiSearchModeResponse, + ApiAskModeResponse, +} from "./response/api-response.js"; import { QueryAgentError } from "./response/error.js"; it("runs the query agent", async () => { @@ -96,6 +103,134 @@ it("runs the query agent", async () => { }); }); +it("runs the query agent ask", async () => { + const mockClient = { + getConnectionDetails: jest.fn().mockResolvedValue({ + host: "test-cluster", + bearerToken: "test-token", + headers: { "X-Provider": "test-key" }, + }), + } as unknown as WeaviateClient; + + const apiSuccess: ApiAskModeResponse = { + searches: [ + { + query: "search query", + filters: { + filter_type: "integer", + property_name: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + collection: "test_collection", + sort_property: undefined, + }, + { + query: undefined, + filters: { + filter_type: "integer", + property_name: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + collection: "test_collection", + sort_property: { + property_name: "test_property", + order: "ascending", + tie_break: { + property_name: "test_property_2", + order: "descending", + tie_break: undefined, + }, + }, + }, + ], + aggregations: [], + usage: { + model_units: 1, + usage_in_plan: true, + remaining_plan_requests: 2, + }, + total_time: 1.5, + is_partial_answer: false, + missing_information: [], + final_answer: "Test answer", + sources: [ + { + object_id: "123", + collection: "test-collection", + }, + ], + }; + + global.fetch = jest.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(apiSuccess), + } as Response), + ) as jest.Mock; + + const agent = new QueryAgent(mockClient, { + systemPrompt: "test system prompt", + }); + + const response = await agent.ask("What is the capital of France?", { + collections: ["test-collection"], + }); + + expect(response).toEqual({ + outputType: "finalState", + searches: [ + { + collection: "test_collection", + query: "search query", + filters: { + filterType: "integer", + propertyName: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + sortProperty: undefined, + }, + { + collection: "test_collection", + query: undefined, + filters: { + filterType: "integer", + propertyName: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + sortProperty: { + propertyName: "test_property", + order: "ascending", + tieBreak: { + propertyName: "test_property_2", + order: "descending", + }, + }, + }, + ], + aggregations: [], + usage: { + modelUnits: 1, + usageInPlan: true, + remainingPlanRequests: 2, + }, + totalTime: 1.5, + isPartialAnswer: false, + missingInformation: [], + finalAnswer: "Test answer", + sources: [ + { + objectId: "123", + collection: "test-collection", + }, + ], + display: expect.any(Function), + }); +}); + it("search-only mode success: caches searches and sends on subsequent request", async () => { const mockClient = { getConnectionDetails: jest.fn().mockResolvedValue({ @@ -118,6 +253,26 @@ it("search-only mode success: caches searches and sends on subsequent request", value: 0, }, collection: "test_collection", + sort_property: undefined, + }, + { + query: undefined, + filters: { + filter_type: "integer", + property_name: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + collection: "test_collection", + sort_property: { + property_name: "test_property", + order: "ascending", + tie_break: { + property_name: "test_property_2", + order: "descending", + tie_break: undefined, + }, + }, }, ], usage: { @@ -202,6 +357,25 @@ it("search-only mode success: caches searches and sends on subsequent request", operator: ComparisonOperator.GreaterThan, value: 0, }, + sortProperty: undefined, + }, + { + collection: "test_collection", + query: undefined, + filters: { + filterType: "integer", + propertyName: "test_property", + operator: ComparisonOperator.GreaterThan, + value: 0, + }, + sortProperty: { + propertyName: "test_property", + order: "ascending", + tieBreak: { + propertyName: "test_property_2", + order: "descending", + }, + }, }, ], usage: { diff --git a/src/query/response/api-response.ts b/src/query/response/api-response.ts index ec0776b..60eed88 100644 --- a/src/query/response/api-response.ts +++ b/src/query/response/api-response.ts @@ -19,10 +19,17 @@ export type ApiAskModeResponse = { sources?: ApiSource[]; }; +export type ApiQuerySort = { + property_name: string; + order: "ascending" | "descending"; + tie_break?: ApiQuerySort; +}; + export type ApiSearch = { query?: string; filters?: ApiPropertyFilter | ApiFilterAndOr; collection: string; + sort_property?: ApiQuerySort; }; export type ApiAggregation = { diff --git a/src/query/response/response-mapping.ts b/src/query/response/response-mapping.ts index 0ce2fef..9cc6fe2 100644 --- a/src/query/response/response-mapping.ts +++ b/src/query/response/response-mapping.ts @@ -16,6 +16,7 @@ import { AskModeResponse, Search, ModelUnitUsage, + QuerySort, } from "./response.js"; import { @@ -32,6 +33,7 @@ import { ApiAskModeResponse, ApiSearch, ApiModelUnitUsage, + ApiQuerySort, } from "./api-response.js"; import { ServerSentEvent } from "./server-sent-events.js"; @@ -67,8 +69,17 @@ const mapSearches = (searches: ApiSearch[]): Search[] => query: search.query, filters: search.filters ? mapFilter(search.filters) : undefined, collection: search.collection, + sortProperty: search.sort_property + ? mapQuerySort(search.sort_property) + : undefined, })); +const mapQuerySort = (sort: ApiQuerySort): QuerySort => ({ + propertyName: sort.property_name, + order: sort.order, + tieBreak: sort.tie_break ? mapQuerySort(sort.tie_break) : undefined, +}); + const mapUsage = (usage: ApiModelUnitUsage): ModelUnitUsage => ({ modelUnits: usage.model_units, usageInPlan: usage.usage_in_plan, diff --git a/src/query/response/response.ts b/src/query/response/response.ts index dd23eff..21808c0 100644 --- a/src/query/response/response.ts +++ b/src/query/response/response.ts @@ -17,6 +17,7 @@ export type Search = { query?: string; filters?: PropertyFilter | FilterAndOr; collection: string; + sortProperty?: QuerySort; }; export type Aggregation = { @@ -37,6 +38,12 @@ export type ModelUnitUsage = { remainingPlanRequests: number; }; +export type QuerySort = { + propertyName: string; + order: "ascending" | "descending"; + tieBreak?: QuerySort; +}; + export type QueryAgentResponse = { outputType: "finalState"; originalQuery: string;