Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 176 additions & 2 deletions src/query/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import { WeaviateClient } from "weaviate-client";
import { QueryAgent } from "./agent.js";
import { ApiQueryAgentResponse } from "./response/api-response.js";
import { QueryAgentResponse, ComparisonOperator } from "./response/response.js";
import { ApiSearchModeResponse } from "./response/api-response.js";
import {
QueryAgentResponse,
ComparisonOperator,
AskModeResponse,
} from "./response/response.js";
import {
ApiSearchModeResponse,
ApiAskModeResponse,
} from "./response/api-response.js";
import { QueryAgentError } from "./response/error.js";

it("runs the query agent", async () => {
Expand Down Expand Up @@ -96,6 +103,134 @@ it("runs the query agent", async () => {
});
});

it("runs the query agent ask", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
host: "test-cluster",
bearerToken: "test-token",
headers: { "X-Provider": "test-key" },
}),
} as unknown as WeaviateClient;

const apiSuccess: ApiAskModeResponse = {
searches: [
{
query: "search query",
filters: {
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
collection: "test_collection",
sort_property: undefined,
},
{
query: undefined,
filters: {
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
collection: "test_collection",
sort_property: {
property_name: "test_property",
order: "ascending",
tie_break: {
property_name: "test_property_2",
order: "descending",
tie_break: undefined,
},
},
},
],
aggregations: [],
usage: {
model_units: 1,
usage_in_plan: true,
remaining_plan_requests: 2,
},
total_time: 1.5,
is_partial_answer: false,
missing_information: [],
final_answer: "Test answer",
sources: [
{
object_id: "123",
collection: "test-collection",
},
],
};

global.fetch = jest.fn(() =>
Promise.resolve({
ok: true,
json: () => Promise.resolve(apiSuccess),
} as Response),
) as jest.Mock;

const agent = new QueryAgent(mockClient, {
systemPrompt: "test system prompt",
});

const response = await agent.ask("What is the capital of France?", {
collections: ["test-collection"],
});

expect(response).toEqual<AskModeResponse>({
outputType: "finalState",
searches: [
{
collection: "test_collection",
query: "search query",
filters: {
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
sortProperty: undefined,
},
{
collection: "test_collection",
query: undefined,
filters: {
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
sortProperty: {
propertyName: "test_property",
order: "ascending",
tieBreak: {
propertyName: "test_property_2",
order: "descending",
},
},
},
],
aggregations: [],
usage: {
modelUnits: 1,
usageInPlan: true,
remainingPlanRequests: 2,
},
totalTime: 1.5,
isPartialAnswer: false,
missingInformation: [],
finalAnswer: "Test answer",
sources: [
{
objectId: "123",
collection: "test-collection",
},
],
display: expect.any(Function),
});
});

it("search-only mode success: caches searches and sends on subsequent request", async () => {
const mockClient = {
getConnectionDetails: jest.fn().mockResolvedValue({
Expand All @@ -118,6 +253,26 @@ it("search-only mode success: caches searches and sends on subsequent request",
value: 0,
},
collection: "test_collection",
sort_property: undefined,
},
{
query: undefined,
filters: {
filter_type: "integer",
property_name: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
collection: "test_collection",
sort_property: {
property_name: "test_property",
order: "ascending",
tie_break: {
property_name: "test_property_2",
order: "descending",
tie_break: undefined,
},
},
},
],
usage: {
Expand Down Expand Up @@ -202,6 +357,25 @@ it("search-only mode success: caches searches and sends on subsequent request",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
sortProperty: undefined,
},
{
collection: "test_collection",
query: undefined,
filters: {
filterType: "integer",
propertyName: "test_property",
operator: ComparisonOperator.GreaterThan,
value: 0,
},
sortProperty: {
propertyName: "test_property",
order: "ascending",
tieBreak: {
propertyName: "test_property_2",
order: "descending",
},
},
},
],
usage: {
Expand Down
7 changes: 7 additions & 0 deletions src/query/response/api-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ export type ApiAskModeResponse = {
sources?: ApiSource[];
};

export type ApiQuerySort = {
property_name: string;
order: "ascending" | "descending";
tie_break?: ApiQuerySort;
};

export type ApiSearch = {
query?: string;
filters?: ApiPropertyFilter | ApiFilterAndOr;
collection: string;
sort_property?: ApiQuerySort;
};

export type ApiAggregation = {
Expand Down
11 changes: 11 additions & 0 deletions src/query/response/response-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
AskModeResponse,
Search,
ModelUnitUsage,
QuerySort,
} from "./response.js";

import {
Expand All @@ -32,6 +33,7 @@ import {
ApiAskModeResponse,
ApiSearch,
ApiModelUnitUsage,
ApiQuerySort,
} from "./api-response.js";

import { ServerSentEvent } from "./server-sent-events.js";
Expand Down Expand Up @@ -67,8 +69,17 @@ const mapSearches = (searches: ApiSearch[]): Search[] =>
query: search.query,
filters: search.filters ? mapFilter(search.filters) : undefined,
collection: search.collection,
sortProperty: search.sort_property
? mapQuerySort(search.sort_property)
: undefined,
}));

const mapQuerySort = (sort: ApiQuerySort): QuerySort => ({
propertyName: sort.property_name,
order: sort.order,
tieBreak: sort.tie_break ? mapQuerySort(sort.tie_break) : undefined,
});

const mapUsage = (usage: ApiModelUnitUsage): ModelUnitUsage => ({
modelUnits: usage.model_units,
usageInPlan: usage.usage_in_plan,
Expand Down
7 changes: 7 additions & 0 deletions src/query/response/response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export type Search = {
query?: string;
filters?: PropertyFilter | FilterAndOr;
collection: string;
sortProperty?: QuerySort;
};

export type Aggregation = {
Expand All @@ -37,6 +38,12 @@ export type ModelUnitUsage = {
remainingPlanRequests: number;
};

export type QuerySort = {
propertyName: string;
order: "ascending" | "descending";
tieBreak?: QuerySort;
};

export type QueryAgentResponse = {
outputType: "finalState";
originalQuery: string;
Expand Down