Skip to content

Commit 8c935fc

Browse files
authored
fix: correct type inference for CTEs with aggregations and joins (#430)
Improves the type inference logic for Common Table Expressions (CTEs) involving aggregations and joins. Fixes issues where nullable types were incorrectly inferred, especially in cases with LEFT JOINs and aggregated columns. Adds a helper function to adjust context when resolving column references from CTE sources, ensuring accurate type descriptions.
1 parent 5d90bb3 commit 8c935fc

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

.changeset/fast-planes-work.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ts-safeql/generate": patch
3+
---
4+
5+
fix: correct type inference for CTEs with aggregations and joins

packages/generate/src/ast-describe.ts

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,17 +990,43 @@ function getColumnRefOrigins({
990990
}
991991
}
992992

993+
function getContextForColumnRef(
994+
context: ASTDescriptionContext,
995+
node: LibPgQueryAST.ColumnRef,
996+
): ASTDescriptionContext {
997+
if (isColumnTableColumnRef(node.fields) || isColumnTableStarRef(node.fields)) {
998+
const sourceName = node.fields[0].String.sval;
999+
const source = context.resolver.sources.get(sourceName);
1000+
if (source?.kind === "cte") {
1001+
return { ...context, resolver: source.sources };
1002+
}
1003+
}
1004+
1005+
return context;
1006+
}
1007+
9931008
function getDescribedColumnRef({
9941009
alias,
9951010
context,
9961011
node,
9971012
}: GetDescribedParamsOf<LibPgQueryAST.ColumnRef>): ASTDescribedColumn[] {
998-
const origins = getColumnRefOrigins({ alias, context, node })
999-
?.map((origin) => getDescribedNode({ alias, node: origin, context }))
1000-
.flat();
1013+
const definitionNodes = getColumnRefOrigins({ alias, context, node });
10011014

1002-
if (origins) return origins;
1015+
if (definitionNodes) {
1016+
const defContext = getContextForColumnRef(context, node);
1017+
return definitionNodes.flatMap((node) =>
1018+
getDescribedNode({ alias, node, context: defContext }),
1019+
);
1020+
}
10031021

1022+
return getDescribedColumnRefFromSchema({ alias, context, node });
1023+
}
1024+
1025+
function getDescribedColumnRefFromSchema({
1026+
alias,
1027+
context,
1028+
node,
1029+
}: GetDescribedParamsOf<LibPgQueryAST.ColumnRef>): ASTDescribedColumn[] {
10041030
// select *
10051031
if (isColumnStarRef(node.fields)) {
10061032
return getDescribedColumnByResolvedColumns({

packages/generate/src/generate.test.ts

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function runMigrations(sql: SQL) {
5656
id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
5757
nullable_col TEXT
5858
);
59-
59+
6060
CREATE TYPE overriden_enum AS ENUM ('foo', 'bar');
6161
6262
CREATE TABLE test_overriden_enum (
@@ -2443,3 +2443,42 @@ test("nullable columns in INNER JOIN subselect should remain nullable", async ()
24432443
],
24442444
});
24452445
});
2446+
2447+
test("regression: wrong inference of nullable in aggregation", async () => {
2448+
await testQuery({
2449+
query: `
2450+
with subquery as (
2451+
select a_id, array_agg(b_id) as list
2452+
from b
2453+
group by a_id
2454+
)
2455+
select subquery.list
2456+
from a
2457+
left join subquery on (subquery.a_id = a.id);
2458+
`,
2459+
schema: `
2460+
CREATE TABLE a (
2461+
id int primary key,
2462+
name text not null default ''
2463+
);
2464+
2465+
CREATE TABLE b (
2466+
a_id int not null,
2467+
b_id int not null,
2468+
primary key (a_id, b_id)
2469+
);
2470+
`,
2471+
expected: [
2472+
[
2473+
"list",
2474+
{
2475+
kind: "union",
2476+
value: [
2477+
{ kind: "array", value: { kind: "type", value: "number", type: "int4" } },
2478+
{ kind: "type", value: "null", type: "null" },
2479+
],
2480+
},
2481+
],
2482+
],
2483+
});
2484+
});

0 commit comments

Comments
 (0)