diff --git a/packages/language/src/index.ts b/packages/language/src/index.ts index 4b578f31..ab577c7f 100644 --- a/packages/language/src/index.ts +++ b/packages/language/src/index.ts @@ -20,7 +20,7 @@ export class DocumentLoadError extends Error { export async function loadDocument( fileName: string, - pluginModelFiles: string[] = [], + additionalModelFiles: string[] = [], ): Promise< { success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] } > { @@ -50,9 +50,9 @@ export async function loadDocument( URI.file(path.resolve(path.join(_dirname, '../res', STD_LIB_MODULE_NAME))), ); - // load plugin model files + // load additional model files const pluginDocs = await Promise.all( - pluginModelFiles.map((file) => + additionalModelFiles.map((file) => services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))), ), ); diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index f6daf04d..9bc6f664 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -127,6 +127,20 @@ export class PolicyHandler extends OperationNodeTransf // --- Post mutation work --- if (hasPostUpdatePolicies && result.rows.length > 0) { + // verify if before-update rows and post-update rows still id-match + if (beforeUpdateInfo) { + invariant(beforeUpdateInfo.rows.length === result.rows.length); + const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel); + for (const postRow of result.rows) { + const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f])); + if (!beforeRow) { + throw new QueryError( + 'Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.', + ); + } + } + } + // entities updated filter const idConditions = this.buildIdConditions(mutationModel, result.rows); @@ -234,10 +248,15 @@ export class PolicyHandler extends OperationNodeTransf if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) { return undefined; } + + // combine update's where with policy filter + const policyFilter = this.buildPolicyFilter(model, model, 'update'); + const combinedFilter = where ? conjunction(this.dialect, [where.where, policyFilter]) : policyFilter; + const query: SelectQueryNode = { kind: 'SelectQueryNode', from: FromNode.create([TableNode.create(model)]), - where, + where: WhereNode.create(combinedFilter), selections: [...beforeUpdateAccessFields.map((f) => SelectionNode.create(ColumnNode.create(f)))], }; const result = await proceed(query); diff --git a/packages/runtime/src/client/crud/dialects/base-dialect.ts b/packages/runtime/src/client/crud/dialects/base-dialect.ts index 7357c8f5..642297b8 100644 --- a/packages/runtime/src/client/crud/dialects/base-dialect.ts +++ b/packages/runtime/src/client/crud/dialects/base-dialect.ts @@ -36,6 +36,8 @@ import { } from '../../query-utils'; export abstract class BaseCrudDialect { + protected eb = expressionBuilder(); + constructor( protected readonly schema: Schema, protected readonly options: ClientOptions, @@ -51,9 +53,9 @@ export abstract class BaseCrudDialect { // #region common query builders - buildSelectModel(eb: ExpressionBuilder, model: string, modelAlias: string) { + buildSelectModel(model: string, modelAlias: string) { const modelDef = requireModel(this.schema, model); - let result = eb.selectFrom(model === modelAlias ? model : `${model} as ${modelAlias}`); + let result = this.eb.selectFrom(model === modelAlias ? model : `${model} as ${modelAlias}`); // join all delegate bases let joinBase = modelDef.baseModel; while (joinBase) { @@ -73,7 +75,7 @@ export abstract class BaseCrudDialect { // where if (args.where) { - result = result.where((eb) => this.buildFilter(eb, model, modelAlias, args?.where)); + result = result.where(() => this.buildFilter(model, modelAlias, args?.where)); } // skip && take @@ -112,21 +114,16 @@ export abstract class BaseCrudDialect { return result; } - buildFilter( - eb: ExpressionBuilder, - model: string, - modelAlias: string, - where: boolean | object | undefined, - ) { + buildFilter(model: string, modelAlias: string, where: boolean | object | undefined) { if (where === true || where === undefined) { - return this.true(eb); + return this.true(); } if (where === false) { - return this.false(eb); + return this.false(); } - let result = this.true(eb); + let result = this.true(); const _where = flattenCompoundUniqueFilters(this.schema, model, where); for (const [key, payload] of Object.entries(_where)) { @@ -139,33 +136,28 @@ export abstract class BaseCrudDialect { } if (this.isLogicalCombinator(key)) { - result = this.and(eb, result, this.buildCompositeFilter(eb, model, modelAlias, key, payload)); + result = this.and(result, this.buildCompositeFilter(model, modelAlias, key, payload)); continue; } const fieldDef = requireField(this.schema, model, key); if (fieldDef.relation) { - result = this.and(eb, result, this.buildRelationFilter(eb, model, modelAlias, key, fieldDef, payload)); + result = this.and(result, this.buildRelationFilter(model, modelAlias, key, fieldDef, payload)); } else { // if the field is from a base model, build a reference from that model - const fieldRef = this.fieldRef( - fieldDef.originModel ?? model, - key, - eb, - fieldDef.originModel ?? modelAlias, - ); + const fieldRef = this.fieldRef(fieldDef.originModel ?? model, key, fieldDef.originModel ?? modelAlias); if (fieldDef.array) { - result = this.and(eb, result, this.buildArrayFilter(eb, fieldRef, fieldDef, payload)); + result = this.and(result, this.buildArrayFilter(fieldRef, fieldDef, payload)); } else { - result = this.and(eb, result, this.buildPrimitiveFilter(eb, fieldRef, fieldDef, payload)); + result = this.and(result, this.buildPrimitiveFilter(fieldRef, fieldDef, payload)); } } } // call expression builder and combine the results if ('$expr' in _where && typeof _where['$expr'] === 'function') { - result = this.and(eb, result, _where['$expr'](eb)); + result = this.and(result, _where['$expr'](this.eb)); } return result; @@ -183,9 +175,8 @@ export abstract class BaseCrudDialect { const orderByItems = ensureArray(_orderBy).flatMap((obj) => Object.entries(obj)); - const eb = expressionBuilder(); const subQueryAlias = `${model}$cursor$sub`; - const cursorFilter = this.buildFilter(eb, model, subQueryAlias, cursor); + const cursorFilter = this.buildFilter(model, subQueryAlias, cursor); let result = query; const filters: ExpressionWrapper[] = []; @@ -198,17 +189,17 @@ export abstract class BaseCrudDialect { const _order = negateOrderBy ? (order === 'asc' ? 'desc' : 'asc') : order; const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '='; andFilters.push( - eb( - eb.ref(`${modelAlias}.${field}`), + this.eb( + this.eb.ref(`${modelAlias}.${field}`), op, - this.buildSelectModel(eb, model, subQueryAlias) + this.buildSelectModel(model, subQueryAlias) .select(`${subQueryAlias}.${field}`) .where(cursorFilter), ), ); } - filters.push(eb.and(andFilters)); + filters.push(this.eb.and(andFilters)); } result = result.where((eb) => eb.or(filters)); @@ -221,7 +212,6 @@ export abstract class BaseCrudDialect { } protected buildCompositeFilter( - eb: ExpressionBuilder, model: string, modelAlias: string, key: (typeof LOGICAL_COMBINATORS)[number], @@ -229,38 +219,24 @@ export abstract class BaseCrudDialect { ): Expression { return match(key) .with('AND', () => - this.and( - eb, - ...enumerate(payload).map((subPayload) => this.buildFilter(eb, model, modelAlias, subPayload)), - ), + this.and(...enumerate(payload).map((subPayload) => this.buildFilter(model, modelAlias, subPayload))), ) .with('OR', () => - this.or( - eb, - ...enumerate(payload).map((subPayload) => this.buildFilter(eb, model, modelAlias, subPayload)), - ), + this.or(...enumerate(payload).map((subPayload) => this.buildFilter(model, modelAlias, subPayload))), ) - .with('NOT', () => eb.not(this.buildCompositeFilter(eb, model, modelAlias, 'AND', payload))) + .with('NOT', () => this.eb.not(this.buildCompositeFilter(model, modelAlias, 'AND', payload))) .exhaustive(); } - private buildRelationFilter( - eb: ExpressionBuilder, - model: string, - modelAlias: string, - field: string, - fieldDef: FieldDef, - payload: any, - ) { + private buildRelationFilter(model: string, modelAlias: string, field: string, fieldDef: FieldDef, payload: any) { if (!fieldDef.array) { - return this.buildToOneRelationFilter(eb, model, modelAlias, field, fieldDef, payload); + return this.buildToOneRelationFilter(model, modelAlias, field, fieldDef, payload); } else { - return this.buildToManyRelationFilter(eb, model, modelAlias, field, fieldDef, payload); + return this.buildToManyRelationFilter(model, modelAlias, field, fieldDef, payload); } } private buildToOneRelationFilter( - eb: ExpressionBuilder, model: string, modelAlias: string, field: string, @@ -272,10 +248,10 @@ export abstract class BaseCrudDialect { if (ownedByModel && !fieldDef.originModel) { // can be short-circuited to FK null check - return this.and(eb, ...keyPairs.map(({ fk }) => eb(sql.ref(`${modelAlias}.${fk}`), 'is', null))); + return this.and(...keyPairs.map(({ fk }) => this.eb(sql.ref(`${modelAlias}.${fk}`), 'is', null))); } else { // translate it to `{ is: null }` filter - return this.buildToOneRelationFilter(eb, model, modelAlias, field, fieldDef, { is: null }); + return this.buildToOneRelationFilter(model, modelAlias, field, fieldDef, { is: null }); } } @@ -290,10 +266,10 @@ export abstract class BaseCrudDialect { ); const filterResultField = `${field}$filter`; - const joinSelect = eb + const joinSelect = this.eb .selectFrom(`${fieldDef.type} as ${joinAlias}`) - .where(() => this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right))))) - .select(() => eb.fn.count(eb.lit(1)).as(filterResultField)); + .where(() => this.and(...joinPairs.map(([left, right]) => this.eb(sql.ref(left), '=', sql.ref(right))))) + .select(() => this.eb.fn.count(this.eb.lit(1)).as(filterResultField)); const conditions: Expression[] = []; @@ -301,12 +277,12 @@ export abstract class BaseCrudDialect { if ('is' in payload) { if (payload.is === null) { // check if not found - conditions.push(eb(joinSelect, '=', 0)); + conditions.push(this.eb(joinSelect, '=', 0)); } else { // check if found conditions.push( - eb( - joinSelect.where(() => this.buildFilter(eb, fieldDef.type, joinAlias, payload.is)), + this.eb( + joinSelect.where(() => this.buildFilter(fieldDef.type, joinAlias, payload.is)), '>', 0, ), @@ -317,16 +293,15 @@ export abstract class BaseCrudDialect { if ('isNot' in payload) { if (payload.isNot === null) { // check if found - conditions.push(eb(joinSelect, '>', 0)); + conditions.push(this.eb(joinSelect, '>', 0)); } else { conditions.push( this.or( - eb, // is null - eb(joinSelect, '=', 0), + this.eb(joinSelect, '=', 0), // found one that matches the filter - eb( - joinSelect.where(() => this.buildFilter(eb, fieldDef.type, joinAlias, payload.isNot)), + this.eb( + joinSelect.where(() => this.buildFilter(fieldDef.type, joinAlias, payload.isNot)), '=', 0, ), @@ -336,19 +311,18 @@ export abstract class BaseCrudDialect { } } else { conditions.push( - eb( - joinSelect.where(() => this.buildFilter(eb, fieldDef.type, joinAlias, payload)), + this.eb( + joinSelect.where(() => this.buildFilter(fieldDef.type, joinAlias, payload)), '>', 0, ), ); } - return this.and(eb, ...conditions); + return this.and(...conditions); } private buildToManyRelationFilter( - eb: ExpressionBuilder, model: string, modelAlias: string, field: string, @@ -357,7 +331,7 @@ export abstract class BaseCrudDialect { ) { // null check needs to be converted to fk "is null" checks if (payload === null) { - return eb(sql.ref(`${modelAlias}.${field}`), 'is', null); + return this.eb(sql.ref(`${modelAlias}.${field}`), 'is', null); } const relationModel = fieldDef.type; @@ -391,17 +365,15 @@ export abstract class BaseCrudDialect { } else { const relationKeyPairs = getRelationForeignKeyFieldPairs(this.schema, model, field); - let result = this.true(eb); + let result = this.true(); for (const { fk, pk } of relationKeyPairs.keyPairs) { if (relationKeyPairs.ownedByModel) { result = this.and( - eb, result, eb(sql.ref(`${modelAlias}.${fk}`), '=', sql.ref(`${relationFilterSelectAlias}.${pk}`)), ); } else { result = this.and( - eb, result, eb(sql.ref(`${modelAlias}.${pk}`), '=', sql.ref(`${relationFilterSelectAlias}.${fk}`)), ); @@ -411,7 +383,7 @@ export abstract class BaseCrudDialect { } }; - let result = this.true(eb); + let result = this.true(); for (const [key, subPayload] of Object.entries(payload)) { if (!subPayload) { @@ -421,15 +393,12 @@ export abstract class BaseCrudDialect { switch (key) { case 'some': { result = this.and( - eb, result, - eb( - this.buildSelectModel(eb, relationModel, relationFilterSelectAlias) - .select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count')) - .where(buildPkFkWhereRefs(eb)) - .where((eb1) => - this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload), - ), + this.eb( + this.buildSelectModel(relationModel, relationFilterSelectAlias) + .select(() => this.eb.fn.count(this.eb.lit(1)).as('$count')) + .where(buildPkFkWhereRefs(this.eb)) + .where(() => this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), '>', 0, ), @@ -439,16 +408,13 @@ export abstract class BaseCrudDialect { case 'every': { result = this.and( - eb, result, - eb( - this.buildSelectModel(eb, relationModel, relationFilterSelectAlias) + this.eb( + this.buildSelectModel(relationModel, relationFilterSelectAlias) .select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count')) - .where(buildPkFkWhereRefs(eb)) - .where((eb1) => - eb1.not( - this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload), - ), + .where(buildPkFkWhereRefs(this.eb)) + .where(() => + this.eb.not(this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), ), '=', 0, @@ -459,15 +425,12 @@ export abstract class BaseCrudDialect { case 'none': { result = this.and( - eb, result, - eb( - this.buildSelectModel(eb, relationModel, relationFilterSelectAlias) - .select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count')) - .where(buildPkFkWhereRefs(eb)) - .where((eb1) => - this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload), - ), + this.eb( + this.buildSelectModel(relationModel, relationFilterSelectAlias) + .select(() => this.eb.fn.count(this.eb.lit(1)).as('$count')) + .where(buildPkFkWhereRefs(this.eb)) + .where(() => this.buildFilter(relationModel, relationFilterSelectAlias, subPayload)), '=', 0, ), @@ -480,12 +443,7 @@ export abstract class BaseCrudDialect { return result; } - private buildArrayFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - fieldDef: FieldDef, - payload: any, - ) { + private buildArrayFilter(fieldRef: Expression, fieldDef: FieldDef, payload: any) { const clauses: Expression[] = []; const fieldType = fieldDef.type as BuiltinType; @@ -498,27 +456,27 @@ export abstract class BaseCrudDialect { switch (key) { case 'equals': { - clauses.push(this.buildLiteralFilter(eb, fieldRef, fieldType, eb.val(value))); + clauses.push(this.buildLiteralFilter(fieldRef, fieldType, this.eb.val(value))); break; } case 'has': { - clauses.push(eb(fieldRef, '@>', eb.val([value]))); + clauses.push(this.eb(fieldRef, '@>', this.eb.val([value]))); break; } case 'hasEvery': { - clauses.push(eb(fieldRef, '@>', eb.val(value))); + clauses.push(this.eb(fieldRef, '@>', this.eb.val(value))); break; } case 'hasSome': { - clauses.push(eb(fieldRef, '&&', eb.val(value))); + clauses.push(this.eb(fieldRef, '&&', this.eb.val(value))); break; } case 'isEmpty': { - clauses.push(eb(fieldRef, value === true ? '=' : '!=', eb.val([]))); + clauses.push(this.eb(fieldRef, value === true ? '=' : '!=', this.eb.val([]))); break; } @@ -528,27 +486,27 @@ export abstract class BaseCrudDialect { } } - return this.and(eb, ...clauses); + return this.and(...clauses); } - buildPrimitiveFilter(eb: ExpressionBuilder, fieldRef: Expression, fieldDef: FieldDef, payload: any) { + buildPrimitiveFilter(fieldRef: Expression, fieldDef: FieldDef, payload: any) { if (payload === null) { - return eb(fieldRef, 'is', null); + return this.eb(fieldRef, 'is', null); } if (isEnum(this.schema, fieldDef.type)) { - return this.buildEnumFilter(eb, fieldRef, fieldDef, payload); + return this.buildEnumFilter(fieldRef, fieldDef, payload); } return ( match(fieldDef.type as BuiltinType) - .with('String', () => this.buildStringFilter(eb, fieldRef, payload)) + .with('String', () => this.buildStringFilter(fieldRef, payload)) .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => - this.buildNumberFilter(eb, fieldRef, type, payload), + this.buildNumberFilter(fieldRef, type, payload), ) - .with('Boolean', () => this.buildBooleanFilter(eb, fieldRef, payload)) - .with('DateTime', () => this.buildDateTimeFilter(eb, fieldRef, payload)) - .with('Bytes', () => this.buildBytesFilter(eb, fieldRef, payload)) + .with('Boolean', () => this.buildBooleanFilter(fieldRef, payload)) + .with('DateTime', () => this.buildDateTimeFilter(fieldRef, payload)) + .with('Bytes', () => this.buildBytesFilter(fieldRef, payload)) // TODO: JSON filters .with('Json', () => { throw new InternalError('JSON filters are not supported yet'); @@ -560,12 +518,11 @@ export abstract class BaseCrudDialect { ); } - private buildLiteralFilter(eb: ExpressionBuilder, lhs: Expression, type: BuiltinType, rhs: unknown) { - return eb(lhs, '=', rhs !== null && rhs !== undefined ? this.transformPrimitive(rhs, type, false) : rhs); + private buildLiteralFilter(lhs: Expression, type: BuiltinType, rhs: unknown) { + return this.eb(lhs, '=', rhs !== null && rhs !== undefined ? this.transformPrimitive(rhs, type, false) : rhs); } private buildStandardFilter( - eb: ExpressionBuilder, type: BuiltinType, payload: any, lhs: Expression, @@ -577,7 +534,7 @@ export abstract class BaseCrudDialect { ) { if (payload === null || !isPlainObject(payload)) { return { - conditions: [this.buildLiteralFilter(eb, lhs, type, payload)], + conditions: [this.buildLiteralFilter(lhs, type, payload)], consumedKeys: [], }; } @@ -594,41 +551,40 @@ export abstract class BaseCrudDialect { } const rhs = Array.isArray(value) ? value.map(getRhs) : getRhs(value); const condition = match(op) - .with('equals', () => (rhs === null ? eb(lhs, 'is', null) : eb(lhs, '=', rhs))) + .with('equals', () => (rhs === null ? this.eb(lhs, 'is', null) : this.eb(lhs, '=', rhs))) .with('in', () => { invariant(Array.isArray(rhs), 'right hand side must be an array'); if (rhs.length === 0) { - return this.false(eb); + return this.false(); } else { - return eb(lhs, 'in', rhs); + return this.eb(lhs, 'in', rhs); } }) .with('notIn', () => { invariant(Array.isArray(rhs), 'right hand side must be an array'); if (rhs.length === 0) { - return this.true(eb); + return this.true(); } else { - return eb.not(eb(lhs, 'in', rhs)); + return this.eb.not(this.eb(lhs, 'in', rhs)); } }) - .with('lt', () => eb(lhs, '<', rhs)) - .with('lte', () => eb(lhs, '<=', rhs)) - .with('gt', () => eb(lhs, '>', rhs)) - .with('gte', () => eb(lhs, '>=', rhs)) - .with('not', () => eb.not(recurse(value))) + .with('lt', () => this.eb(lhs, '<', rhs)) + .with('lte', () => this.eb(lhs, '<=', rhs)) + .with('gt', () => this.eb(lhs, '>', rhs)) + .with('gte', () => this.eb(lhs, '>=', rhs)) + .with('not', () => this.eb.not(recurse(value))) // aggregations .with(P.union(...AGGREGATE_OPERATORS), (op) => { const innerResult = this.buildStandardFilter( - eb, type, value, - aggregate(eb, lhs, op), + aggregate(this.eb, lhs, op), getRhs, recurse, throwIfInvalid, ); consumedKeys.push(...innerResult.consumedKeys); - return this.and(eb, ...innerResult.conditions); + return this.and(...innerResult.conditions); }) .otherwise(() => { if (throwIfInvalid) { @@ -647,23 +603,18 @@ export abstract class BaseCrudDialect { return { conditions, consumedKeys }; } - private buildStringFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - payload: StringFilter, - ) { + private buildStringFilter(fieldRef: Expression, payload: StringFilter) { let mode: 'default' | 'insensitive' | undefined; if (payload && typeof payload === 'object' && 'mode' in payload) { mode = payload.mode; } const { conditions, consumedKeys } = this.buildStandardFilter( - eb, 'String', payload, - mode === 'insensitive' ? eb.fn('lower', [fieldRef]) : fieldRef, - (value) => this.prepStringCasing(eb, value, mode), - (value) => this.buildStringFilter(eb, fieldRef, value as StringFilter), + mode === 'insensitive' ? this.eb.fn('lower', [fieldRef]) : fieldRef, + (value) => this.prepStringCasing(this.eb, value, mode), + (value) => this.buildStringFilter(fieldRef, value as StringFilter), ); if (payload && typeof payload === 'object') { @@ -676,18 +627,18 @@ export abstract class BaseCrudDialect { const condition = match(key) .with('contains', () => mode === 'insensitive' - ? eb(fieldRef, 'ilike', sql.val(`%${value}%`)) - : eb(fieldRef, 'like', sql.val(`%${value}%`)), + ? this.eb(fieldRef, 'ilike', sql.val(`%${value}%`)) + : this.eb(fieldRef, 'like', sql.val(`%${value}%`)), ) .with('startsWith', () => mode === 'insensitive' - ? eb(fieldRef, 'ilike', sql.val(`${value}%`)) - : eb(fieldRef, 'like', sql.val(`${value}%`)), + ? this.eb(fieldRef, 'ilike', sql.val(`${value}%`)) + : this.eb(fieldRef, 'like', sql.val(`${value}%`)), ) .with('endsWith', () => mode === 'insensitive' - ? eb(fieldRef, 'ilike', sql.val(`%${value}`)) - : eb(fieldRef, 'like', sql.val(`%${value}`)), + ? this.eb(fieldRef, 'ilike', sql.val(`%${value}`)) + : this.eb(fieldRef, 'like', sql.val(`%${value}`)), ) .otherwise(() => { throw new QueryError(`Invalid string filter key: ${key}`); @@ -699,7 +650,7 @@ export abstract class BaseCrudDialect { } } - return this.and(eb, ...conditions); + return this.and(...conditions); } private prepStringCasing( @@ -720,93 +671,66 @@ export abstract class BaseCrudDialect { } } - private buildNumberFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - type: BuiltinType, - payload: any, - ) { + private buildNumberFilter(fieldRef: Expression, type: BuiltinType, payload: any) { const { conditions } = this.buildStandardFilter( - eb, type, payload, fieldRef, (value) => this.transformPrimitive(value, type, false), - (value) => this.buildNumberFilter(eb, fieldRef, type, value), + (value) => this.buildNumberFilter(fieldRef, type, value), ); - return this.and(eb, ...conditions); + return this.and(...conditions); } - private buildBooleanFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - payload: BooleanFilter, - ) { + private buildBooleanFilter(fieldRef: Expression, payload: BooleanFilter) { const { conditions } = this.buildStandardFilter( - eb, 'Boolean', payload, fieldRef, (value) => this.transformPrimitive(value, 'Boolean', false), - (value) => this.buildBooleanFilter(eb, fieldRef, value as BooleanFilter), + (value) => this.buildBooleanFilter(fieldRef, value as BooleanFilter), true, ['equals', 'not'], ); - return this.and(eb, ...conditions); + return this.and(...conditions); } - private buildDateTimeFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - payload: DateTimeFilter, - ) { + private buildDateTimeFilter(fieldRef: Expression, payload: DateTimeFilter) { const { conditions } = this.buildStandardFilter( - eb, 'DateTime', payload, fieldRef, (value) => this.transformPrimitive(value, 'DateTime', false), - (value) => this.buildDateTimeFilter(eb, fieldRef, value as DateTimeFilter), + (value) => this.buildDateTimeFilter(fieldRef, value as DateTimeFilter), true, ); - return this.and(eb, ...conditions); + return this.and(...conditions); } - private buildBytesFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - payload: BytesFilter, - ) { + private buildBytesFilter(fieldRef: Expression, payload: BytesFilter) { const conditions = this.buildStandardFilter( - eb, 'Bytes', payload, fieldRef, (value) => this.transformPrimitive(value, 'Bytes', false), - (value) => this.buildBytesFilter(eb, fieldRef, value as BytesFilter), + (value) => this.buildBytesFilter(fieldRef, value as BytesFilter), true, ['equals', 'in', 'notIn', 'not'], ); - return this.and(eb, ...conditions.conditions); + return this.and(...conditions.conditions); } - private buildEnumFilter( - eb: ExpressionBuilder, - fieldRef: Expression, - fieldDef: FieldDef, - payload: any, - ) { + private buildEnumFilter(fieldRef: Expression, fieldDef: FieldDef, payload: any) { const conditions = this.buildStandardFilter( - eb, 'String', payload, fieldRef, (value) => value, - (value) => this.buildEnumFilter(eb, fieldRef, fieldDef, value), + (value) => this.buildEnumFilter(fieldRef, fieldDef, value), true, ['equals', 'in', 'notIn', 'not'], ); - return this.and(eb, ...conditions.conditions); + return this.and(...conditions.conditions); } buildOrderBy( @@ -826,6 +750,14 @@ export abstract class BaseCrudDialect { } let result = query; + + const buildFieldRef = (model: string, field: string, modelAlias: string) => { + const fieldDef = requireField(this.schema, model, field); + return fieldDef.originModel + ? this.fieldRef(fieldDef.originModel, field, fieldDef.originModel) + : this.fieldRef(model, field, modelAlias); + }; + enumerate(orderBy).forEach((orderBy) => { for (const [field, value] of Object.entries(orderBy)) { if (!value) { @@ -838,8 +770,7 @@ export abstract class BaseCrudDialect { for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( - (eb) => - aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS), + (eb) => aggregate(eb, buildFieldRef(model, k, modelAlias), field as AGGREGATE_OPERATORS), sql.raw(this.negateSort(v, negated)), ); } @@ -852,7 +783,7 @@ export abstract class BaseCrudDialect { for (const [k, v] of Object.entries(value)) { invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`); result = result.orderBy( - (eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)), + (eb) => eb.fn.count(buildFieldRef(model, k, modelAlias)), sql.raw(this.negateSort(v, negated)), ); } @@ -865,7 +796,7 @@ export abstract class BaseCrudDialect { const fieldDef = requireField(this.schema, model, field); if (!fieldDef.relation) { - const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias); + const fieldRef = buildFieldRef(model, field, modelAlias); if (value === 'asc' || value === 'desc') { result = result.orderBy(fieldRef, this.negateSort(value, negated)); } else if ( @@ -898,11 +829,10 @@ export abstract class BaseCrudDialect { const sort = this.negateSort(value._count, negated); result = result.orderBy((eb) => { const subQueryAlias = `${modelAlias}$orderBy$${field}$count`; - let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias); + let subQuery = this.buildSelectModel(relationModel, subQueryAlias); const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias); subQuery = subQuery.where(() => this.and( - eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right))), ), ); @@ -915,10 +845,7 @@ export abstract class BaseCrudDialect { result = result.leftJoin(relationModel, (join) => { const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, relationModel); return join.on((eb) => - this.and( - eb, - ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right))), - ), + this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))), ); }); result = this.buildOrderBy(result, fieldDef.type, relationModel, value, false, negated); @@ -964,7 +891,7 @@ export abstract class BaseCrudDialect { } jsonObject[field] = eb.ref(`${subModel.name}.${field}`); } - return this.buildJsonObject(eb, jsonObject).as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); + return this.buildJsonObject(jsonObject).as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); }); } @@ -972,13 +899,12 @@ export abstract class BaseCrudDialect { } protected buildModelSelect( - eb: ExpressionBuilder, model: GetModels, subQueryAlias: string, payload: true | FindArgs, true>, selectAllFields: boolean, ) { - let subQuery = this.buildSelectModel(eb, model, subQueryAlias); + let subQuery = this.buildSelectModel(model, subQueryAlias); if (selectAllFields) { subQuery = this.buildSelectAllFields( @@ -1005,7 +931,7 @@ export abstract class BaseCrudDialect { const fieldDef = requireField(this.schema, model, field); if (fieldDef.computed) { // TODO: computed field from delegate base? - return query.select((eb) => this.fieldRef(model, field, eb, modelAlias).as(field)); + return query.select(() => this.fieldRef(model, field, modelAlias).as(field)); } else if (!fieldDef.originModel) { // regular field return query.select(sql.ref(`${modelAlias}.${field}`).as(field)); @@ -1085,14 +1011,14 @@ export abstract class BaseCrudDialect { value.where && typeof value.where === 'object' ) { - const filter = this.buildFilter(eb, fieldModel, fieldModel, value.where); + const filter = this.buildFilter(fieldModel, fieldModel, value.where); fieldCountQuery = fieldCountQuery.where(filter); } jsonObject[field] = fieldCountQuery; } - return this.buildJsonObject(eb, jsonObject); + return this.buildJsonObject(jsonObject); } // #endregion @@ -1103,12 +1029,12 @@ export abstract class BaseCrudDialect { return negated ? (sort === 'asc' ? 'desc' : 'asc') : sort; } - public true(eb: ExpressionBuilder): Expression { - return eb.lit(this.transformPrimitive(true, 'Boolean', false) as boolean); + public true(): Expression { + return this.eb.lit(this.transformPrimitive(true, 'Boolean', false) as boolean); } - public false(eb: ExpressionBuilder): Expression { - return eb.lit(this.transformPrimitive(false, 'Boolean', false) as boolean); + public false(): Expression { + return this.eb.lit(this.transformPrimitive(false, 'Boolean', false) as boolean); } public isTrue(expression: Expression) { @@ -1127,40 +1053,34 @@ export abstract class BaseCrudDialect { return (node as ValueNode).value === false || (node as ValueNode).value === 0; } - and(eb: ExpressionBuilder, ...args: Expression[]) { + and(...args: Expression[]) { const nonTrueArgs = args.filter((arg) => !this.isTrue(arg)); if (nonTrueArgs.length === 0) { - return this.true(eb); + return this.true(); } else if (nonTrueArgs.length === 1) { return nonTrueArgs[0]!; } else { - return eb.and(nonTrueArgs); + return this.eb.and(nonTrueArgs); } } - or(eb: ExpressionBuilder, ...args: Expression[]) { + or(...args: Expression[]) { const nonFalseArgs = args.filter((arg) => !this.isFalse(arg)); if (nonFalseArgs.length === 0) { - return this.false(eb); + return this.false(); } else if (nonFalseArgs.length === 1) { return nonFalseArgs[0]!; } else { - return eb.or(nonFalseArgs); + return this.eb.or(nonFalseArgs); } } - not(eb: ExpressionBuilder, ...args: Expression[]) { - return eb.not(this.and(eb, ...args)); + not(...args: Expression[]) { + return this.eb.not(this.and(...args)); } - fieldRef( - model: string, - field: string, - eb: ExpressionBuilder, - modelAlias?: string, - inlineComputedField = true, - ) { - return buildFieldRef(this.schema, model, field, this.options, eb, modelAlias, inlineComputedField); + fieldRef(model: string, field: string, modelAlias?: string, inlineComputedField = true) { + return buildFieldRef(this.schema, model, field, this.options, this.eb, modelAlias, inlineComputedField); } protected canJoinWithoutNestedSelect( @@ -1221,18 +1141,12 @@ export abstract class BaseCrudDialect { /** * Builds an Kysely expression that returns a JSON object for the given key-value pairs. */ - abstract buildJsonObject( - eb: ExpressionBuilder, - value: Record>, - ): ExpressionWrapper; + abstract buildJsonObject(value: Record>): ExpressionWrapper; /** * Builds an Kysely expression that returns the length of an array. */ - abstract buildArrayLength( - eb: ExpressionBuilder, - array: Expression, - ): ExpressionWrapper; + abstract buildArrayLength(array: Expression): ExpressionWrapper; /** * Builds an array literal SQL string for the given values. diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index b6c40661..0a60c350 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -153,7 +153,7 @@ export class PostgresCrudDialect extends BaseCrudDiale if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) { // build join directly - tbl = this.buildModelSelect(eb, relationModel, relationSelectName, payload, false); + tbl = this.buildModelSelect(relationModel, relationSelectName, payload, false); // parent join filter tbl = this.buildRelationJoinFilter( @@ -167,13 +167,7 @@ export class PostgresCrudDialect extends BaseCrudDiale } else { // join with a nested query tbl = eb.selectFrom(() => { - let subQuery = this.buildModelSelect( - eb, - relationModel, - `${relationSelectName}$t`, - payload, - true, - ); + let subQuery = this.buildModelSelect(relationModel, `${relationSelectName}$t`, payload, true); // parent join filter subQuery = this.buildRelationJoinFilter( @@ -237,7 +231,7 @@ export class PostgresCrudDialect extends BaseCrudDiale } else { const joinPairs = buildJoinPairs(this.schema, model, parentAlias, relationField, relationModelAlias); query = query.where((eb) => - this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))), + this.and(...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))), ); } return query; @@ -303,10 +297,7 @@ export class PostgresCrudDialect extends BaseCrudDiale ...Object.entries(relationModelDef.fields) .filter(([, value]) => !value.relation) .filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true)) - .map(([field]) => [ - sql.lit(field), - this.fieldRef(relationModel, field, eb, relationModelAlias, false), - ]) + .map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, relationModelAlias, false)]) .flatMap((v) => v), ); } else if (payload.select) { @@ -329,7 +320,7 @@ export class PostgresCrudDialect extends BaseCrudDiale ? // reference the synthesized JSON field eb.ref(`${parentResultName}$${field}.$data`) : // reference a plain field - this.fieldRef(relationModel, field, eb, relationModelAlias, false); + this.fieldRef(relationModel, field, relationModelAlias, false); return [sql.lit(field), fieldValue]; } }) @@ -396,8 +387,8 @@ export class PostgresCrudDialect extends BaseCrudDiale return query; } - override buildJsonObject(eb: ExpressionBuilder, value: Record>) { - return eb.fn( + override buildJsonObject(value: Record>) { + return this.eb.fn( 'jsonb_build_object', Object.entries(value).flatMap(([key, value]) => [sql.lit(key), value]), ); @@ -415,11 +406,8 @@ export class PostgresCrudDialect extends BaseCrudDiale return true; } - override buildArrayLength( - eb: ExpressionBuilder, - array: Expression, - ): ExpressionWrapper { - return eb.fn('array_length', [array]); + override buildArrayLength(array: Expression): ExpressionWrapper { + return this.eb.fn('array_length', [array]); } override buildArrayLiteralSQL(values: unknown[]): string { diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 5c024dfb..e163f464 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -155,7 +155,7 @@ export class SqliteCrudDialect extends BaseCrudDialect if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) { // join without needing a nested select on relation model - tbl = this.buildModelSelect(eb, relationModel, subQueryName, payload, false); + tbl = this.buildModelSelect(relationModel, subQueryName, payload, false); // add parent join filter tbl = this.buildRelationJoinFilter(tbl, model, relationField, subQueryName, parentAlias); @@ -166,7 +166,7 @@ export class SqliteCrudDialect extends BaseCrudDialect const selectModelAlias = `${parentAlias}$${relationField}$sub`; // select all fields - let selectModelQuery = this.buildModelSelect(eb, relationModel, selectModelAlias, payload, true); + let selectModelQuery = this.buildModelSelect(relationModel, selectModelAlias, payload, true); // add parent join filter selectModelQuery = this.buildRelationJoinFilter( @@ -203,10 +203,7 @@ export class SqliteCrudDialect extends BaseCrudDialect ...Object.entries(relationModelDef.fields) .filter(([, value]) => !value.relation) .filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true)) - .map(([field]) => [ - sql.lit(field), - this.fieldRef(relationModel, field, eb, subQueryName, false), - ]) + .map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, subQueryName, false)]) .flatMap((v) => v), ); } else if (payload.select) { @@ -237,7 +234,7 @@ export class SqliteCrudDialect extends BaseCrudDialect } else { return [ sql.lit(field), - this.fieldRef(relationModel, field, eb, subQueryName, false) as ArgsType, + this.fieldRef(relationModel, field, subQueryName, false) as ArgsType, ]; } } @@ -345,8 +342,8 @@ export class SqliteCrudDialect extends BaseCrudDialect return query; } - override buildJsonObject(eb: ExpressionBuilder, value: Record>) { - return eb.fn( + override buildJsonObject(value: Record>) { + return this.eb.fn( 'json_object', Object.entries(value).flatMap(([key, value]) => [sql.lit(key), value]), ); @@ -364,11 +361,8 @@ export class SqliteCrudDialect extends BaseCrudDialect return false; } - override buildArrayLength( - eb: ExpressionBuilder, - array: Expression, - ): ExpressionWrapper { - return eb.fn('json_array_length', [array]); + override buildArrayLength(array: Expression): ExpressionWrapper { + return this.eb.fn('json_array_length', [array]); } override buildArrayLiteralSQL(_values: unknown[]): string { diff --git a/packages/runtime/src/client/crud/operations/aggregate.ts b/packages/runtime/src/client/crud/operations/aggregate.ts index fe111481..5df07608 100644 --- a/packages/runtime/src/client/crud/operations/aggregate.ts +++ b/packages/runtime/src/client/crud/operations/aggregate.ts @@ -1,4 +1,3 @@ -import type { ExpressionBuilder } from 'kysely'; import { sql } from 'kysely'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; @@ -18,8 +17,8 @@ export class AggregateOperationHandler extends BaseOpe // table and where let subQuery = this.dialect - .buildSelectModel(eb as ExpressionBuilder, this.model, this.model) - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + .buildSelectModel(this.model, this.model) + .where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where)); // select fields: collect fields from aggregation body const selectedFields: string[] = []; diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 2aca2980..1832874a 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -143,7 +143,7 @@ export abstract class BaseOperationHandler { args: FindArgs, true> | undefined, ): Promise { // table - let query = this.dialect.buildSelectModel(expressionBuilder(), model, model); + let query = this.dialect.buildSelectModel(model, model); if (args) { query = this.dialect.buildFilterSortTake(model, args, query, model); @@ -1043,7 +1043,7 @@ export abstract class BaseOperationHandler { const idFields = requireIdFields(this.schema, model); const query = kysely .updateTable(model) - .where((eb) => this.dialect.buildFilter(eb, model, model, combinedWhere)) + .where(() => this.dialect.buildFilter(model, model, combinedWhere)) .set(updateFields) .returning(idFields as any) .modifyEnd( @@ -1155,7 +1155,7 @@ export abstract class BaseOperationHandler { const key = Object.keys(payload)[0]; const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, false); const eb = expressionBuilder(); - const fieldRef = this.dialect.fieldRef(model, field, eb); + const fieldRef = this.dialect.fieldRef(model, field); return match(key) .with('set', () => value) @@ -1178,7 +1178,7 @@ export abstract class BaseOperationHandler { const key = Object.keys(payload)[0]; const value = this.dialect.transformPrimitive(payload[key!], fieldDef.type as BuiltinType, true); const eb = expressionBuilder(); - const fieldRef = this.dialect.fieldRef(model, field, eb); + const fieldRef = this.dialect.fieldRef(model, field); return match(key) .with('set', () => value) @@ -1273,7 +1273,7 @@ export abstract class BaseOperationHandler { if (!shouldFallbackToIdFilter) { // simple filter query = query - .where((eb) => this.dialect.buildFilter(eb, model, model, where)) + .where(() => this.dialect.buildFilter(model, model, where)) .$if(limit !== undefined, (qb) => qb.limit(limit!)); } else { query = query.where((eb) => @@ -1284,8 +1284,8 @@ export abstract class BaseOperationHandler { ), 'in', this.dialect - .buildSelectModel(eb, filterModel, filterModel) - .where(this.dialect.buildFilter(eb, filterModel, filterModel, where)) + .buildSelectModel(filterModel, filterModel) + .where(this.dialect.buildFilter(filterModel, filterModel, where)) .select(this.buildIdFieldRefs(kysely, filterModel)) .$if(limit !== undefined, (qb) => qb.limit(limit!)), ), @@ -1968,7 +1968,7 @@ export abstract class BaseOperationHandler { } if (!needIdFilter) { - query = query.where((eb) => this.dialect.buildFilter(eb, model, model, where)); + query = query.where(() => this.dialect.buildFilter(model, model, where)); } else { query = query.where((eb) => eb( @@ -1978,8 +1978,8 @@ export abstract class BaseOperationHandler { ), 'in', this.dialect - .buildSelectModel(eb, filterModel, filterModel) - .where((eb) => this.dialect.buildFilter(eb, filterModel, filterModel, where)) + .buildSelectModel(filterModel, filterModel) + .where(() => this.dialect.buildFilter(filterModel, filterModel, where)) .select(this.buildIdFieldRefs(kysely, filterModel)) .$if(limit !== undefined, (qb) => qb.limit(limit!)), ), diff --git a/packages/runtime/src/client/crud/operations/count.ts b/packages/runtime/src/client/crud/operations/count.ts index 9c321d98..90451745 100644 --- a/packages/runtime/src/client/crud/operations/count.ts +++ b/packages/runtime/src/client/crud/operations/count.ts @@ -1,4 +1,3 @@ -import type { ExpressionBuilder } from 'kysely'; import { sql } from 'kysely'; import type { SchemaDef } from '../../../schema'; import { BaseOperationHandler } from './base'; @@ -16,8 +15,8 @@ export class CountOperationHandler extends BaseOperati // nested query for filtering and pagination let subQuery = this.dialect - .buildSelectModel(eb as ExpressionBuilder, this.model, this.model) - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + .buildSelectModel(this.model, this.model) + .where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where)); if (parsedArgs?.select && typeof parsedArgs.select === 'object') { // select fields diff --git a/packages/runtime/src/client/crud/operations/group-by.ts b/packages/runtime/src/client/crud/operations/group-by.ts index 14bb77b5..4f4a083f 100644 --- a/packages/runtime/src/client/crud/operations/group-by.ts +++ b/packages/runtime/src/client/crud/operations/group-by.ts @@ -1,4 +1,3 @@ -import { expressionBuilder } from 'kysely'; import { match } from 'ts-pattern'; import type { SchemaDef } from '../../../schema'; import { aggregate, getField } from '../../query-utils'; @@ -19,7 +18,7 @@ export class GroupByOperationHandler extends BaseOpera let subQuery = eb .selectFrom(this.model) .selectAll() - .where((eb1) => this.dialect.buildFilter(eb1, this.model, this.model, parsedArgs?.where)); + .where(() => this.dialect.buildFilter(this.model, this.model, parsedArgs?.where)); // skip & take const skip = parsedArgs?.skip; @@ -44,7 +43,7 @@ export class GroupByOperationHandler extends BaseOpera return subQuery.as('$sub'); }); - const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, expressionBuilder(), '$sub'); + const fieldRef = (field: string) => this.dialect.fieldRef(this.model, field, '$sub'); // groupBy const bys = typeof parsedArgs.by === 'string' ? [parsedArgs.by] : (parsedArgs.by as string[]); @@ -56,7 +55,7 @@ export class GroupByOperationHandler extends BaseOpera } if (parsedArgs.having) { - query = query.having((eb) => this.dialect.buildFilter(eb, this.model, '$sub', parsedArgs.having)); + query = query.having(() => this.dialect.buildFilter(this.model, '$sub', parsedArgs.having)); } // select all by fields diff --git a/packages/runtime/src/client/functions.ts b/packages/runtime/src/client/functions.ts index 35390916..3f1bc806 100644 --- a/packages/runtime/src/client/functions.ts +++ b/packages/runtime/src/client/functions.ts @@ -100,7 +100,7 @@ export const isEmpty: ZModelFunction = (eb, args, { dialect }: ZModelFuncti if (!field) { throw new Error('"field" parameter is required'); } - return eb(dialect.buildArrayLength(eb, field), '=', sql.lit(0)); + return eb(dialect.buildArrayLength(field), '=', sql.lit(0)); }; export const now: ZModelFunction = () => sql.raw('CURRENT_TIMESTAMP'); diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index fd234378..7662fba0 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -91,14 +91,29 @@ export async function generateTsSchemaInPlace(schemaPath: string) { return compileAndLoad(workDir); } -export async function loadSchema(schema: string) { +export async function loadSchema(schema: string, additionalSchemas?: Record) { if (!schema.includes('datasource ')) { schema = `${makePrelude('sqlite')}\n\n${schema}`; } + // create a temp folder + const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'zenstack-schema')); + // create a temp file - const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`); + const tempFile = path.join(tempDir, `schema.zmodel`); fs.writeFileSync(tempFile, schema); + + if (additionalSchemas) { + for (const [fileName, content] of Object.entries(additionalSchemas)) { + let name = fileName; + if (!name.endsWith('.zmodel')) { + name += '.zmodel'; + } + const filePath = path.join(tempDir, name); + fs.writeFileSync(filePath, content); + } + } + const r = await loadDocument(tempFile); expect(r).toSatisfy( (r) => r.success, diff --git a/tests/regression/test/v2-migrated/issue-1014.test.ts b/tests/regression/test/v2-migrated/issue-1014.test.ts new file mode 100644 index 00000000..70917ff2 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1014.test.ts @@ -0,0 +1,49 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +// TODO: field-level policy support +describe.skip('Regression for issue 1014', () => { + it('update', async () => { + const db = await createPolicyTestClient( + ` + model User { + id Int @id() @default(autoincrement()) + name String + posts Post[] + } + + model Post { + id Int @id() @default(autoincrement()) + title String + content String? + author User? @relation(fields: [authorId], references: [id]) + authorId Int? @allow('update', true, true) + + @@allow('read', true) + } + `, + ); + + const user = await db.$unuseAll().user.create({ data: { name: 'User1' } }); + const post = await db.$unuseAll().post.create({ data: { title: 'Post1' } }); + await expect(db.post.update({ where: { id: post.id }, data: { authorId: user.id } })).toResolveTruthy(); + }); + + it('read', async () => { + const db = await createPolicyTestClient( + ` + model Post { + id Int @id() @default(autoincrement()) + title String @allow('read', true, true) + content String + } + `, + ); + + const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } }); + await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull(); + await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({ + title: 'Post1', + }); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1058.test.ts b/tests/regression/test/v2-migrated/issue-1058.test.ts new file mode 100644 index 00000000..fed09565 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1058.test.ts @@ -0,0 +1,52 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1058', async () => { + const schema = ` + model User { + id String @id @default(cuid()) + name String + + userRankings UserRanking[] + userFavorites UserFavorite[] + } + + model Entity { + id String @id @default(cuid()) + name String + type String + userRankings UserRanking[] + userFavorites UserFavorite[] + + @@delegate(type) + } + + model Person extends Entity { + } + + model Studio extends Entity { + } + + + model UserRanking { + id String @id @default(cuid()) + rank Int + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction) + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + } + + model UserFavorite { + id String @id @default(cuid()) + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction) + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + } + `; + + await createTestClient(schema, { provider: 'postgresql' }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1078.test.ts b/tests/regression/test/v2-migrated/issue-1078.test.ts new file mode 100644 index 00000000..a3af1620 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1078.test.ts @@ -0,0 +1,53 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue 1078', () => { + it('regression1', async () => { + const db = await createPolicyTestClient( + ` + model Counter { + id String @id + + name String + value Int + + @@validate(value >= 0) + @@allow('all', true) + } + `, + ); + + await expect( + db.counter.create({ + data: { id: '1', name: 'It should create', value: 1 }, + }), + ).toResolveTruthy(); + + //! This query fails validation + await expect( + db.counter.update({ + where: { id: '1' }, + data: { name: 'It should update' }, + }), + ).toResolveTruthy(); + }); + + // TODO: field-level policy support + it.skip('regression2', async () => { + const db = await createPolicyTestClient( + ` + model Post { + id Int @id() @default(autoincrement()) + title String @allow('read', true, true) + content String + } + `, + ); + + const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } }); + await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull(); + await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({ + title: 'Post1', + }); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1080.test.ts b/tests/regression/test/v2-migrated/issue-1080.test.ts new file mode 100644 index 00000000..0f46beca --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1080.test.ts @@ -0,0 +1,129 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1080', async () => { + const db = await createPolicyTestClient( + ` +model Project { + id String @id @unique @default(uuid()) + Fields Field[] + + @@allow('all', true) +} + +model Field { + id String @id @unique @default(uuid()) + name String + Project Project @relation(fields: [projectId], references: [id]) + projectId String + + @@allow('all', true) +} + `, + ); + + const project = await db.project.create({ + include: { Fields: true }, + data: { + Fields: { + create: [{ name: 'first' }, { name: 'second' }], + }, + }, + }); + + let updated = await db.project.update({ + where: { id: project.id }, + include: { Fields: true }, + data: { + Fields: { + upsert: [ + { + where: { id: project.Fields[0].id }, + create: { name: 'first1' }, + update: { name: 'first1' }, + }, + { + where: { id: project.Fields[1].id }, + create: { name: 'second1' }, + update: { name: 'second1' }, + }, + ], + }, + }, + }); + expect(updated).toMatchObject({ + Fields: expect.arrayContaining([ + expect.objectContaining({ name: 'first1' }), + expect.objectContaining({ name: 'second1' }), + ]), + }); + + updated = await db.project.update({ + where: { id: project.id }, + include: { Fields: true }, + data: { + Fields: { + upsert: { + where: { id: project.Fields[0].id }, + create: { name: 'first2' }, + update: { name: 'first2' }, + }, + }, + }, + }); + expect(updated).toMatchObject({ + Fields: expect.arrayContaining([ + expect.objectContaining({ name: 'first2' }), + expect.objectContaining({ name: 'second1' }), + ]), + }); + + updated = await db.project.update({ + where: { id: project.id }, + include: { Fields: true }, + data: { + Fields: { + upsert: { + where: { id: project.Fields[0].id }, + create: { name: 'first3' }, + update: { name: 'first3' }, + }, + update: { + where: { id: project.Fields[1].id }, + data: { name: 'second3' }, + }, + }, + }, + }); + expect(updated).toMatchObject({ + Fields: expect.arrayContaining([ + expect.objectContaining({ name: 'first3' }), + expect.objectContaining({ name: 'second3' }), + ]), + }); + + updated = await db.project.update({ + where: { id: project.id }, + include: { Fields: true }, + data: { + Fields: { + upsert: { + where: { id: 'non-exist' }, + create: { name: 'third1' }, + update: { name: 'third1' }, + }, + update: { + where: { id: project.Fields[1].id }, + data: { name: 'second4' }, + }, + }, + }, + }); + expect(updated).toMatchObject({ + Fields: expect.arrayContaining([ + expect.objectContaining({ name: 'first3' }), + expect.objectContaining({ name: 'second4' }), + expect.objectContaining({ name: 'third1' }), + ]), + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1123.test.ts b/tests/regression/test/v2-migrated/issue-1123.test.ts new file mode 100644 index 00000000..3c1cb4d0 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1123.test.ts @@ -0,0 +1,43 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1123', async () => { + const db = await createPolicyTestClient( + ` +model Content { + id String @id @default(cuid()) + published Boolean @default(false) + contentType String + likes Like[] + @@delegate(contentType) + @@allow('all', true) +} + +model Post extends Content { + title String +} + +model Image extends Content { + url String +} + +model Like { + id String @id @default(cuid()) + content Content @relation(fields: [contentId], references: [id]) + contentId String + @@allow('all', true) +} + `, + ); + + await db.post.create({ + data: { + title: 'a post', + likes: { create: {} }, + }, + }); + + await expect(db.content.findFirst({ include: { _count: { select: { likes: true } } } })).resolves.toMatchObject({ + _count: { likes: 1 }, + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1135.test.ts b/tests/regression/test/v2-migrated/issue-1135.test.ts new file mode 100644 index 00000000..41df934f --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1135.test.ts @@ -0,0 +1,76 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1135', async () => { + const db = await createTestClient( + ` +model Attachment { + id String @id @default(cuid()) + url String + myEntityId String + myEntity Entity @relation(fields: [myEntityId], references: [id], onUpdate: NoAction) +} + +model Entity { + id String @id @default(cuid()) + name String + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt @default(now()) + + attachments Attachment[] + + type String + @@delegate(type) +} + +model Person extends Entity { + age Int? +} + `, + { + extraSourceFiles: { + 'main.ts': ` +import { ZenStackClient } from '@zenstackhq/runtime'; +import { schema } from './schema'; + +const db = new ZenStackClient(schema, {} as any); + +db.person.create({ + data: { + name: 'test', + attachments: { + create: { + url: 'https://...', + }, + }, + }, +}); + `, + }, + }, + ); + + await expect( + db.person.create({ + data: { + name: 'test', + attachments: { + create: { + url: 'https://...', + }, + }, + }, + include: { attachments: true }, + }), + ).resolves.toMatchObject({ + id: expect.any(String), + name: 'test', + attachments: [ + { + id: expect.any(String), + url: 'https://...', + myEntityId: expect.any(String), + }, + ], + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1149.test.ts b/tests/regression/test/v2-migrated/issue-1149.test.ts new file mode 100644 index 00000000..404c3969 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1149.test.ts @@ -0,0 +1,90 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1149', async () => { + const schema = ` + model User { + id String @id @default(cuid()) + name String + + userRankings UserRanking[] + userFavorites UserFavorite[] + } + + model Entity { + id String @id @default(cuid()) + name String + type String + userRankings UserRanking[] + userFavorites UserFavorite[] + + @@delegate(type) + } + + model Person extends Entity { + } + + model Studio extends Entity { + } + + + model UserRanking { + id String @id @default(cuid()) + rank Int + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction) + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + } + + model UserFavorite { + id String @id @default(cuid()) + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction) + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction) + } + `; + + const db = await createTestClient(schema); + + const user = await db.user.create({ data: { name: 'user' } }); + const person = await db.person.create({ data: { name: 'person' } }); + + await expect( + db.userRanking.createMany({ + data: { + rank: 1, + entityId: person.id, + userId: user.id, + }, + }), + ).resolves.toMatchObject({ count: 1 }); + + await expect( + db.userRanking.createMany({ + data: [ + { + rank: 2, + entityId: person.id, + userId: user.id, + }, + { + rank: 3, + entityId: person.id, + userId: user.id, + }, + ], + }), + ).resolves.toMatchObject({ count: 2 }); + + await expect(db.userRanking.findMany()).resolves.toEqual( + expect.arrayContaining([ + expect.objectContaining({ rank: 1 }), + expect.objectContaining({ rank: 2 }), + expect.objectContaining({ rank: 3 }), + ]), + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1167.test.ts b/tests/regression/test/v2-migrated/issue-1167.test.ts new file mode 100644 index 00000000..9a18c374 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1167.test.ts @@ -0,0 +1,19 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1167', async () => { + await loadSchema( + ` +model FileAsset { + id String @id @default(cuid()) + delegate_type String + @@delegate(delegate_type) + @@map("file_assets") +} + +model ImageAsset extends FileAsset { + @@map("image_assets") +} + `, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1179.test.ts b/tests/regression/test/v2-migrated/issue-1179.test.ts new file mode 100644 index 00000000..b6a21879 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1179.test.ts @@ -0,0 +1,26 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('regression', async () => { + await loadSchema( + ` +type Base { + id String @id @default(uuid()) +} + +model User with Base { + email String + posts Post[] + @@allow('all', auth() == this) +} + +model Post { + id String @id @default(uuid()) + + user User @relation(fields: [userId], references: [id]) + userId String + @@allow('all', auth().id == userId) +} + `, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1235.test.ts b/tests/regression/test/v2-migrated/issue-1235.test.ts new file mode 100644 index 00000000..e5d17b6a --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1235.test.ts @@ -0,0 +1,39 @@ +import { createPolicyTestClient, testLogger } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue 1235', () => { + it('regression1', async () => { + const db = await createPolicyTestClient( + ` +model Post { + id Int @id @default(autoincrement()) + @@deny('post-update', before().id != id) + @@allow('all', true) +} + `, + { log: testLogger }, + ); + + const post = await db.post.create({ data: {} }); + await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).rejects.toThrow( + /updating id fields is not supported/, + ); + }); + + it('regression2', async () => { + const db = await createPolicyTestClient( + ` +model Post { + id Int @id @default(autoincrement()) + @@deny('post-update', before().id != this.id) + @@allow('all', true) +} + `, + ); + + const post = await db.post.create({ data: {} }); + await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).rejects.toThrow( + /updating id fields is not supported/, + ); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1241.test.ts b/tests/regression/test/v2-migrated/issue-1241.test.ts new file mode 100644 index 00000000..bddfa4e8 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1241.test.ts @@ -0,0 +1,84 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { randomBytes } from 'crypto'; +import { expect, it } from 'vitest'; + +it('verifies issue 1241', async () => { + const db = await createPolicyTestClient( + ` +model User { + id String @id @default(uuid()) + todos Todo[] + + @@auth + @@allow('all', true) +} + +model Todo { + id String @id @default(uuid()) + + user_id String + user User @relation(fields: [user_id], references: [id]) + + images File[] @relation("todo_images") + documents File[] @relation("todo_documents") + + @@allow('all', true) +} + +model File { + id String @id @default(uuid()) + s3_key String @unique + label String + + todo_image_id String? + todo_image Todo? @relation("todo_images", fields: [todo_image_id], references: [id]) + + todo_document_id String? + todo_document Todo? @relation("todo_documents", fields: [todo_document_id], references: [id]) + + @@allow('all', true) +} + `, + ); + + const user = await db.$unuseAll().user.create({ + data: {}, + }); + await db.$unuseAll().todo.create({ + data: { + user_id: user.id, + + images: { + create: new Array(3).fill(null).map((_, i) => ({ + s3_key: randomBytes(8).toString('hex'), + label: `img-label-${i + 1}`, + })), + }, + + documents: { + create: new Array(3).fill(null).map((_, i) => ({ + s3_key: randomBytes(8).toString('hex'), + label: `doc-label-${i + 1}`, + })), + }, + }, + }); + + const todo = await db.todo.findFirst({ where: {}, include: { documents: true } }); + await expect( + db.todo.update({ + where: { id: todo.id }, + data: { + documents: { + update: todo.documents.map((doc: any) => { + return { + where: { s3_key: doc.s3_key }, + data: { label: 'updated' }, + }; + }), + }, + }, + include: { documents: true }, + }), + ).toResolveTruthy(); +}); diff --git a/tests/regression/test/v2-migrated/issue-1243.test.ts b/tests/regression/test/v2-migrated/issue-1243.test.ts new file mode 100644 index 00000000..1122f90f --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1243.test.ts @@ -0,0 +1,52 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, it } from 'vitest'; + +describe('Regression for issue 1243', () => { + it('uninheritable fields', async () => { + const schema = ` + model Base { + id String @id @default(cuid()) + type String + foo String + + @@delegate(type) + @@index([foo]) + @@map('base') + @@unique([foo]) + } + + model Item1 extends Base { + x String + } + + model Item2 extends Base { + y String + } + `; + + await createTestClient(schema); + }); + + it('multiple id fields', async () => { + const schema = ` + model Base { + id1 String + id2 String + type String + + @@delegate(type) + @@id([id1, id2]) + } + + model Item1 extends Base { + x String + } + + model Item2 extends Base { + y String + } + `; + + await createTestClient(schema); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1257.test.ts b/tests/regression/test/v2-migrated/issue-1257.test.ts new file mode 100644 index 00000000..38fc799e --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1257.test.ts @@ -0,0 +1,48 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1257', async () => { + await loadSchema( + ` +import "./user" +import "./image" + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +}`, + { + base: ` +type Base { + id Int @id @default(autoincrement()) +} +`, + user: ` +import "./base" +import "./image" + +enum Role { + Admin +} + +model User with Base { + email String @unique + role Role + @@auth +} +`, + image: ` +import "./user" +import "./base" + +model Image with Base { + width Int @default(0) + height Int @default(0) + + @@allow('read', true) + @@allow('all', auth().role == Admin) +} +`, + }, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1265.test.ts b/tests/regression/test/v2-migrated/issue-1265.test.ts new file mode 100644 index 00000000..d97be964 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1265.test.ts @@ -0,0 +1,26 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +// TODO: zod schema support +it.skip('verifies issue 1265', async () => { + const { zodSchemas } = await createTestClient( + ` + model User { + id String @id @default(uuid()) + posts Post[] + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String @default('xyz') + userId String @default(auth().id) + user User @relation(fields: [userId], references: [id]) + @@allow('all', true) + } + `, + ); + + expect(zodSchemas.models.PostCreateSchema.safeParse({ title: 'Post 1' }).success).toBeTruthy(); + expect(zodSchemas.input.PostInputSchema.create.safeParse({ data: { title: 'Post 1' } }).success).toBeTruthy(); +}); diff --git a/tests/regression/test/v2-migrated/issue-1271.test.ts b/tests/regression/test/v2-migrated/issue-1271.test.ts new file mode 100644 index 00000000..d6acb7e6 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1271.test.ts @@ -0,0 +1,188 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1271', async () => { + const db = await createPolicyTestClient( + ` +model User { + id String @id @default(uuid()) + + @@auth + @@allow('all', true) +} + +model Test { + id String @id @default(uuid()) + linkingTable LinkingTable[] + key String @default('test') + locale String @default('EN') + + @@unique([key, locale]) + @@allow("all", true) +} + +model LinkingTable { + test_id String + test Test @relation(fields: [test_id], references: [id]) + + another_test_id String + another_test AnotherTest @relation(fields: [another_test_id], references: [id]) + + @@id([test_id, another_test_id]) + @@allow("all", true) +} + +model AnotherTest { + id String @id @default(uuid()) + status String + linkingTable LinkingTable[] + + @@allow("all", true) +} + `, + ); + + const test = await db.test.create({ + data: { + key: 'test1', + }, + }); + const anotherTest = await db.anotherTest.create({ + data: { + status: 'available', + }, + }); + + const updated = await db.test.upsert({ + where: { + key_locale: { + key: test.key, + locale: test.locale, + }, + }, + create: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + update: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + include: { + linkingTable: true, + }, + }); + + expect(updated.linkingTable).toHaveLength(1); + expect(updated.linkingTable[0]).toMatchObject({ another_test_id: anotherTest.id }); + + const test2 = await db.test.upsert({ + where: { + key_locale: { + key: 'test2', + locale: 'locale2', + }, + }, + create: { + key: 'test2', + locale: 'locale2', + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + update: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + include: { + linkingTable: true, + }, + }); + expect(test2).toMatchObject({ key: 'test2', locale: 'locale2' }); + expect(test2.linkingTable).toHaveLength(1); + expect(test2.linkingTable[0]).toMatchObject({ another_test_id: anotherTest.id }); + + const linkingTable = test2.linkingTable[0]; + + // connectOrCreate: connect case + const test3 = await db.test.create({ + data: { + key: 'test3', + locale: 'locale3', + }, + }); + console.log('test3 created:', test3); + const updated2 = await db.linkingTable.update({ + where: { + test_id_another_test_id: { + test_id: linkingTable.test_id, + another_test_id: linkingTable.another_test_id, + }, + }, + data: { + test: { + connectOrCreate: { + where: { + key_locale: { + key: test3.key, + locale: test3.locale, + }, + }, + create: { + key: 'test4', + locale: 'locale4', + }, + }, + }, + another_test: { connect: { id: anotherTest.id } }, + }, + include: { test: true }, + }); + expect(updated2).toMatchObject({ + test: expect.objectContaining({ key: 'test3', locale: 'locale3' }), + another_test_id: anotherTest.id, + }); + + // connectOrCreate: create case + const updated3 = await db.linkingTable.update({ + where: { + test_id_another_test_id: { + test_id: updated2.test_id, + another_test_id: updated2.another_test_id, + }, + }, + data: { + test: { + connectOrCreate: { + where: { + key_locale: { + key: 'test4', + locale: 'locale4', + }, + }, + create: { + key: 'test4', + locale: 'locale4', + }, + }, + }, + another_test: { connect: { id: anotherTest.id } }, + }, + include: { test: true }, + }); + expect(updated3).toMatchObject({ + test: expect.objectContaining({ key: 'test4', locale: 'locale4' }), + another_test_id: anotherTest.id, + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1381.test.ts b/tests/regression/test/v2-migrated/issue-1381.test.ts new file mode 100644 index 00000000..3f168270 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1381.test.ts @@ -0,0 +1,55 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1381', async () => { + await loadSchema( + ` +enum MemberRole { + owner + admin +} + +enum SpaceType { + contractor + public + private +} + +model User { + id String @id @default(cuid()) + name String? + email String? @unique @lower + memberships Membership[] +} + +model Membership { + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + spaceId String + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + role MemberRole + @@id([userId, spaceId]) +} + +model Space { + id String @id @default(cuid()) + name String + type SpaceType @default(private) + memberships Membership[] + options Option[] +} + +model Option { + id String @id @default(cuid()) + name String? + spaceId String? + space Space? @relation(fields: [spaceId], references: [id], onDelete: SetNull) + + @@allow("post-update", + space.type in [contractor, public] && + space.memberships?[space.type in [contractor, public] && auth() == user && role in [owner, admin]] + ) +} + `, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1388.test.ts b/tests/regression/test/v2-migrated/issue-1388.test.ts new file mode 100644 index 00000000..ab3f4701 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1388.test.ts @@ -0,0 +1,32 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1388', async () => { + await loadSchema( + ` +import './auth' +import './post' + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} +`, + { + auth: ` +model User { + id String @id @default(cuid()) + role String +} + `, + post: ` +model Post { + id String @id @default(nanoid(6)) + title String + @@deny('all', auth() == null) + @@allow('all', auth().id == 'user1') +} + `, + }, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1410.test.ts b/tests/regression/test/v2-migrated/issue-1410.test.ts new file mode 100644 index 00000000..c4f7c2db --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1410.test.ts @@ -0,0 +1,143 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1410', async () => { + const db = await createTestClient( + ` + model Drink { + id Int @id @default(autoincrement()) + slug String @unique + + manufacturer_id Int + manufacturer Manufacturer @relation(fields: [manufacturer_id], references: [id]) + + type String + + name String @unique + description String + abv Float + image String? + + gluten Boolean + lactose Boolean + organic Boolean + + containers Container[] + + @@delegate(type) + + @@allow('all', true) + } + + model Beer extends Drink { + style_id Int + style BeerStyle @relation(fields: [style_id], references: [id]) + + ibu Float? + + @@allow('all', true) + } + + model BeerStyle { + id Int @id @default(autoincrement()) + + name String @unique + color String + + beers Beer[] + + @@allow('all', true) + } + + model Wine extends Drink { + style_id Int + style WineStyle @relation(fields: [style_id], references: [id]) + + heavy_score Int? + tannine_score Int? + dry_score Int? + fresh_score Int? + notes String? + + @@allow('all', true) + } + + model WineStyle { + id Int @id @default(autoincrement()) + + name String @unique + color String + + wines Wine[] + + @@allow('all', true) + } + + model Soda extends Drink { + carbonated Boolean + + @@allow('all', true) + } + + model Cocktail extends Drink { + mix Boolean + + @@allow('all', true) + } + + model Container { + barcode String @id + + drink_id Int + drink Drink @relation(fields: [drink_id], references: [id]) + + type String + volume Int + portions Int? + + inventory Int @default(0) + + @@allow('all', true) + } + + model Manufacturer { + id Int @id @default(autoincrement()) + + country_id String + country Country @relation(fields: [country_id], references: [code]) + + name String @unique + description String? + image String? + + drinks Drink[] + + @@allow('all', true) + } + + model Country { + code String @id + name String + + manufacturers Manufacturer[] + + @@allow('all', true) + } + `, + ); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { NOT: { gluten: true } }, + }); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { AND: [{ gluten: true }, { abv: { gt: 50 } }] }, + }); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { OR: [{ AND: [{ NOT: { gluten: true } }] }, { abv: { gt: 50 } }] }, + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1415.test.ts b/tests/regression/test/v2-migrated/issue-1415.test.ts new file mode 100644 index 00000000..0ebbf7e9 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1415.test.ts @@ -0,0 +1,21 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1415', async () => { + await loadSchema( + ` +model User { + id String @id @default(cuid()) + prices Price[] +} + +model Price { + id String @id @default(cuid()) + owner User @relation(fields: [ownerId], references: [id]) + ownerId String @default(auth().id) + priceType String + @@delegate(priceType) +} + `, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1416.test.ts b/tests/regression/test/v2-migrated/issue-1416.test.ts new file mode 100644 index 00000000..461ff068 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1416.test.ts @@ -0,0 +1,36 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { it } from 'vitest'; + +it('verifies issue 1416', async () => { + await loadSchema( + ` +model User { + id String @id @default(cuid()) + role String +} + +model Price { + id String @id @default(nanoid(6)) + entity Entity? @relation(fields: [entityId], references: [id]) + entityId String? + priceType String + @@delegate(priceType) +} + +model MyPrice extends Price { + foo String +} + +model Entity { + id String @id @default(nanoid(6)) + price Price[] + type String + @@delegate(type) +} + +model MyEntity extends Entity { + foo String +} + `, + ); +}); diff --git a/tests/regression/test/v2-migrated/issue-1427.test.ts b/tests/regression/test/v2-migrated/issue-1427.test.ts new file mode 100644 index 00000000..f111288f --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1427.test.ts @@ -0,0 +1,40 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1427', async () => { + const db = await createTestClient( + ` +model User { + id String @id @default(cuid()) + name String + profile Profile? + @@allow('all', true) +} + +model Profile { + id String @id @default(cuid()) + user User @relation(fields: [userId], references: [id]) + userId String @unique + @@allow('all', true) +} + `, + ); + + await db.$unuseAll().user.create({ + data: { + name: 'John', + profile: { + create: {}, + }, + }, + }); + + const found = await db.user.findFirst({ + select: { + id: true, + name: true, + profile: false, + }, + }); + expect(found.profile).toBeUndefined(); +}); diff --git a/tests/regression/test/v2-migrated/issue-1451.test.ts b/tests/regression/test/v2-migrated/issue-1451.test.ts new file mode 100644 index 00000000..f6891819 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1451.test.ts @@ -0,0 +1,56 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +// TODO: field-level policy support +it.skip('verifies issue 1451', async () => { + const db = await createTestClient( + ` +model User { + id String @id + memberships Membership[] +} + +model Space { + id String @id + memberships Membership[] +} + +model Membership { + userId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + spaceId String + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + + role String @deny("update", auth() == user) + employeeReference String? @deny("read, update", space.memberships?[auth() == user && !(role in ['owner', 'admin'])]) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@id([userId, spaceId]) + @@allow('all', true) +} + `, + ); + + await db.$unuseAll().user.create({ + data: { id: '1' }, + }); + + await db.$unuseAll().space.create({ + data: { id: '1' }, + }); + + await db.$unuseAll().membership.create({ + data: { + user: { connect: { id: '1' } }, + space: { connect: { id: '1' } }, + role: 'foo', + employeeReference: 'xyz', + }, + }); + + const r = await db.membership.findMany(); + expect(r).toHaveLength(1); + expect(r[0].employeeReference).toBeUndefined(); +}); diff --git a/tests/regression/test/v2-migrated/issue-1454.test.ts b/tests/regression/test/v2-migrated/issue-1454.test.ts new file mode 100644 index 00000000..d78b74fb --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1454.test.ts @@ -0,0 +1,116 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue 1454', () => { + it('regression1', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + sensitiveInformation String + username String + + purchases Purchase[] + + @@allow('read', auth() == this) +} + +model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) +} + `, + ); + + await db.$unuseAll().user.create({ + data: { username: 'user1', sensitiveInformation: 'sensitive', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(0); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(0); + }); + + // TODO: field-level policy support + it.skip('regression2', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + username String @allow('read', false) + + purchases Purchase[] + + @@allow('read', true) +} + +model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) +} + `, + ); + + const user = await db.$unuseAll().user.create({ + data: { username: 'user1', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { id: user.id } } })).resolves.toHaveLength(1); + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(0); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(0); + }); + + // TODO: field-level policy support + it.skip('regression3', async () => { + const db = await createPolicyTestClient( + ` +model User { + id Int @id @default(autoincrement()) + sensitiveInformation String + username String @allow('read', true, true) + + purchases Purchase[] + + @@allow('read', auth() == this) +} + +model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) +} + `, + ); + + await db.$unuseAll().user.create({ + data: { username: 'user1', sensitiveInformation: 'sensitive', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(1); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(1); + await expect( + db.purchase.findMany({ where: { user: { sensitiveInformation: 'sensitive' } } }), + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ where: { user: { is: { sensitiveInformation: 'sensitive' } } } }), + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ where: { user: { username: 'user1', sensitiveInformation: 'sensitive' } } }), + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ + where: { OR: [{ user: { username: 'user1' } }, { user: { sensitiveInformation: 'sensitive' } }] }, + }), + ).resolves.toHaveLength(1); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1466.test.ts b/tests/regression/test/v2-migrated/issue-1466.test.ts new file mode 100644 index 00000000..932369f7 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1466.test.ts @@ -0,0 +1,195 @@ +import { createTestClient, loadSchema } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue 1466', () => { + it('regression1', async () => { + const db = await createTestClient( + ` +model UserLongLongLongLongLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique +} + +model AssetLongLongLongLongLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongLongLongLongLongName? + assetType String + + @@delegate(assetType) +} + +model VideoLongLongLongLongLongLongLongLongName extends AssetLongLongLongLongLongLongLongLongName { + duration Int +} + `, + { + usePrismaPush: true, + }, + ); + + const video = await db.VideoLongLongLongLongLongLongLongLongName.create({ + data: { duration: 100 }, + }); + + await db.UserLongLongLongLongLongLongLongLongName.create({ + data: { + asset: { connect: { id: video.id } }, + }, + }); + + const userWithAsset = await db.UserLongLongLongLongLongLongLongLongName.findFirst({ + include: { asset: true }, + }); + + expect(userWithAsset).toMatchObject({ + asset: { assetType: 'VideoLongLongLongLongLongLongLongLongName', duration: 100 }, + }); + }); + + it('regression2', async () => { + const db = await createTestClient( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int + + @@unique([assetId]) + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName extends AssetLongLongLongLongName { + duration Int + } + `, + { + usePrismaPush: true, + }, + ); + + const video = await db.VideoLongLongLongLongName.create({ + data: { duration: 100 }, + }); + + await db.UserLongLongLongLongName.create({ + data: { + asset: { connect: { id: video.id } }, + }, + }); + + const userWithAsset = await db.UserLongLongLongLongName.findFirst({ + include: { asset: true }, + }); + + expect(userWithAsset).toMatchObject({ + asset: { assetType: 'VideoLongLongLongLongName', duration: 100 }, + }); + }); + + it('regression3', async () => { + await loadSchema( + ` +model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique +} + +model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) +} + +model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int +} + +model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String +} + `, + ); + }); + + it('regression4', async () => { + await loadSchema( + ` +model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique +} + +model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) +} + +model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int +} + +model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String +} + `, + ); + }); + + it('regression5', async () => { + await loadSchema( + ` +model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique(map: 'assetId_unique') +} + +model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) +} + +model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int +} + +model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String +} + `, + ); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1467.test.ts b/tests/regression/test/v2-migrated/issue-1467.test.ts new file mode 100644 index 00000000..042ef8b6 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1467.test.ts @@ -0,0 +1,44 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1467', async () => { + const db = await createTestClient( + ` + model User { + id Int @id @default(autoincrement()) + type String + } + + model Container { + id Int @id @default(autoincrement()) + drink Drink @relation(fields: [drinkId], references: [id]) + drinkId Int + } + + model Drink { + id Int @id @default(autoincrement()) + name String @unique + containers Container[] + type String + + @@delegate(type) + } + + model Beer extends Drink { + } + `, + ); + + await db.beer.create({ + data: { id: 1, name: 'Beer1' }, + }); + + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + + const beers = await db.beer.findFirst({ + select: { id: true, name: true, _count: { select: { containers: true } } }, + orderBy: { name: 'asc' }, + }); + expect(beers).toMatchObject({ _count: { containers: 2 } }); +}); diff --git a/tests/regression/test/v2-migrated/issue-1483.test.ts b/tests/regression/test/v2-migrated/issue-1483.test.ts new file mode 100644 index 00000000..8802a312 --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1483.test.ts @@ -0,0 +1,67 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { expect, it } from 'vitest'; + +it('verifies issue 1483', async () => { + const db = await createTestClient( + ` +model User { + @@auth + id String @id + edits Edit[] + @@allow('all', true) +} + +model Entity { + + id String @id @default(cuid()) + name String + edits Edit[] + + type String + @@delegate(type) + + @@allow('all', true) +} + +model Person extends Entity { +} + +model Edit { + id String @id @default(cuid()) + + authorId String? + author User? @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + @@allow('all', true) +} + `, + ); + + await db.edit.deleteMany({}); + await db.person.deleteMany({}); + await db.user.deleteMany({}); + + const person = await db.person.create({ + data: { + name: 'test', + }, + }); + + await db.edit.create({ + data: { + entityId: person.id, + }, + }); + + await expect( + db.edit.findMany({ + include: { + author: true, + entity: true, + }, + }), + ).resolves.toHaveLength(1); +}); diff --git a/tests/regression/test/v2-migrated/issue-1487.test.ts b/tests/regression/test/v2-migrated/issue-1487.test.ts new file mode 100644 index 00000000..acf39ead --- /dev/null +++ b/tests/regression/test/v2-migrated/issue-1487.test.ts @@ -0,0 +1,52 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import Decimal from 'decimal.js'; +import { expect, it } from 'vitest'; + +it('verifies issue 1487', async () => { + const db = await createTestClient( + ` +model LineItem { + id Int @id @default(autoincrement()) + price Decimal + createdAt DateTime @default(now()) + + orderId Int + order Order @relation(fields: [orderId], references: [id]) +} +model Order extends BaseType { + total Decimal + createdAt DateTime @default(now()) + lineItems LineItem[] +} +model BaseType { + id Int @id @default(autoincrement()) + entityType String + + @@delegate(entityType) +} + `, + ); + + const create = await db.Order.create({ + data: { + total: new Decimal(100_100.99), + lineItems: { create: [{ price: 90_000.66 }, { price: 20_100.33 }] }, + }, + }); + + const order = await db.Order.findFirst({ where: { id: create.id }, include: { lineItems: true } }); + expect(Decimal.isDecimal(order.total)).toBe(true); + expect(order.createdAt instanceof Date).toBe(true); + expect(order.total.toString()).toEqual('100100.99'); + order.lineItems.forEach((item: any) => { + expect(Decimal.isDecimal(item.price)).toBe(true); + expect(item.price.toString()).not.toEqual('[object Object]'); + }); + + const lineItems = await db.LineItem.findMany(); + lineItems.forEach((item: any) => { + expect(item.createdAt instanceof Date).toBe(true); + expect(Decimal.isDecimal(item.price)).toBe(true); + expect(item.price.toString()).not.toEqual('[object Object]'); + }); +}); diff --git a/tests/regression/test/v2-migrated/issue-764.test.ts b/tests/regression/test/v2-migrated/issue-764.test.ts index 404616fb..b34f1bac 100644 --- a/tests/regression/test/v2-migrated/issue-764.test.ts +++ b/tests/regression/test/v2-migrated/issue-764.test.ts @@ -24,7 +24,7 @@ model Post { `, ); - const user = await db.$unuseAll().user.create({ + const user = await db.user.create({ data: { name: 'Me' }, });