Skip to content

Commit 6ec9a46

Browse files
authored
fix: scalar subquery in select list should infer correct type (#412)
1 parent 88abda0 commit 6ec9a46

File tree

4 files changed

+143
-31
lines changed

4 files changed

+143
-31
lines changed

.changeset/wet-ears-cross.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@ts-safeql/generate": patch
3+
---
4+
5+
Fixed type inference for scalar subqueries in select lists.

packages/generate/src/ast-describe.ts

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ import {
1010
} from "./ast-decribe.utils";
1111
import { ResolvedColumn, SourcesResolver, getSources } from "./ast-get-sources";
1212
import { PgColRow, PgEnumsMaps, PgTypesMap } from "./generate";
13-
import { FlattenedRelationWithJoins } from "./utils/get-relations-with-joins";
13+
import { getNonNullableColumns } from "./utils/get-nonnullable-columns";
14+
import {
15+
FlattenedRelationWithJoins,
16+
flattenRelationsWithJoinsMap,
17+
getRelationsWithJoins,
18+
} from "./utils/get-relations-with-joins";
1419

1520
type ASTDescriptionOptions = {
1621
parsed: LibPgQueryAST.ParseResult;
17-
relations: FlattenedRelationWithJoins[];
1822
typesMap: Map<string, { override: boolean; value: string }>;
1923
typeExprMap: Map<string, Map<string, Map<string, string>>>;
2024
overridenColumnTypesMap: Map<string, Map<string, string>>;
21-
nonNullableColumns: Set<string>;
2225
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
2326
pgTypes: PgTypesMap;
2427
pgEnums: PgEnumsMaps;
@@ -29,6 +32,8 @@ type ASTDescriptionContext = ASTDescriptionOptions & {
2932
select: LibPgQueryAST.SelectStmt;
3033
resolver: SourcesResolver;
3134
resolved: WeakMap<LibPgQueryAST.Node, string>;
35+
nonNullableColumns: Set<string>;
36+
relations: FlattenedRelationWithJoins[];
3237
toTypeScriptType: (
3338
params: { oid: number; baseOid: number | null } | { name: string },
3439
) => ASTDescribedColumnType;
@@ -43,15 +48,28 @@ export type ASTDescribedColumnType =
4348
| { kind: "type"; value: string; type: string; base?: string }
4449
| { kind: "literal"; value: string; base: ASTDescribedColumnType };
4550

46-
export function getASTDescription(
47-
params: ASTDescriptionOptions,
48-
): Map<number, ASTDescribedColumn | undefined> {
51+
export function getASTDescription(params: ASTDescriptionOptions): {
52+
map: Map<number, ASTDescribedColumn | undefined>;
53+
meta: {
54+
relations: FlattenedRelationWithJoins[];
55+
nonNullableColumns: Set<string>;
56+
};
57+
} {
4958
const select = params.parsed.stmts[0]?.stmt?.SelectStmt;
5059

5160
if (select === undefined) {
52-
return new Map();
61+
return {
62+
map: new Map(),
63+
meta: {
64+
relations: [],
65+
nonNullableColumns: new Set(),
66+
},
67+
};
5368
}
5469

70+
const nonNullableColumns = getNonNullableColumns(params.parsed);
71+
const relations = flattenRelationsWithJoinsMap(getRelationsWithJoins(params.parsed));
72+
5573
function getTypeByOid(oid: number) {
5674
const name = params.pgTypes.get(oid)?.name;
5775

@@ -74,10 +92,12 @@ export function getASTDescription(
7492

7593
const context: ASTDescriptionContext = {
7694
...params,
95+
nonNullableColumns,
96+
relations,
7797
resolver: getSources({
78-
relations: params.relations,
98+
relations: relations,
7999
select: select,
80-
nonNullableColumns: params.nonNullableColumns,
100+
nonNullableColumns: nonNullableColumns,
81101
pgColsBySchemaAndTableName: params.pgColsBySchemaAndTableName,
82102
}),
83103
select: select,
@@ -176,7 +196,13 @@ export function getASTDescription(
176196
final.set(i, result);
177197
}
178198

179-
return final;
199+
return {
200+
map: final,
201+
meta: {
202+
relations,
203+
nonNullableColumns,
204+
},
205+
};
180206
}
181207

182208
function mergeColumns(columns: (ASTDescribedColumn | undefined)[]): ASTDescribedColumn | undefined {
@@ -247,6 +273,10 @@ function getDescribedNode(params: {
247273
return getDescribedAExpr({ alias: alias, node: node.A_Expr, context });
248274
}
249275

276+
if (node.SelectStmt !== undefined) {
277+
return getDescribedSelectStmt({ alias: alias, node: node.SelectStmt, context });
278+
}
279+
250280
return [];
251281
}
252282

@@ -488,24 +518,74 @@ function getDescribedSubLink({
488518
context,
489519
node,
490520
}: GetDescribedParamsOf<LibPgQueryAST.SubLink>): ASTDescribedColumn[] {
521+
const getSubLinkType = (): ASTDescribedColumnType => {
522+
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXISTS_SUBLINK) {
523+
return context.toTypeScriptType({ name: "bool" });
524+
}
525+
526+
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXPR_SUBLINK) {
527+
const described = node.subselect?.SelectStmt
528+
? getDescribedNode({
529+
alias: undefined,
530+
node: { SelectStmt: node.subselect.SelectStmt },
531+
context,
532+
})
533+
: [];
534+
535+
return described.length > 0
536+
? described[0].type
537+
: context.toTypeScriptType({ name: "unknown" });
538+
}
539+
540+
return context.toTypeScriptType({ name: "unknown" });
541+
};
542+
491543
return [
492544
{
493545
name: alias ?? "exists",
494546
type: resolveType({
495547
context: context,
496548
nullable: false,
497-
type: (() => {
498-
if (node.subLinkType === LibPgQueryAST.SubLinkType.EXISTS_SUBLINK) {
499-
return context.toTypeScriptType({ name: "bool" });
500-
}
501-
502-
return context.toTypeScriptType({ name: "unknown" });
503-
})(),
549+
type: getSubLinkType(),
504550
}),
505551
},
506552
];
507553
}
508554

555+
function getDescribedSelectStmt({
556+
alias,
557+
context,
558+
node,
559+
}: GetDescribedParamsOf<LibPgQueryAST.SelectStmt>): ASTDescribedColumn[] {
560+
const subParsed: LibPgQueryAST.ParseResult = {
561+
version: 0,
562+
stmts: [{ stmt: { SelectStmt: node }, stmtLocation: 0, stmtLen: 0 }],
563+
};
564+
565+
const subDescription = getASTDescription({
566+
parsed: subParsed,
567+
typesMap: context.typesMap,
568+
typeExprMap: context.typeExprMap,
569+
overridenColumnTypesMap: context.overridenColumnTypesMap,
570+
pgColsBySchemaAndTableName: context.pgColsBySchemaAndTableName,
571+
pgTypes: context.pgTypes,
572+
pgEnums: context.pgEnums,
573+
pgFns: context.pgFns,
574+
});
575+
576+
const firstColumn = subDescription.map.get(0);
577+
if (firstColumn) {
578+
return [
579+
{
580+
name: alias ?? firstColumn.name,
581+
type: firstColumn.type,
582+
},
583+
];
584+
}
585+
586+
return [];
587+
}
588+
509589
function getDescribedCoalesceExpr({
510590
alias,
511591
context,

packages/generate/src/generate.test.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2056,7 +2056,6 @@ test(`jsonb subselect ->> key => string | null`, async () => {
20562056
},
20572057
],
20582058
],
2059-
unknownColumns: ["extracted_value"],
20602059
});
20612060
});
20622061

@@ -2383,3 +2382,40 @@ test("jsonb ->> operator should return string | null", async () => {
23832382
],
23842383
});
23852384
});
2385+
2386+
test("scalar subquery in select list should infer correct type", async () => {
2387+
await testQuery({
2388+
schema: `
2389+
CREATE TABLE tbl (
2390+
id INTEGER PRIMARY KEY,
2391+
col TEXT
2392+
);
2393+
`,
2394+
query: `SELECT (SELECT col FROM tbl LIMIT 1) AS col`,
2395+
expected: [
2396+
[
2397+
"col",
2398+
{
2399+
kind: "union",
2400+
value: [
2401+
{ kind: "type", value: "string", type: "text" },
2402+
{ kind: "type", value: "null", type: "null" },
2403+
],
2404+
},
2405+
],
2406+
],
2407+
});
2408+
});
2409+
2410+
test("scalar subquery with WHERE should infer non-nullable type", async () => {
2411+
await testQuery({
2412+
schema: `
2413+
CREATE TABLE tbl (
2414+
id INTEGER PRIMARY KEY,
2415+
col TEXT
2416+
);
2417+
`,
2418+
query: `SELECT (SELECT col FROM tbl WHERE col IS NOT NULL LIMIT 1) AS col`,
2419+
expected: [["col", { kind: "type", value: "string", type: "text" }]],
2420+
});
2421+
});

packages/generate/src/generate.ts

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@ import { either } from "fp-ts";
1616
import postgres from "postgres";
1717
import { ASTDescribedColumn, getASTDescription } from "./ast-describe";
1818
import { ColType } from "./utils/colTypes";
19-
import { getNonNullableColumns } from "./utils/get-nonnullable-columns";
20-
import {
21-
FlattenedRelationWithJoins,
22-
flattenRelationsWithJoinsMap,
23-
getRelationsWithJoins,
24-
} from "./utils/get-relations-with-joins";
19+
import { FlattenedRelationWithJoins } from "./utils/get-relations-with-joins";
2520
import * as parser from "libpg-query";
2621

2722
type JSToPostgresTypeMap = Record<string, unknown>;
@@ -262,15 +257,11 @@ async function generate(
262257
}
263258

264259
const parsed = await parser.parse(query.text);
265-
const relationsWithJoins = flattenRelationsWithJoinsMap(getRelationsWithJoins(parsed));
266-
const nonNullableColumnsBasedOnAST = getNonNullableColumns(parsed);
267260

268261
const astQueryDescription = getASTDescription({
269262
parsed: parsed,
270-
relations: relationsWithJoins,
271263
typesMap: typesMap,
272264
overridenColumnTypesMap: overridenColumnTypesMap,
273-
nonNullableColumns: nonNullableColumnsBasedOnAST,
274265
pgColsBySchemaAndTableName: pgColsBySchemaAndTableName,
275266
pgTypes: pgTypes,
276267
pgEnums: pgEnums,
@@ -283,21 +274,21 @@ async function generate(
283274
.get(col.table)
284275
?.find((x) => x.colNum === col.number);
285276

286-
const astDescribed = astQueryDescription.get(position);
277+
const astDescribed = astQueryDescription.map.get(position);
287278

288279
return {
289280
described: col,
290281
astDescribed: astDescribed,
291282
introspected: introspected,
292-
isNonNullableBasedOnAST: nonNullableColumnsBasedOnAST.has(col.name),
283+
isNonNullableBasedOnAST: astQueryDescription.meta.nonNullableColumns.has(col.name),
293284
};
294285
});
295286

296287
const context: GenerateContext = {
297288
columns,
298289
pgTypes,
299290
pgEnums,
300-
relationsWithJoins,
291+
relationsWithJoins: astQueryDescription.meta.relations,
301292
overrides: {
302293
types: typesMap,
303294
columns: overridenColumnTypesMap,

0 commit comments

Comments
 (0)