Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.13",
"@types/tmp": "^0.2.6",
"@types/tmp": "catalog:",
"@zenstackhq/eslint-config": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"@zenstackhq/testtools": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"better-sqlite3": "^11.8.1",
"tmp": "^0.2.3"
"tmp": "catalog:"
}
}
7 changes: 7 additions & 0 deletions packages/cli/src/actions/action-utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { loadDocument } from '@zenstackhq/language';
import { isDataSource } from '@zenstackhq/language/ast';
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'node:fs';
Expand Down Expand Up @@ -41,6 +42,9 @@ export async function loadSchemaDocument(schemaFile: string) {
});
throw new CliError('Failed to load schema');
}
loadResult.warnings.forEach((warn) => {
console.warn(colors.yellow(warn));
});
return loadResult.model;
}

Expand All @@ -54,6 +58,9 @@ export function handleSubProcessError(err: unknown) {

export async function generateTempPrismaSchema(zmodelPath: string, folder?: string) {
const model = await loadSchemaDocument(zmodelPath);
if (!model.declarations.some(isDataSource)) {
throw new CliError('Schema must define a datasource');
}
const prismaSchema = await new PrismaSchemaGenerator(model).generate();
if (!folder) {
folder = path.dirname(zmodelPath);
Expand Down
37 changes: 36 additions & 1 deletion packages/cli/test/ts-schema-gen.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { ExpressionUtils } from '@zenstackhq/runtime/schema';
import { generateTsSchema } from '@zenstackhq/testtools';
import { createTestProject, generateTsSchema, generateTsSchemaInPlace } from '@zenstackhq/testtools';
import fs from 'node:fs';
import path from 'node:path';
import { describe, expect, it } from 'vitest';

describe('TypeScript schema generation tests', () => {
Expand Down Expand Up @@ -325,4 +327,37 @@ model User extends Base {
},
});
});

it('merges all declarations from imported modules', async () => {
const workDir = createTestProject();
fs.writeFileSync(
path.join(workDir, 'a.zmodel'),
`
enum Role {
Admin
User
}
`,
);
fs.writeFileSync(
path.join(workDir, 'b.zmodel'),
`
import './a'

datasource db {
provider = 'sqlite'
url = 'file:./test.db'
}

model User {
id Int @id
role Role
}
`,
);

const { schema } = await generateTsSchemaInPlace(path.join(workDir, 'b.zmodel'));
expect(schema.enums).toMatchObject({ Role: expect.any(Object) });
expect(schema.models).toMatchObject({ User: expect.any(Object) });
});
});
4 changes: 3 additions & 1 deletion packages/language/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
"@zenstackhq/eslint-config": "workspace:*",
"@zenstackhq/typescript-config": "workspace:*",
"@zenstackhq/common-helpers": "workspace:*",
"langium-cli": "catalog:"
"langium-cli": "catalog:",
"tmp": "catalog:",
"@types/tmp": "catalog:"
},
"volta": {
"node": "18.19.1",
Expand Down
104 changes: 101 additions & 3 deletions packages/language/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { URI } from 'langium';
import { isAstNode, URI, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium';
import { NodeFileSystem } from 'langium/node';
import fs from 'node:fs';
import path from 'node:path';
import { fileURLToPath } from 'node:url';
import type { Model } from './ast';
import { isDataSource, type AstNode, type Model } from './ast';
import { STD_LIB_MODULE_NAME } from './constants';
import { createZModelLanguageServices } from './module';
import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils';

export function createZModelServices() {
return createZModelLanguageServices(NodeFileSystem);
Expand Down Expand Up @@ -60,8 +61,12 @@ export async function loadDocument(
const langiumDocuments = services.shared.workspace.LangiumDocuments;
const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName)));

// load imports
const importedURIs = await loadImports(document, langiumDocuments);
const importedDocuments = await Promise.all(importedURIs.map((uri) => langiumDocuments.getOrCreateDocument(uri)));

// build the document together with standard library, plugin modules, and imported documents
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document], {
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
validation: true,
});

Expand Down Expand Up @@ -95,11 +100,104 @@ export async function loadDocument(
};
}

const model = document.parseResult.value as Model;

// merge all declarations into the main document
const imported = mergeImportsDeclarations(langiumDocuments, model);

// remove imported documents
imported.forEach((model) => {
langiumDocuments.deleteDocument(model.$document!.uri);
services.shared.workspace.IndexManager.remove(model.$document!.uri);
});

// extra validation after merging imported declarations
const additionalErrors = validationAfterImportMerge(model);
if (additionalErrors.length > 0) {
return {
success: false,
errors: additionalErrors,
warnings,
};
}

return {
success: true,
model: document.parseResult.value as Model,
warnings,
};
}

async function loadImports(
document: LangiumDocument<AstNode>,
documents: LangiumDocuments,
uris: Set<string> = new Set(),
) {
const uriString = document.uri.toString();
if (!uris.has(uriString)) {
uris.add(uriString);
const model = document.parseResult.value as Model;
for (const imp of model.imports) {
const importedModel = resolveImport(documents, imp);
if (importedModel) {
const importedDoc = getDocument(importedModel);
await loadImports(importedDoc, documents, uris);
}
}
}
return Array.from(uris)
.filter((x) => uriString != x)
.map((e) => URI.parse(e));
}

function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) {
const importedModels = resolveTransitiveImports(documents, model);

const importedDeclarations = importedModels.flatMap((m) => m.declarations);
model.declarations.push(...importedDeclarations);

// remove import directives
model.imports = [];

// fix $containerIndex
linkContentToContainer(model);

return importedModels;
}

function linkContentToContainer(node: AstNode): void {
for (const [name, value] of Object.entries(node)) {
if (!name.startsWith('$')) {
if (Array.isArray(value)) {
value.forEach((item, index) => {
if (isAstNode(item)) {
(item as Mutable<AstNode>).$container = node;
(item as Mutable<AstNode>).$containerProperty = name;
(item as Mutable<AstNode>).$containerIndex = index;
}
});
} else if (isAstNode(value)) {
(value as Mutable<AstNode>).$container = node;
(value as Mutable<AstNode>).$containerProperty = name;
}
}
}
}

function validationAfterImportMerge(model: Model) {
const errors: string[] = [];
const dataSources = model.declarations.filter((d) => isDataSource(d));
if (dataSources.length > 1) {
errors.push('Validation error: Multiple datasource declarations are not allowed');
}

// at most one `@@auth` model
const decls = getDataModelAndTypeDefs(model, true);
const authDecls = decls.filter((d) => hasAttribute(d, '@@auth'));
if (authDecls.length > 1) {
errors.push('Validation error: Multiple `@@auth` declarations are not allowed');
}
return errors;
}

export * from './module';
27 changes: 26 additions & 1 deletion packages/language/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { invariant } from '@zenstackhq/common-helpers';
import { AstUtils, URI, type AstNode, type LangiumDocuments, type Reference } from 'langium';
import { AstUtils, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Reference } from 'langium';
import fs from 'node:fs';
import path from 'path';
import { STD_LIB_MODULE_NAME, type ExpressionContext } from './constants';
Expand Down Expand Up @@ -577,3 +577,28 @@ export function getAllAttributes(
attributes.push(...decl.attributes);
return attributes;
}

/**
* Retrieve the document in which the given AST node is contained. A reference to the document is
* usually held by the root node of the AST.
*
* @throws an error if the node is not contained in a document.
*/
export function getDocument<T extends AstNode = AstNode>(node: AstNode): LangiumDocument<T> {
const rootNode = findRootNode(node);
const result = rootNode.$document;
if (!result) {
throw new Error('AST node has no document.');
}
return result as LangiumDocument<T>;
}

/**
* Returns the root node of the given AST node by following the `$container` references.
*/
export function findRootNode(node: AstNode): AstNode {
while (node.$container) {
node = node.$container;
}
return node;
}
2 changes: 1 addition & 1 deletion packages/language/src/validators/schema-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ export default class SchemaValidator implements AstValidator<Model> {
private validateImports(model: Model, accept: ValidationAcceptor) {
model.imports.forEach((imp) => {
const importedModel = resolveImport(this.documents, imp);
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
if (!importedModel) {
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
accept('error', `Cannot find model file ${importPath}`, {
node: imp,
});
Expand Down
107 changes: 107 additions & 0 deletions packages/language/test/import.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import { invariant } from '@zenstackhq/common-helpers';
import fs from 'node:fs';
import path from 'node:path';
import tmp from 'tmp';
import { describe, expect, it } from 'vitest';
import { loadDocument } from '../src';
import { DataModel, isDataModel } from '../src/ast';

describe('Import tests', () => {
it('merges declarations', async () => {
const { name } = tmp.dirSync();
fs.writeFileSync(
path.join(name, 'a.zmodel'),
`
model A {
id Int @id
name String
}
`,
);
fs.writeFileSync(
path.join(name, 'b.zmodel'),
`
import './a'
model B {
id Int @id
}
`,
);

const model = await expectLoaded(path.join(name, 'b.zmodel'));
expect(model.declarations.filter(isDataModel)).toHaveLength(2);
expect(model.imports).toHaveLength(0);
});

it('resolves imported symbols', async () => {
const { name } = tmp.dirSync();
fs.writeFileSync(
path.join(name, 'a.zmodel'),
`
enum Role {
Admin
User
}
`,
);
fs.writeFileSync(
path.join(name, 'b.zmodel'),
`
import './a'
model User {
id Int @id
role Role
}
`,
);

const model = await expectLoaded(path.join(name, 'b.zmodel'));
expect((model.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('Role');
});

it('supports cyclic imports', async () => {
const { name } = tmp.dirSync();
fs.writeFileSync(
path.join(name, 'a.zmodel'),
`
import './b'
model A {
id Int @id
b B?
}
`,
);
fs.writeFileSync(
path.join(name, 'b.zmodel'),
`
}
`,
);
fs.writeFileSync(
path.join(name, 'b.zmodel'),
`
import './a'
model B {
id Int @id
a A @relation(fields: [aId], references: [id])
aId Int @unique
}
`,
);

const modelB = await expectLoaded(path.join(name, 'b.zmodel'));
expect((modelB.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('A');
const modelA = await expectLoaded(path.join(name, 'a.zmodel'));
expect((modelA.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('B');
});

async function expectLoaded(file: string) {
const result = await loadDocument(file);
if (!result.success) {
console.error('Errors:', result.errors);
throw new Error(`Failed to load document from ${file}`);
}
invariant(result.success);
return result.model;
}
});
Loading