Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/language/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export class DocumentLoadError extends Error {

export async function loadDocument(
fileName: string,
pluginModelFiles: string[] = [],
additionalModelFiles: string[] = [],
): Promise<
{ success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] }
> {
Expand Down Expand Up @@ -50,9 +50,9 @@ export async function loadDocument(
URI.file(path.resolve(path.join(_dirname, '../res', STD_LIB_MODULE_NAME))),
);

// load plugin model files
// load additional model files
const pluginDocs = await Promise.all(
pluginModelFiles.map((file) =>
additionalModelFiles.map((file) =>
services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))),
),
);
Expand Down
14 changes: 14 additions & 0 deletions packages/plugins/policy/src/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
// --- Post mutation work ---

if (hasPostUpdatePolicies && result.rows.length > 0) {
// verify if before-update rows and post-update rows still id-match
if (beforeUpdateInfo) {
invariant(beforeUpdateInfo.rows.length === result.rows.length);
const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel);
for (const postRow of result.rows) {
const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
if (!beforeRow) {
throw new QueryError(
'Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.',
);
}
}
}

// entities updated filter
const idConditions = this.buildIdConditions(mutationModel, result.rows);

Expand Down
16 changes: 12 additions & 4 deletions packages/runtime/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}

let result = query;

const buildFieldRef = (model: string, field: string, modelAlias: string) => {
const fieldDef = requireField(this.schema, model, field);
const eb = expressionBuilder<any, any>();
return fieldDef.originModel
? this.fieldRef(fieldDef.originModel, field, eb, fieldDef.originModel)
: this.fieldRef(model, field, eb, modelAlias);
};

enumerate(orderBy).forEach((orderBy) => {
for (const [field, value] of Object.entries<any>(orderBy)) {
if (!value) {
Expand All @@ -838,8 +847,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [k, v] of Object.entries<SortOrder>(value)) {
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
result = result.orderBy(
(eb) =>
aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS),
(eb) => aggregate(eb, buildFieldRef(model, k, modelAlias), field as AGGREGATE_OPERATORS),
sql.raw(this.negateSort(v, negated)),
);
}
Expand All @@ -852,7 +860,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
for (const [k, v] of Object.entries<string>(value)) {
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
result = result.orderBy(
(eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)),
(eb) => eb.fn.count(buildFieldRef(model, k, modelAlias)),
sql.raw(this.negateSort(v, negated)),
);
}
Expand All @@ -865,7 +873,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const fieldDef = requireField(this.schema, model, field);

if (!fieldDef.relation) {
const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias);
const fieldRef = buildFieldRef(model, field, modelAlias);
if (value === 'asc' || value === 'desc') {
result = result.orderBy(fieldRef, this.negateSort(value, negated));
} else if (
Expand Down
19 changes: 17 additions & 2 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,29 @@ export async function generateTsSchemaInPlace(schemaPath: string) {
return compileAndLoad(workDir);
}

export async function loadSchema(schema: string) {
export async function loadSchema(schema: string, additionalSchemas?: Record<string, string>) {
if (!schema.includes('datasource ')) {
schema = `${makePrelude('sqlite')}\n\n${schema}`;
}

// create a temp folder
const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'zenstack-schema'));

// create a temp file
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
const tempFile = path.join(tempDir, `schema.zmodel`);
fs.writeFileSync(tempFile, schema);

if (additionalSchemas) {
for (const [fileName, content] of Object.entries(additionalSchemas)) {
let name = fileName;
if (!name.endsWith('.zmodel')) {
name += '.zmodel';
}
const filePath = path.join(tempDir, name);
fs.writeFileSync(filePath, content);
}
}

const r = await loadDocument(tempFile);
expect(r).toSatisfy(
(r) => r.success,
Expand Down
49 changes: 49 additions & 0 deletions tests/regression/test/v2-migrated/issue-1014.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

// TODO: field-level policy support
describe.skip('Regression for issue 1014', () => {
it('update', async () => {
const db = await createPolicyTestClient(
`
model User {
id Int @id() @default(autoincrement())
name String
posts Post[]
}

model Post {
id Int @id() @default(autoincrement())
title String
content String?
author User? @relation(fields: [authorId], references: [id])
authorId Int? @allow('update', true, true)

@@allow('read', true)
}
`,
);

const user = await db.$unuseAll().user.create({ data: { name: 'User1' } });
const post = await db.$unuseAll().post.create({ data: { title: 'Post1' } });
await expect(db.post.update({ where: { id: post.id }, data: { authorId: user.id } })).toResolveTruthy();
});

it('read', async () => {
const db = await createPolicyTestClient(
`
model Post {
id Int @id() @default(autoincrement())
title String @allow('read', true, true)
content String
}
`,
);

const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } });
await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull();
await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({
title: 'Post1',
});
});
});
52 changes: 52 additions & 0 deletions tests/regression/test/v2-migrated/issue-1058.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { createTestClient } from '@zenstackhq/testtools';
import { it } from 'vitest';

it('verifies issue 1058', async () => {
const schema = `
model User {
id String @id @default(cuid())
name String
userRankings UserRanking[]
userFavorites UserFavorite[]
}
model Entity {
id String @id @default(cuid())
name String
type String
userRankings UserRanking[]
userFavorites UserFavorite[]
@@delegate(type)
}
model Person extends Entity {
}
model Studio extends Entity {
}
model UserRanking {
id String @id @default(cuid())
rank Int
entityId String
entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction)
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction)
}
model UserFavorite {
id String @id @default(cuid())
entityId String
entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction)
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction)
}
`;

await createTestClient(schema, { provider: 'postgresql' });
});
53 changes: 53 additions & 0 deletions tests/regression/test/v2-migrated/issue-1078.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

describe('Regression for issue 1078', () => {
it('regression1', async () => {
const db = await createPolicyTestClient(
`
model Counter {
id String @id

name String
value Int

@@validate(value >= 0)
@@allow('all', true)
}
`,
);

await expect(
db.counter.create({
data: { id: '1', name: 'It should create', value: 1 },
}),
).toResolveTruthy();

//! This query fails validation
await expect(
db.counter.update({
where: { id: '1' },
data: { name: 'It should update' },
}),
).toResolveTruthy();
});

// TODO: field-level policy support
it.skip('regression2', async () => {
const db = await createPolicyTestClient(
`
model Post {
id Int @id() @default(autoincrement())
title String @allow('read', true, true)
content String
}
`,
);

const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } });
await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull();
await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({
title: 'Post1',
});
});
});
Loading