Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ function dbgenerated(expr: String?): Any {
function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use.
*/
function search(field: String, search: String): Boolean {
} @@@expressionContext([AccessPolicy])
// /**
// * If the field value matches the search condition with [full-text-search](https://www.prisma.io/docs/concepts/components/prisma-client/full-text-search). Need to enable "fullTextSearch" preview feature to use.
// */
// function search(field: String, search: String): Boolean {
// } @@@expressionContext([AccessPolicy])

/**
* Checks the field value starts with the search string. By default, the search is case-sensitive, and
Expand All @@ -151,25 +151,25 @@ function endsWith(field: String, search: String, caseInSensitive: Boolean?): Boo
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* If the field value (a list) has the given search value
* Checks if the list field has the given search value
*/
function has(field: Any[], search: Any): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* If the field value (a list) has every element of the search list
* Checks if the list field has at least one element of the search list
*/
function hasEvery(field: Any[], search: Any[]): Boolean {
function hasSome(field: Any[], search: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* If the field value (a list) has at least one element of the search list
* Checks if the list field has every element of the search list
*/
function hasSome(field: Any[], search: Any[]): Boolean {
function hasEvery(field: Any[], search: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* If the field value (a list) is empty
* Checks if the list field is empty
*/
function isEmpty(field: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])
Expand Down Expand Up @@ -551,9 +551,9 @@ function length(field: Any): Int {


/**
* Validates a string field value matches a regex.
* Validates a string field value matches a regex pattern.
*/
function regex(field: String, regex: String): Boolean {
function regex(field: String, pattern: String): Boolean {
} @@@expressionContext([ValidationRule])

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/language/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ export function getAllDeclarationsIncludingImports(documents: LangiumDocuments,
}

export function getAuthDecl(decls: (DataModel | TypeDef)[]) {
let authModel = decls.find((m) => hasAttribute(m, '@@auth'));
let authModel = decls.find((d) => hasAttribute(d, '@@auth'));
if (!authModel) {
authModel = decls.find((m) => m.name === 'User');
authModel = decls.find((d) => d.name === 'User');
}
return authModel;
}
Expand Down
6 changes: 6 additions & 0 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ export class ClientImpl<Schema extends SchemaDef> {
return (procOptions[name] as Function).apply(this, [this, ...args]);
}

async $connect() {
await this.kysely.connection().execute(async (conn) => {
await conn.executeQuery(sql`select 1`.compile(this.kysely));
});
}

async $disconnect() {
await this.kysely.destroy();
}
Expand Down
7 changes: 6 additions & 1 deletion packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ export type ClientContract<Schema extends SchemaDef> = {
$unuseAll(): ClientContract<Schema>;

/**
* Disconnects the underlying Kysely instance from the database.
* Eagerly connects to the database.
*/
$connect(): Promise<void>;

/**
* Explicitly disconnects from the database.
*/
$disconnect(): Promise<void>;

Expand Down
18 changes: 3 additions & 15 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
result = this.buildSkipTake(result, skip, take);

// orderBy
result = this.buildOrderBy(
result,
model,
modelAlias,
args.orderBy,
skip !== undefined || take !== undefined,
negateOrderBy,
);
result = this.buildOrderBy(result, model, modelAlias, args.orderBy, negateOrderBy);

// distinct
if ('distinct' in args && (args as any).distinct) {
Expand Down Expand Up @@ -748,15 +741,10 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
model: string,
modelAlias: string,
orderBy: OrArray<OrderBy<Schema, GetModels<Schema>, boolean, boolean>> | undefined,
useDefaultIfEmpty: boolean,
negated: boolean,
) {
if (!orderBy) {
if (useDefaultIfEmpty) {
orderBy = makeDefaultOrderBy(this.schema, model);
} else {
return query;
}
return query;
}

let result = query;
Expand Down Expand Up @@ -862,7 +850,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
),
);
});
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated);
result = this.buildOrderBy(result, fieldDef.type, relationModel, value, negated);
}
}
}
Expand Down
9 changes: 1 addition & 8 deletions packages/runtime/src/client/crud/operations/aggregate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ export class AggregateOperationHandler<Schema extends SchemaDef> extends BaseOpe
subQuery = this.dialect.buildSkipTake(subQuery, skip, take);

// orderBy
subQuery = this.dialect.buildOrderBy(
subQuery,
this.model,
this.model,
parsedArgs.orderBy,
skip !== undefined || take !== undefined,
negateOrderBy,
);
subQuery = this.dialect.buildOrderBy(subQuery, this.model, this.model, parsedArgs.orderBy, negateOrderBy);

return subQuery.as('$sub');
});
Expand Down
50 changes: 16 additions & 34 deletions packages/runtime/src/client/crud/operations/group-by.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,33 @@ export class GroupByOperationHandler<Schema extends SchemaDef> extends BaseOpera
// parse args
const parsedArgs = this.inputValidator.validateGroupByArgs(this.model, normalizedArgs);

let query = this.kysely.selectFrom((eb) => {
// nested query for filtering and pagination

// where
let subQuery = eb
.selectFrom(this.model)
.selectAll()
.where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where));

// skip & take
const skip = parsedArgs?.skip;
let take = parsedArgs?.take;
let negateOrderBy = false;
if (take !== undefined && take < 0) {
negateOrderBy = true;
take = -take;
}
subQuery = this.dialect.buildSkipTake(subQuery, skip, take);

// default orderBy
subQuery = this.dialect.buildOrderBy(
subQuery,
this.model,
this.model,
undefined,
skip !== undefined || take !== undefined,
negateOrderBy,
);

return subQuery.as('$sub');
});
let query = this.kysely
.selectFrom(this.model)
.where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where));

const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, '$sub');
const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field);

// groupBy
const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]);
query = query.groupBy(bys.map((by) => fieldRef(by)));

// skip & take
const skip = parsedArgs?.skip;
let take = parsedArgs?.take;
let negateOrderBy = false;
if (take !== undefined && take < 0) {
negateOrderBy = true;
take = -take;
}
query = this.dialect.buildSkipTake(query, skip, take);

// orderBy
if (parsedArgs.orderBy) {
query = this.dialect.buildOrderBy(query, this.model, '$sub', parsedArgs.orderBy, false, false);
query = this.dialect.buildOrderBy(query, this.model, this.model, parsedArgs.orderBy, negateOrderBy);
}

if (parsedArgs.having) {
query = query.having(() => this.dialect.buildFilter(this.model, '$sub', parsedArgs.having));
query = query.having(() => this.dialect.buildFilter(this.model, this.model, parsedArgs.having));
}

// select all by fields
Expand Down
38 changes: 31 additions & 7 deletions packages/runtime/src/client/crud/validator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ import {
addStringValidation,
} from './utils';

const schemaCache = new WeakMap<SchemaDef, Map<string, ZodType>>();

type GetSchemaFunc<Schema extends SchemaDef, Options> = (model: GetModels<Schema>, options: Options) => ZodType;

export class InputValidator<Schema extends SchemaDef> {
private schemaCache = new Map<string, ZodType>();

constructor(private readonly client: ClientContract<Schema>) {}

private get schema() {
Expand Down Expand Up @@ -192,6 +192,24 @@ export class InputValidator<Schema extends SchemaDef> {
);
}

private getSchemaCache(cacheKey: string) {
let thisCache = schemaCache.get(this.schema);
if (!thisCache) {
thisCache = new Map<string, ZodType>();
schemaCache.set(this.schema, thisCache);
}
return thisCache.get(cacheKey);
}

private setSchemaCache(cacheKey: string, schema: ZodType) {
let thisCache = schemaCache.get(this.schema);
if (!thisCache) {
thisCache = new Map<string, ZodType>();
schemaCache.set(this.schema, thisCache);
}
return thisCache.set(cacheKey, schema);
}

private validate<T, Options = undefined>(
model: GetModels<Schema>,
operation: string,
Expand All @@ -200,14 +218,16 @@ export class InputValidator<Schema extends SchemaDef> {
args: unknown,
) {
const cacheKey = stableStringify({
type: 'model',
model,
operation,
options,
extraValidationsEnabled: this.extraValidationsEnabled,
});
let schema = this.schemaCache.get(cacheKey!);
let schema = this.getSchemaCache(cacheKey!);
if (!schema) {
schema = getSchema(model, options);
this.schemaCache.set(cacheKey!, schema);
this.setSchemaCache(cacheKey!, schema);
}
const { error, data } = schema.safeParse(args);
if (error) {
Expand Down Expand Up @@ -293,8 +313,12 @@ export class InputValidator<Schema extends SchemaDef> {
}

private makeTypeDefSchema(type: string): z.ZodType {
const key = `$typedef-${type}`;
let schema = this.schemaCache.get(key);
const key = stableStringify({
type: 'typedef',
name: type,
extraValidationsEnabled: this.extraValidationsEnabled,
});
let schema = this.getSchemaCache(key!);
if (schema) {
return schema;
}
Expand All @@ -316,7 +340,7 @@ export class InputValidator<Schema extends SchemaDef> {
),
)
.passthrough();
this.schemaCache.set(key, schema);
this.setSchemaCache(key!, schema);
return schema;
}

Expand Down
28 changes: 23 additions & 5 deletions packages/runtime/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
type RootOperationNode,
} from 'kysely';
import { match } from 'ts-pattern';
import type { GetModels, SchemaDef } from '../../schema';
import type { GetModels, ModelDef, SchemaDef, TypeDefDef } from '../../schema';
import { type ClientImpl } from '../client-impl';
import { TransactionIsolationLevel, type ClientContract } from '../contract';
import { InternalError, QueryError, ZenStackError } from '../errors';
Expand All @@ -42,7 +42,7 @@ type MutationInfo<Schema extends SchemaDef> = {
};

export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQueryExecutor {
private readonly nameMapper: QueryNameMapper;
private readonly nameMapper: QueryNameMapper | undefined;

constructor(
private client: ClientImpl<Schema>,
Expand All @@ -54,7 +54,21 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
private suppressMutationHooks: boolean = false,
) {
super(compiler, adapter, connectionProvider, plugins);
this.nameMapper = new QueryNameMapper(client.$schema);

if (this.schemaHasMappedNames(client.$schema)) {
this.nameMapper = new QueryNameMapper(client.$schema);
}
}

private schemaHasMappedNames(schema: Schema) {
const hasMapAttr = (decl: ModelDef | TypeDefDef) => {
if (decl.attributes?.some((attr) => attr.name === '@@map')) {
return true;
}
return Object.values(decl.fields).some((field) => field.attributes?.some((attr) => attr.name === '@map'));
};

return Object.values(schema.models).some(hasMapAttr) || Object.values(schema.typeDefs ?? []).some(hasMapAttr);
}

private get kysely() {
Expand Down Expand Up @@ -170,7 +184,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer

if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
// no need to handle mutation hooks, just proceed
const finalQuery = this.nameMapper.transformNode(query);
const finalQuery = this.processNameMapping(query);
compiled = this.compileQuery(finalQuery);
if (parameters) {
compiled = { ...compiled, parameters };
Expand All @@ -189,7 +203,7 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
};
}
const finalQuery = this.nameMapper.transformNode(query);
const finalQuery = this.processNameMapping(query);
compiled = this.compileQuery(finalQuery);
if (parameters) {
compiled = { ...compiled, parameters };
Expand Down Expand Up @@ -239,6 +253,10 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
return result;
}

private processNameMapping<Node extends RootOperationNode>(query: Node): Node {
return this.nameMapper?.transformNode(query) ?? query;
}

private createClientForConnection(connection: DatabaseConnection, inTx: boolean) {
const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection));
innerExecutor.suppressMutationHooks = true;
Expand Down
Loading