Skip to content

Commit 4779cd2

Browse files
committed
feat: implement zmodel import
1 parent 7c7183d commit 4779cd2

File tree

14 files changed

+394
-19
lines changed

14 files changed

+394
-19
lines changed

packages/cli/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@
4444
},
4545
"devDependencies": {
4646
"@types/better-sqlite3": "^7.6.13",
47-
"@types/tmp": "^0.2.6",
47+
"@types/tmp": "catalog:",
4848
"@zenstackhq/eslint-config": "workspace:*",
4949
"@zenstackhq/runtime": "workspace:*",
5050
"@zenstackhq/testtools": "workspace:*",
5151
"@zenstackhq/typescript-config": "workspace:*",
5252
"better-sqlite3": "^11.8.1",
53-
"tmp": "^0.2.3"
53+
"tmp": "catalog:"
5454
}
5555
}

packages/cli/src/actions/action-utils.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { loadDocument } from '@zenstackhq/language';
2+
import { isDataSource } from '@zenstackhq/language/ast';
23
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
34
import colors from 'colors';
45
import fs from 'node:fs';
@@ -41,6 +42,9 @@ export async function loadSchemaDocument(schemaFile: string) {
4142
});
4243
throw new CliError('Failed to load schema');
4344
}
45+
loadResult.warnings.forEach((warn) => {
46+
console.warn(colors.yellow(warn));
47+
});
4448
return loadResult.model;
4549
}
4650

@@ -54,6 +58,9 @@ export function handleSubProcessError(err: unknown) {
5458

5559
export async function generateTempPrismaSchema(zmodelPath: string, folder?: string) {
5660
const model = await loadSchemaDocument(zmodelPath);
61+
if (!model.declarations.some(isDataSource)) {
62+
throw new CliError('Schema must define a datasource');
63+
}
5764
const prismaSchema = await new PrismaSchemaGenerator(model).generate();
5865
if (!folder) {
5966
folder = path.dirname(zmodelPath);

packages/cli/test/ts-schema-gen.test.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import { ExpressionUtils } from '@zenstackhq/runtime/schema';
2-
import { generateTsSchema } from '@zenstackhq/testtools';
2+
import { createTestProject, generateTsSchema, generateTsSchemaInPlace } from '@zenstackhq/testtools';
3+
import fs from 'node:fs';
4+
import path from 'node:path';
35
import { describe, expect, it } from 'vitest';
46

57
describe('TypeScript schema generation tests', () => {
@@ -325,4 +327,37 @@ model User extends Base {
325327
},
326328
});
327329
});
330+
331+
it('merges all declarations from imported modules', async () => {
332+
const workDir = createTestProject();
333+
fs.writeFileSync(
334+
path.join(workDir, 'a.zmodel'),
335+
`
336+
enum Role {
337+
Admin
338+
User
339+
}
340+
`,
341+
);
342+
fs.writeFileSync(
343+
path.join(workDir, 'b.zmodel'),
344+
`
345+
import './a'
346+
347+
datasource db {
348+
provider = 'sqlite'
349+
url = 'file:./test.db'
350+
}
351+
352+
model User {
353+
id Int @id
354+
role Role
355+
}
356+
`,
357+
);
358+
359+
const { schema } = await generateTsSchemaInPlace(path.join(workDir, 'b.zmodel'));
360+
expect(schema.enums).toMatchObject({ Role: expect.any(Object) });
361+
expect(schema.models).toMatchObject({ User: expect.any(Object) });
362+
});
328363
});

packages/language/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
"@zenstackhq/eslint-config": "workspace:*",
6363
"@zenstackhq/typescript-config": "workspace:*",
6464
"@zenstackhq/common-helpers": "workspace:*",
65-
"langium-cli": "catalog:"
65+
"langium-cli": "catalog:",
66+
"tmp": "catalog:",
67+
"@types/tmp": "catalog:"
6668
},
6769
"volta": {
6870
"node": "18.19.1",

packages/language/src/index.ts

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import { URI } from 'langium';
1+
import { isAstNode, URI, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium';
22
import { NodeFileSystem } from 'langium/node';
33
import fs from 'node:fs';
44
import path from 'node:path';
55
import { fileURLToPath } from 'node:url';
6-
import type { Model } from './ast';
6+
import { isDataSource, type AstNode, type Model } from './ast';
77
import { STD_LIB_MODULE_NAME } from './constants';
88
import { createZModelLanguageServices } from './module';
9+
import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils';
910

1011
export function createZModelServices() {
1112
return createZModelLanguageServices(NodeFileSystem);
@@ -60,8 +61,12 @@ export async function loadDocument(
6061
const langiumDocuments = services.shared.workspace.LangiumDocuments;
6162
const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName)));
6263

64+
// load imports
65+
const importedURIs = await loadImports(document, langiumDocuments);
66+
const importedDocuments = await Promise.all(importedURIs.map((uri) => langiumDocuments.getOrCreateDocument(uri)));
67+
6368
// build the document together with standard library, plugin modules, and imported documents
64-
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document], {
69+
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
6570
validation: true,
6671
});
6772

@@ -95,11 +100,104 @@ export async function loadDocument(
95100
};
96101
}
97102

103+
const model = document.parseResult.value as Model;
104+
105+
// merge all declarations into the main document
106+
const imported = mergeImportsDeclarations(langiumDocuments, model);
107+
108+
// remove imported documents
109+
imported.forEach((model) => {
110+
langiumDocuments.deleteDocument(model.$document!.uri);
111+
services.shared.workspace.IndexManager.remove(model.$document!.uri);
112+
});
113+
114+
// extra validation after merging imported declarations
115+
const additionalErrors = validationAfterImportMerge(model);
116+
if (additionalErrors.length > 0) {
117+
return {
118+
success: false,
119+
errors: additionalErrors,
120+
warnings,
121+
};
122+
}
123+
98124
return {
99125
success: true,
100126
model: document.parseResult.value as Model,
101127
warnings,
102128
};
103129
}
104130

131+
async function loadImports(
132+
document: LangiumDocument<AstNode>,
133+
documents: LangiumDocuments,
134+
uris: Set<string> = new Set(),
135+
) {
136+
const uriString = document.uri.toString();
137+
if (!uris.has(uriString)) {
138+
uris.add(uriString);
139+
const model = document.parseResult.value as Model;
140+
for (const imp of model.imports) {
141+
const importedModel = resolveImport(documents, imp);
142+
if (importedModel) {
143+
const importedDoc = getDocument(importedModel);
144+
await loadImports(importedDoc, documents, uris);
145+
}
146+
}
147+
}
148+
return Array.from(uris)
149+
.filter((x) => uriString != x)
150+
.map((e) => URI.parse(e));
151+
}
152+
153+
function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) {
154+
const importedModels = resolveTransitiveImports(documents, model);
155+
156+
const importedDeclarations = importedModels.flatMap((m) => m.declarations);
157+
model.declarations.push(...importedDeclarations);
158+
159+
// remove import directives
160+
model.imports = [];
161+
162+
// fix $containerIndex
163+
linkContentToContainer(model);
164+
165+
return importedModels;
166+
}
167+
168+
function linkContentToContainer(node: AstNode): void {
169+
for (const [name, value] of Object.entries(node)) {
170+
if (!name.startsWith('$')) {
171+
if (Array.isArray(value)) {
172+
value.forEach((item, index) => {
173+
if (isAstNode(item)) {
174+
(item as Mutable<AstNode>).$container = node;
175+
(item as Mutable<AstNode>).$containerProperty = name;
176+
(item as Mutable<AstNode>).$containerIndex = index;
177+
}
178+
});
179+
} else if (isAstNode(value)) {
180+
(value as Mutable<AstNode>).$container = node;
181+
(value as Mutable<AstNode>).$containerProperty = name;
182+
}
183+
}
184+
}
185+
}
186+
187+
function validationAfterImportMerge(model: Model) {
188+
const errors: string[] = [];
189+
const dataSources = model.declarations.filter((d) => isDataSource(d));
190+
if (dataSources.length > 1) {
191+
errors.push('Validation error: Multiple datasource declarations are not allowed');
192+
}
193+
194+
// at most one `@@auth` model
195+
const decls = getDataModelAndTypeDefs(model, true);
196+
const authDecls = decls.filter((d) => hasAttribute(d, '@@auth'));
197+
if (authDecls.length > 1) {
198+
errors.push('Validation error: Multiple `@@auth` declarations are not allowed');
199+
}
200+
return errors;
201+
}
202+
105203
export * from './module';

packages/language/src/utils.ts

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { invariant } from '@zenstackhq/common-helpers';
2-
import { AstUtils, URI, type AstNode, type LangiumDocuments, type Reference } from 'langium';
2+
import { AstUtils, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Reference } from 'langium';
33
import fs from 'node:fs';
44
import path from 'path';
55
import { STD_LIB_MODULE_NAME, type ExpressionContext } from './constants';
@@ -577,3 +577,28 @@ export function getAllAttributes(
577577
attributes.push(...decl.attributes);
578578
return attributes;
579579
}
580+
581+
/**
582+
* Retrieve the document in which the given AST node is contained. A reference to the document is
583+
* usually held by the root node of the AST.
584+
*
585+
* @throws an error if the node is not contained in a document.
586+
*/
587+
export function getDocument<T extends AstNode = AstNode>(node: AstNode): LangiumDocument<T> {
588+
const rootNode = findRootNode(node);
589+
const result = rootNode.$document;
590+
if (!result) {
591+
throw new Error('AST node has no document.');
592+
}
593+
return result as LangiumDocument<T>;
594+
}
595+
596+
/**
597+
* Returns the root node of the given AST node by following the `$container` references.
598+
*/
599+
export function findRootNode(node: AstNode): AstNode {
600+
while (node.$container) {
601+
node = node.$container;
602+
}
603+
return node;
604+
}

packages/language/src/validators/schema-validator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ export default class SchemaValidator implements AstValidator<Model> {
4747
private validateImports(model: Model, accept: ValidationAcceptor) {
4848
model.imports.forEach((imp) => {
4949
const importedModel = resolveImport(this.documents, imp);
50-
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
5150
if (!importedModel) {
51+
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
5252
accept('error', `Cannot find model file ${importPath}`, {
5353
node: imp,
5454
});
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import { invariant } from '@zenstackhq/common-helpers';
2+
import fs from 'node:fs';
3+
import path from 'node:path';
4+
import tmp from 'tmp';
5+
import { describe, expect, it } from 'vitest';
6+
import { loadDocument } from '../src';
7+
import { DataModel, isDataModel } from '../src/ast';
8+
9+
describe('Import tests', () => {
10+
it('merges declarations', async () => {
11+
const { name } = tmp.dirSync();
12+
fs.writeFileSync(
13+
path.join(name, 'a.zmodel'),
14+
`
15+
model A {
16+
id Int @id
17+
name String
18+
}
19+
`,
20+
);
21+
fs.writeFileSync(
22+
path.join(name, 'b.zmodel'),
23+
`
24+
import './a'
25+
model B {
26+
id Int @id
27+
}
28+
`,
29+
);
30+
31+
const model = await expectLoaded(path.join(name, 'b.zmodel'));
32+
expect(model.declarations.filter(isDataModel)).toHaveLength(2);
33+
expect(model.imports).toHaveLength(0);
34+
});
35+
36+
it('resolves imported symbols', async () => {
37+
const { name } = tmp.dirSync();
38+
fs.writeFileSync(
39+
path.join(name, 'a.zmodel'),
40+
`
41+
enum Role {
42+
Admin
43+
User
44+
}
45+
`,
46+
);
47+
fs.writeFileSync(
48+
path.join(name, 'b.zmodel'),
49+
`
50+
import './a'
51+
model User {
52+
id Int @id
53+
role Role
54+
}
55+
`,
56+
);
57+
58+
const model = await expectLoaded(path.join(name, 'b.zmodel'));
59+
expect((model.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('Role');
60+
});
61+
62+
it('supports cyclic imports', async () => {
63+
const { name } = tmp.dirSync();
64+
fs.writeFileSync(
65+
path.join(name, 'a.zmodel'),
66+
`
67+
import './b'
68+
model A {
69+
id Int @id
70+
b B?
71+
}
72+
`,
73+
);
74+
fs.writeFileSync(
75+
path.join(name, 'b.zmodel'),
76+
`
77+
}
78+
`,
79+
);
80+
fs.writeFileSync(
81+
path.join(name, 'b.zmodel'),
82+
`
83+
import './a'
84+
model B {
85+
id Int @id
86+
a A @relation(fields: [aId], references: [id])
87+
aId Int @unique
88+
}
89+
`,
90+
);
91+
92+
const modelB = await expectLoaded(path.join(name, 'b.zmodel'));
93+
expect((modelB.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('A');
94+
const modelA = await expectLoaded(path.join(name, 'a.zmodel'));
95+
expect((modelA.declarations[0] as DataModel).fields[1].type.reference?.ref?.name).toBe('B');
96+
});
97+
98+
async function expectLoaded(file: string) {
99+
const result = await loadDocument(file);
100+
if (!result.success) {
101+
console.error('Errors:', result.errors);
102+
throw new Error(`Failed to load document from ${file}`);
103+
}
104+
invariant(result.success);
105+
return result.model;
106+
}
107+
});

0 commit comments

Comments
 (0)