Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 92 additions & 17 deletions src/collections/aggregate/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { FilterValue } from '../filters/index.js';

import { WeaviateQueryError } from '../../errors.js';
import { Aggregator } from '../../graphql/index.js';
import { toBase64FromMedia } from '../../index.js';
import { PrimitiveKeys, toBase64FromMedia } from '../../index.js';
import { Bm25QueryProperty } from '../query/types.js';
import { Serialize } from '../serialize/index.js';

export type AggregateBaseOptions<T, M> = {
Expand Down Expand Up @@ -35,6 +36,19 @@ export type AggregateNearOptions<T, M> = AggregateBaseOptions<T, M> & {
targetVector?: string;
};

export type AggregateHybridOptions<T, M> = AggregateBaseOptions<T, M> & {
alpha?: number;
maxVectorDistance?: number;
objectLimit?: number;
queryProperties?: (PrimitiveKeys<T> | Bm25QueryProperty<T>)[];
targetVector?: string;
vector?: number[];
};

export type AggregateGroupByHybridOptions<T, M> = AggregateHybridOptions<T, M> & {
groupBy: (keyof T & string) | GroupByAggregate<T>;
};

export type AggregateGroupByNearOptions<T, M> = AggregateNearOptions<T, M> & {
groupBy: (keyof T & string) | GroupByAggregate<T>;
};
Expand Down Expand Up @@ -346,9 +360,26 @@ class AggregateManager<T> implements Aggregate<T> {
this.tenant = tenant;

this.groupBy = {
hybrid: <M extends PropertiesMetrics<T> | undefined = undefined>(
query: string,
opts: AggregateGroupByHybridOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]> => {
let builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withHybrid({
query: query,
alpha: opts?.alpha,
maxVectorDistance: opts?.maxVectorDistance,
properties: opts?.queryProperties as string[],
targetVectors: opts?.targetVector ? [opts.targetVector] : undefined,
vector: opts?.vector,
});
if (opts?.objectLimit) {
builder = builder.withObjectLimit(opts.objectLimit);
}
return this.doGroupBy(builder);
},
nearImage: async <M extends PropertiesMetrics<T> | undefined = undefined>(
image: string | Buffer,
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]> => {
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearImage({
image: await toBase64FromMedia(image),
Expand All @@ -363,7 +394,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearObject: <M extends PropertiesMetrics<T> | undefined = undefined>(
id: string,
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]> => {
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearObject({
id: id,
Expand All @@ -378,7 +409,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearText: <M extends PropertiesMetrics<T> | undefined = undefined>(
query: string | string[],
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]> => {
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearText({
concepts: Array.isArray(query) ? query : [query],
Expand All @@ -393,7 +424,7 @@ class AggregateManager<T> implements Aggregate<T> {
},
nearVector: <M extends PropertiesMetrics<T> | undefined = undefined>(
vector: number[],
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]> => {
const builder = this.base(opts?.returnMetrics, opts?.filters, opts?.groupBy).withNearVector({
vector: vector,
Expand Down Expand Up @@ -489,6 +520,24 @@ class AggregateManager<T> implements Aggregate<T> {
return new AggregateManager<T>(connection, name, dbVersionSupport, consistencyLevel, tenant);
}

hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts?: AggregateHybridOptions<T, M>
): Promise<AggregateResult<T, M>> {
let builder = this.base(opts?.returnMetrics, opts?.filters).withHybrid({
query: query,
alpha: opts?.alpha,
maxVectorDistance: opts?.maxVectorDistance,
properties: opts?.queryProperties as string[],
targetVectors: opts?.targetVector ? [opts.targetVector] : undefined,
vector: opts?.vector,
});
if (opts?.objectLimit) {
builder = builder.withObjectLimit(opts.objectLimit);
}
return this.do(builder);
}

async nearImage<M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts?: AggregateNearOptions<T, M>
Expand Down Expand Up @@ -602,6 +651,19 @@ class AggregateManager<T> implements Aggregate<T> {
export interface Aggregate<T> {
/** This namespace contains methods perform a group by search while aggregating metrics. */
groupBy: AggregateGroupBy<T>;
/**
* Aggregate metrics over the objects returned by a hybrid search on this collection.
*
* This method requires that the objects in the collection have associated vectors.
*
* @param {string} query The text query to search for.
* @param {AggregateHybridOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts?: AggregateHybridOptions<T, M>
): Promise<AggregateResult<T, M>>;
/**
* Aggregate metrics over the objects returned by a near image vector search on this collection.
*
Expand Down Expand Up @@ -673,67 +735,80 @@ export interface Aggregate<T> {

export interface AggregateGroupBy<T> {
/**
* Aggregate metrics over the objects returned by a near image vector search on this collection.
* Aggregate metrics over the objects grouped by a specified property and returned by a hybrid search on this collection.
*
* This method requires that the objects in the collection have associated vectors.
*
* @param {string} query The text query to search for.
* @param {AggregateGroupByHybridOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
hybrid<M extends PropertiesMetrics<T>>(
query: string,
opts: AggregateGroupByHybridOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects grouped by a specified property and returned by a near image vector search on this collection.
*
* At least one of `certainty`, `distance`, or `object_limit` must be specified here for the vector search.
*
* This method requires a vectorizer capable of handling base64-encoded images, e.g. `img2vec-neural`, `multi2vec-clip`, and `multi2vec-bind`.
*
* @param {string | Buffer} image The image to search on. This can be a base64 string, a file path string, or a buffer.
* @param {AggregateGroupByNearOptions<T, M>} [opts] The options for the request.
* @param {AggregateGroupByNearOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
nearImage<M extends PropertiesMetrics<T>>(
image: string | Buffer,
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects returned by a near object search on this collection.
* Aggregate metrics over the objects grouped by a specified property and returned by a near object search on this collection.
*
* At least one of `certainty`, `distance`, or `object_limit` must be specified here for the vector search.
*
* This method requires that the objects in the collection have associated vectors.
*
* @param {string} id The ID of the object to search for.
* @param {AggregateGroupByNearOptions<T, M>} [opts] The options for the request.
* @param {AggregateGroupByNearOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
nearObject<M extends PropertiesMetrics<T>>(
id: string,
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects returned by a near text vector search on this collection.
* Aggregate metrics over the objects grouped by a specified property and returned by a near text vector search on this collection.
*
* At least one of `certainty`, `distance`, or `object_limit` must be specified here for the vector search.
*
* This method requires a vectorizer capable of handling text, e.g. `text2vec-contextionary`, `text2vec-openai`, etc.
*
* @param {string | string[]} query The text to search for.
* @param {AggregateGroupByNearOptions<T, M>} [opts] The options for the request.
* @param {AggregateGroupByNearOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
nearText<M extends PropertiesMetrics<T>>(
query: string | string[],
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over the objects returned by a near vector search on this collection.
* Aggregate metrics over the objects grouped by a specified property and returned by a near vector search on this collection.
*
* At least one of `certainty`, `distance`, or `object_limit` must be specified here for the vector search.
*
* This method requires that the objects in the collection have associated vectors.
*
* @param {number[]} vector The vector to search for.
* @param {AggregateGroupByNearOptions<T, M>} [opts] The options for the request.
* @param {AggregateGroupByNearOptions<T, M>} opts The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects returned by the vector search.
*/
nearVector<M extends PropertiesMetrics<T>>(
vector: number[],
opts?: AggregateGroupByNearOptions<T, M>
opts: AggregateGroupByNearOptions<T, M>
): Promise<AggregateGroupByResult<T, M>[]>;
/**
* Aggregate metrics over all the objects in this collection without any vector search.
* Aggregate metrics over all the objects in this collection grouped by a specified property without any vector search.
*
* @param {AggregateGroupByOptions<T, M>} [opts] The options for the request.
* @returns {Promise<AggregateGroupByResult<T, M>[]>} The aggregated metrics for the objects in the collection.
Expand Down
84 changes: 84 additions & 0 deletions src/collections/aggregate/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ describe('Testing of collection.aggregate.overAll with a multi-tenancy collectio

beforeAll(async () => {
client = await weaviate.connectToLocal();
collection = client.collections.get(collectionName);
return client.collections
.create({
name: collectionName,
Expand Down Expand Up @@ -389,3 +390,86 @@ describe('Testing of collection.aggregate.overAll with a multi-tenancy collectio
WeaviateQueryError
));
});

describe('Testing of collection.aggregate search methods', () => {
let client: WeaviateClient;
let collection: Collection;
const collectionName = 'TestCollectionAggregateSearches';

let uuid: string;

afterAll(async () => {
return (await client).collections.delete(collectionName).catch((err) => {
console.error(err);
throw err;
});
});

beforeAll(async () => {
client = await weaviate.connectToLocal();
collection = client.collections.get(collectionName);
return client.collections
.create({
name: collectionName,
properties: [
{
name: 'text',
dataType: 'text',
},
],
vectorizers: weaviate.configure.vectorizer.text2VecContextionary(),
})
.then(async () => {
const data: Array<any> = [];
for (let i = 0; i < 100; i++) {
data.push({
properties: {
text: 'test',
},
});
}
await collection.data.insertMany(data).then((res) => {
uuid = res.uuids[0];
});
});
});

it('should return an aggregation on a hybrid search', async () => {
const result = await collection.aggregate.hybrid('test', {
alpha: 0.5,
maxVectorDistance: 0,
queryProperties: ['text'],
returnMetrics: collection.metrics.aggregate('text').text(['count']),
});
expect(result.totalCount).toEqual(100);
expect(result.properties.text.count).toEqual(100);
});

it('should return an aggregation on a nearText search', async () => {
const result = await collection.aggregate.nearText('test', {
objectLimit: 100,
returnMetrics: collection.metrics.aggregate('text').text(['count']),
});
expect(result.totalCount).toEqual(100);
expect(result.properties.text.count).toEqual(100);
});

it('should return an aggregation on a nearVector search', async () => {
const obj = await collection.query.fetchObjectById(uuid, { includeVector: true });
const result = await collection.aggregate.nearVector(obj?.vectors.default!, {
objectLimit: 100,
returnMetrics: collection.metrics.aggregate('text').text(['count']),
});
expect(result.totalCount).toEqual(100);
expect(result.properties.text.count).toEqual(100);
});

it('should return an aggregation on a nearObject search', async () => {
const result = await collection.aggregate.nearObject(uuid, {
objectLimit: 100,
returnMetrics: collection.metrics.aggregate('text').text(['count']),
});
expect(result.totalCount).toEqual(100);
expect(result.properties.text.count).toEqual(100);
});
});
15 changes: 15 additions & 0 deletions src/graphql/aggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Connection from '../connection/index.js';
import { WhereFilter } from '../openapi/types.js';
import { CommandBase } from '../validation/commandBase.js';
import { isValidPositiveIntProperty } from '../validation/number.js';
import Hybrid, { HybridArgs } from './hybrid.js';
import NearMedia, {
NearAudioArgs,
NearDepthArgs,
Expand All @@ -24,6 +25,7 @@ export default class Aggregator extends CommandBase {
private className?: string;
private fields?: string;
private groupBy?: string[];
private hybridString?: string;
private includesNearMediaFilter: boolean;
private limit?: number;
private nearMediaString?: string;
Expand Down Expand Up @@ -133,6 +135,15 @@ export default class Aggregator extends CommandBase {
return this;
};

withHybrid = (args: HybridArgs) => {
try {
this.hybridString = new Hybrid(args).toString();
} catch (e: any) {
this.addError(e.toString());
}
return this;
};

withObjectLimit = (objectLimit: number) => {
if (!isValidPositiveIntProperty(objectLimit)) {
throw new Error('objectLimit must be a non-negative integer');
Expand Down Expand Up @@ -222,6 +233,10 @@ export default class Aggregator extends CommandBase {
args = [...args, `groupBy:${JSON.stringify(this.groupBy)}`];
}

if (this.hybridString) {
args = [...args, `hybrid:${this.hybridString}`];
}

if (this.limit) {
args = [...args, `limit:${this.limit}`];
}
Expand Down
7 changes: 7 additions & 0 deletions src/graphql/hybrid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface HybridArgs {
targetVectors?: string[];
fusionType?: FusionType;
searches?: HybridSubSearch[];
maxVectorDistance?: number;
}

export interface NearTextSubSearch {
Expand Down Expand Up @@ -87,6 +88,7 @@ export default class GraphQLHybrid {
private targetVectors?: string[];
private fusionType?: FusionType;
private searches?: GraphQLHybridSubSearch[];
private maxVectorDistance?: number;

constructor(args: HybridArgs) {
this.alpha = args.alpha;
Expand All @@ -96,6 +98,7 @@ export default class GraphQLHybrid {
this.targetVectors = args.targetVectors;
this.fusionType = args.fusionType;
this.searches = args.searches?.map((search) => new GraphQLHybridSubSearch(search));
this.maxVectorDistance = args.maxVectorDistance;
}

toString() {
Expand Down Expand Up @@ -125,6 +128,10 @@ export default class GraphQLHybrid {
args = [...args, `searches:[${this.searches.map((search) => search.toString()).join(',')}]`];
}

if (this.maxVectorDistance !== undefined) {
args = [...args, `maxVectorDistance:${this.maxVectorDistance}`];
}

return `{${args.join(',')}}`;
}
}
Loading