Skip to content

Commit 0df015a

Browse files
authored
feat: implement zmodel import (#126)
* feat: implement zmodel import * addressing PR comments * update
1 parent 7c7183d commit 0df015a

File tree

14 files changed

+391
-48
lines changed

14 files changed

+391
-48
lines changed

packages/cli/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@
4444
},
4545
"devDependencies": {
4646
"@types/better-sqlite3": "^7.6.13",
47-
"@types/tmp": "^0.2.6",
47+
"@types/tmp": "catalog:",
4848
"@zenstackhq/eslint-config": "workspace:*",
4949
"@zenstackhq/runtime": "workspace:*",
5050
"@zenstackhq/testtools": "workspace:*",
5151
"@zenstackhq/typescript-config": "workspace:*",
5252
"better-sqlite3": "^11.8.1",
53-
"tmp": "^0.2.3"
53+
"tmp": "catalog:"
5454
}
5555
}

packages/cli/src/actions/action-utils.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { loadDocument } from '@zenstackhq/language';
2+
import { isDataSource } from '@zenstackhq/language/ast';
23
import { PrismaSchemaGenerator } from '@zenstackhq/sdk';
34
import colors from 'colors';
45
import fs from 'node:fs';
@@ -41,6 +42,9 @@ export async function loadSchemaDocument(schemaFile: string) {
4142
});
4243
throw new CliError('Failed to load schema');
4344
}
45+
loadResult.warnings.forEach((warn) => {
46+
console.warn(colors.yellow(warn));
47+
});
4448
return loadResult.model;
4549
}
4650

@@ -54,6 +58,9 @@ export function handleSubProcessError(err: unknown) {
5458

5559
export async function generateTempPrismaSchema(zmodelPath: string, folder?: string) {
5660
const model = await loadSchemaDocument(zmodelPath);
61+
if (!model.declarations.some(isDataSource)) {
62+
throw new CliError('Schema must define a datasource');
63+
}
5764
const prismaSchema = await new PrismaSchemaGenerator(model).generate();
5865
if (!folder) {
5966
folder = path.dirname(zmodelPath);

packages/cli/test/ts-schema-gen.test.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import { ExpressionUtils } from '@zenstackhq/runtime/schema';
2-
import { generateTsSchema } from '@zenstackhq/testtools';
2+
import { createTestProject, generateTsSchema, generateTsSchemaInPlace } from '@zenstackhq/testtools';
3+
import fs from 'node:fs';
4+
import path from 'node:path';
35
import { describe, expect, it } from 'vitest';
46

57
describe('TypeScript schema generation tests', () => {
@@ -325,4 +327,37 @@ model User extends Base {
325327
},
326328
});
327329
});
330+
331+
it('merges all declarations from imported modules', async () => {
332+
const workDir = createTestProject();
333+
fs.writeFileSync(
334+
path.join(workDir, 'a.zmodel'),
335+
`
336+
enum Role {
337+
Admin
338+
User
339+
}
340+
`,
341+
);
342+
fs.writeFileSync(
343+
path.join(workDir, 'b.zmodel'),
344+
`
345+
import './a'
346+
347+
datasource db {
348+
provider = 'sqlite'
349+
url = 'file:./test.db'
350+
}
351+
352+
model User {
353+
id Int @id
354+
role Role
355+
}
356+
`,
357+
);
358+
359+
const { schema } = await generateTsSchemaInPlace(path.join(workDir, 'b.zmodel'));
360+
expect(schema.enums).toMatchObject({ Role: expect.any(Object) });
361+
expect(schema.models).toMatchObject({ User: expect.any(Object) });
362+
});
328363
});

packages/language/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
"@zenstackhq/eslint-config": "workspace:*",
6363
"@zenstackhq/typescript-config": "workspace:*",
6464
"@zenstackhq/common-helpers": "workspace:*",
65-
"langium-cli": "catalog:"
65+
"langium-cli": "catalog:",
66+
"tmp": "catalog:",
67+
"@types/tmp": "catalog:"
6668
},
6769
"volta": {
6870
"node": "18.19.1",

packages/language/src/index.ts

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import { URI } from 'langium';
1+
import { isAstNode, URI, type LangiumDocument, type LangiumDocuments, type Mutable } from 'langium';
22
import { NodeFileSystem } from 'langium/node';
33
import fs from 'node:fs';
44
import path from 'node:path';
55
import { fileURLToPath } from 'node:url';
6-
import type { Model } from './ast';
6+
import { isDataSource, type AstNode, type Model } from './ast';
77
import { STD_LIB_MODULE_NAME } from './constants';
88
import { createZModelLanguageServices } from './module';
9+
import { getDataModelAndTypeDefs, getDocument, hasAttribute, resolveImport, resolveTransitiveImports } from './utils';
910

1011
export function createZModelServices() {
1112
return createZModelLanguageServices(NodeFileSystem);
@@ -60,8 +61,15 @@ export async function loadDocument(
6061
const langiumDocuments = services.shared.workspace.LangiumDocuments;
6162
const document = await langiumDocuments.getOrCreateDocument(URI.file(path.resolve(fileName)));
6263

64+
// load imports
65+
const importedURIs = await loadImports(document, langiumDocuments);
66+
const importedDocuments: LangiumDocument[] = [];
67+
for (const uri of importedURIs) {
68+
importedDocuments.push(await langiumDocuments.getOrCreateDocument(uri));
69+
}
70+
6371
// build the document together with standard library, plugin modules, and imported documents
64-
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document], {
72+
await services.shared.workspace.DocumentBuilder.build([stdLib, ...pluginDocs, document, ...importedDocuments], {
6573
validation: true,
6674
});
6775

@@ -95,11 +103,100 @@ export async function loadDocument(
95103
};
96104
}
97105

106+
const model = document.parseResult.value as Model;
107+
108+
// merge all declarations into the main document
109+
const imported = mergeImportsDeclarations(langiumDocuments, model);
110+
111+
// remove imported documents
112+
imported.forEach((model) => {
113+
langiumDocuments.deleteDocument(model.$document!.uri);
114+
services.shared.workspace.IndexManager.remove(model.$document!.uri);
115+
});
116+
117+
// extra validation after merging imported declarations
118+
const additionalErrors = validationAfterImportMerge(model);
119+
if (additionalErrors.length > 0) {
120+
return {
121+
success: false,
122+
errors: additionalErrors,
123+
warnings,
124+
};
125+
}
126+
98127
return {
99128
success: true,
100129
model: document.parseResult.value as Model,
101130
warnings,
102131
};
103132
}
104133

134+
async function loadImports(document: LangiumDocument, documents: LangiumDocuments, uris: Set<string> = new Set()) {
135+
const uriString = document.uri.toString();
136+
if (!uris.has(uriString)) {
137+
uris.add(uriString);
138+
const model = document.parseResult.value as Model;
139+
for (const imp of model.imports) {
140+
const importedModel = resolveImport(documents, imp);
141+
if (importedModel) {
142+
const importedDoc = getDocument(importedModel);
143+
await loadImports(importedDoc, documents, uris);
144+
}
145+
}
146+
}
147+
return Array.from(uris)
148+
.filter((x) => uriString != x)
149+
.map((e) => URI.parse(e));
150+
}
151+
152+
function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) {
153+
const importedModels = resolveTransitiveImports(documents, model);
154+
155+
const importedDeclarations = importedModels.flatMap((m) => m.declarations);
156+
model.declarations.push(...importedDeclarations);
157+
158+
// remove import directives
159+
model.imports = [];
160+
161+
// fix $container, $containerIndex, and $containerProperty
162+
linkContentToContainer(model);
163+
164+
return importedModels;
165+
}
166+
167+
function linkContentToContainer(node: AstNode): void {
168+
for (const [name, value] of Object.entries(node)) {
169+
if (!name.startsWith('$')) {
170+
if (Array.isArray(value)) {
171+
value.forEach((item, index) => {
172+
if (isAstNode(item)) {
173+
(item as Mutable<AstNode>).$container = node;
174+
(item as Mutable<AstNode>).$containerProperty = name;
175+
(item as Mutable<AstNode>).$containerIndex = index;
176+
}
177+
});
178+
} else if (isAstNode(value)) {
179+
(value as Mutable<AstNode>).$container = node;
180+
(value as Mutable<AstNode>).$containerProperty = name;
181+
}
182+
}
183+
}
184+
}
185+
186+
function validationAfterImportMerge(model: Model) {
187+
const errors: string[] = [];
188+
const dataSources = model.declarations.filter((d) => isDataSource(d));
189+
if (dataSources.length > 1) {
190+
errors.push('Validation error: Multiple datasource declarations are not allowed');
191+
}
192+
193+
// at most one `@@auth` model
194+
const decls = getDataModelAndTypeDefs(model, true);
195+
const authDecls = decls.filter((d) => hasAttribute(d, '@@auth'));
196+
if (authDecls.length > 1) {
197+
errors.push('Validation error: Multiple `@@auth` declarations are not allowed');
198+
}
199+
return errors;
200+
}
201+
105202
export * from './module';

packages/language/src/utils.ts

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { invariant } from '@zenstackhq/common-helpers';
2-
import { AstUtils, URI, type AstNode, type LangiumDocuments, type Reference } from 'langium';
2+
import { AstUtils, URI, type AstNode, type LangiumDocument, type LangiumDocuments, type Reference } from 'langium';
33
import fs from 'node:fs';
44
import path from 'path';
55
import { STD_LIB_MODULE_NAME, type ExpressionContext } from './constants';
@@ -413,38 +413,13 @@ export function resolveImport(documents: LangiumDocuments, imp: ModelImport) {
413413
}
414414

415415
export function resolveImportUri(imp: ModelImport) {
416-
if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string ("").
417-
418-
if (!imp.path.endsWith('.zmodel')) {
419-
imp.path += '.zmodel';
420-
}
421-
422-
if (
423-
!imp.path.startsWith('.') && // Respect relative paths
424-
!path.isAbsolute(imp.path) // Respect Absolute paths
425-
) {
426-
// use the current model's path as the search context
427-
const contextPath = imp.$container.$document
428-
? path.dirname(imp.$container.$document.uri.fsPath)
429-
: process.cwd();
430-
imp.path = findNodeModulesFile(imp.path, contextPath) ?? imp.path;
416+
if (!imp.path) {
417+
return undefined;
431418
}
432-
433419
const doc = AstUtils.getDocument(imp);
434420
const dir = path.dirname(doc.uri.fsPath);
435-
return URI.file(path.resolve(dir, imp.path));
436-
}
437-
438-
export function findNodeModulesFile(name: string, cwd: string = process.cwd()) {
439-
if (!name) return undefined;
440-
try {
441-
// Use require.resolve to find the module/file. The paths option allows specifying the directory to start from.
442-
const resolvedPath = require.resolve(name, { paths: [cwd] });
443-
return resolvedPath;
444-
} catch {
445-
// If require.resolve fails to find the module/file, it will throw an error.
446-
return undefined;
447-
}
421+
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
422+
return URI.file(path.resolve(dir, importPath));
448423
}
449424

450425
/**
@@ -577,3 +552,28 @@ export function getAllAttributes(
577552
attributes.push(...decl.attributes);
578553
return attributes;
579554
}
555+
556+
/**
557+
* Retrieve the document in which the given AST node is contained. A reference to the document is
558+
* usually held by the root node of the AST.
559+
*
560+
* @throws an error if the node is not contained in a document.
561+
*/
562+
export function getDocument<T extends AstNode = AstNode>(node: AstNode): LangiumDocument<T> {
563+
const rootNode = findRootNode(node);
564+
const result = rootNode.$document;
565+
if (!result) {
566+
throw new Error('AST node has no document.');
567+
}
568+
return result as LangiumDocument<T>;
569+
}
570+
571+
/**
572+
* Returns the root node of the given AST node by following the `$container` references.
573+
*/
574+
export function findRootNode(node: AstNode): AstNode {
575+
while (node.$container) {
576+
node = node.$container;
577+
}
578+
return node;
579+
}

packages/language/src/validators/schema-validator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ export default class SchemaValidator implements AstValidator<Model> {
4747
private validateImports(model: Model, accept: ValidationAcceptor) {
4848
model.imports.forEach((imp) => {
4949
const importedModel = resolveImport(this.documents, imp);
50-
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
5150
if (!importedModel) {
51+
const importPath = imp.path.endsWith('.zmodel') ? imp.path : `${imp.path}.zmodel`;
5252
accept('error', `Cannot find model file ${importPath}`, {
5353
node: imp,
5454
});

0 commit comments

Comments
 (0)