From 74d4036cd4b7cfe32b6700ccc4e1f9966d69376e Mon Sep 17 00:00:00 2001 From: Pulung Ragil Date: Fri, 20 Dec 2024 18:11:32 +0700 Subject: [PATCH 1/2] fix: cte and subquery relation name and columns should be registered on query catalog --- internal/compiler/analyze.go | 3 +- internal/compiler/compat.go | 9 ++ internal/compiler/parse.go | 12 ++ internal/compiler/query.go | 29 +++++ internal/compiler/resolve.go | 30 ++++- .../cte_and_subquery_column_alias/issue.md | 1 + .../mysql/db/db.go | 31 +++++ .../mysql/db/models.go | 11 ++ .../mysql/db/query.sql.go | 106 ++++++++++++++++++ .../mysql/query.sql | 26 +++++ .../mysql/schema.sql | 5 + .../mysql/sqlc.yaml | 16 +++ .../postgresql/db/db.go | 31 +++++ .../postgresql/db/models.go | 11 ++ .../postgresql/db/query.sql.go | 106 ++++++++++++++++++ .../postgresql/query.sql | 26 +++++ .../postgresql/schema.sql | 5 + .../postgresql/sqlc.yaml | 16 +++ 18 files changed, 467 insertions(+), 7 deletions(-) create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/issue.md create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/db.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/models.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/query.sql.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/query.sql create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/schema.sql create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/sqlc.yaml create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/db.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/models.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/query.sql.go create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/query.sql create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/schema.sql create mode 100644 internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/sqlc.yaml diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 38d66fce19..381c54bcc1 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -154,6 +154,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } rvs := rangeVars(raw.Stmt) + rss := rangeSubSelects(raw.Stmt) refs, errs := findParameters(raw.Stmt) if len(errs) > 0 { if failfast { @@ -173,7 +174,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds, rss) if err := check(err); err != nil { return nil, err } diff --git a/internal/compiler/compat.go b/internal/compiler/compat.go index 097d889cfb..bdbd90c1b3 100644 --- a/internal/compiler/compat.go +++ b/internal/compiler/compat.go @@ -84,6 +84,15 @@ func parseRelation(node ast.Node) (*Relation, error) { return &Relation{Name: n.Name}, nil } + case *ast.RangeSubselect: + if n == nil { + return nil, fmt.Errorf("unexpected nil in %T node", n) + } + if n.Alias != nil && n.Alias.Aliasname != nil { + return &Relation{Name: *n.Alias.Aliasname}, nil + } + return nil, fmt.Errorf("no alias in subquery") + default: return nil, fmt.Errorf("unexpected node type: %T", node) } diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 681d291122..d378738bf5 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -144,6 +144,18 @@ func rangeVars(root ast.Node) []*ast.RangeVar { return vars } +func rangeSubSelects(root ast.Node) []*ast.RangeSubselect { + var rss []*ast.RangeSubselect + find := astutils.VisitorFunc(func(node ast.Node) { + switch n := node.(type) { + case *ast.RangeSubselect: + rss = append(rss, n) + } + }) + astutils.Walk(find, root) + return rss +} + func uniqueParamRefs(in []paramRef, dollar bool) []paramRef { m := make(map[int]bool, len(in)) o := make([]paramRef, 0, len(in)) diff --git a/internal/compiler/query.go b/internal/compiler/query.go index b3cf9d6154..5eba86877d 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -59,3 +59,32 @@ type Parameter struct { Number int Column *Column } + +func (t *Table) toCatalogTable() catalog.Table { + var catalogCols []*catalog.Column + for _, qcol := range t.Columns { + catalogColType := ast.TypeName{} + if qcol.Type != nil { + catalogColType = *qcol.Type + } + + catalogCol := &catalog.Column{ + Name: qcol.Name, + Type: catalogColType, + IsNotNull: qcol.NotNull, + IsUnsigned: qcol.Unsigned, + IsArray: qcol.IsArray, + ArrayDims: qcol.ArrayDims, + Comment: qcol.Comment, + Length: qcol.Length, + } + + catalogCols = append(catalogCols, catalogCol) + } + + return catalog.Table{ + Rel: t.Rel, + Columns: catalogCols, + Comment: "", + } +} \ No newline at end of file diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..100bdf662f 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -21,7 +21,8 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, + embeds rewrite.EmbedSet, rss []*ast.RangeSubselect) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -67,10 +68,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } // If the table name doesn't exist, first check if it's a CTE - if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err + var qcTable *Table + var qcerr error + if qcTable, qcerr = qc.GetTable(fqn); qcerr != nil { + return nil, qcerr } - continue + table = qcTable.toCatalogTable() } err = indexTable(table) if err != nil { @@ -81,6 +84,23 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } + for _, rs := range rss { + fqn, err := ParseTableName(rs) + if err != nil { + return nil, err + } + + cols, err := comp.outputColumns(qc, rs.Subquery) + if err != nil { + return nil, err + } + rsTable := Table{Rel: fqn, Columns: cols} + err = indexTable(rsTable.toCatalogTable()) + if err != nil { + return nil, err + } + } + // resolve a table for an embed for _, embed := range embeds { table, err := c.GetTable(embed.Table) @@ -268,7 +288,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *ast.BetweenExpr: if n == nil || n.Expr == nil || n.Left == nil || n.Right == nil { - fmt.Println("ast.BetweenExpr is nil") continue } @@ -527,7 +546,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *ast.In: if n == nil || n.List == nil { - fmt.Println("ast.In is nil") continue } diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/issue.md b/internal/endtoend/testdata/cte_and_subquery_column_alias/issue.md new file mode 100644 index 0000000000..3f2dca55fd --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/issue.md @@ -0,0 +1 @@ +https://github.com/sqlc-dev/sqlc/issues/3720 \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/db.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/db.go new file mode 100644 index 0000000000..41b7a34365 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/models.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/models.go new file mode 100644 index 0000000000..d7ca792f7b --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package db + +type Customer struct { + ID uint32 + FirstName string + LastName string +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/query.sql.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/query.sql.go new file mode 100644 index 0000000000..6564e375dd --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/db/query.sql.go @@ -0,0 +1,106 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: query.sql + +package db + +import ( + "context" + "strings" +) + +const getFullNames = `-- name: GetFullNames :many +SELECT + full_name +FROM + ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) subquery +WHERE + full_name IN (/*SLICE:full_names*/?) +` + +func (q *Queries) GetFullNames(ctx context.Context, fullNames []interface{}) ([]string, error) { + query := getFullNames + var queryParams []interface{} + if len(fullNames) > 0 { + for _, v := range fullNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:full_names*/?", strings.Repeat(",?", len(fullNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:full_names*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var full_name string + if err := rows.Scan(&full_name); err != nil { + return nil, err + } + items = append(items, full_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getFullNames2 = `-- name: GetFullNames2 :many +WITH subquery AS ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) +SELECT + full_name +FROM + subquery +WHERE + full_name IN (/*SLICE:full_names*/?) +` + +func (q *Queries) GetFullNames2(ctx context.Context, fullNames []interface{}) ([]string, error) { + query := getFullNames2 + var queryParams []interface{} + if len(fullNames) > 0 { + for _, v := range fullNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:full_names*/?", strings.Repeat(",?", len(fullNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:full_names*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var full_name string + if err := rows.Scan(&full_name); err != nil { + return nil, err + } + items = append(items, full_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/query.sql b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/query.sql new file mode 100644 index 0000000000..f93a2e2861 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/query.sql @@ -0,0 +1,26 @@ +-- name: GetFullNames :many +SELECT + full_name +FROM + ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) subquery +WHERE + full_name IN (sqlc.slice ("full_names")); + +-- name: GetFullNames2 :many +WITH subquery AS ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) +SELECT + full_name +FROM + subquery +WHERE + full_name IN (sqlc.slice ("full_names")); \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/schema.sql b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/schema.sql new file mode 100644 index 0000000000..6d94db9de8 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE customers ( + id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, + first_name varchar(255) not null, + last_name varchar(255) not null +) ENGINE = INNODB DEFAULT CHARSET = utf8mb4; \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/sqlc.yaml b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/sqlc.yaml new file mode 100644 index 0000000000..0b715f453f --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/mysql/sqlc.yaml @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "schema": "schema.sql", + "queries": "query.sql", + "engine": "mysql", + "gen": { + "go": { + "package": "db", + "out": "db" + } + } + } + ] +} \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/db.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/db.go new file mode 100644 index 0000000000..41b7a34365 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/models.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/models.go new file mode 100644 index 0000000000..868cfd4b98 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 + +package db + +type Customer struct { + ID int64 + FirstName string + LastName string +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/query.sql.go b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/query.sql.go new file mode 100644 index 0000000000..7793979cda --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/db/query.sql.go @@ -0,0 +1,106 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: query.sql + +package db + +import ( + "context" + "strings" +) + +const getFullNames = `-- name: GetFullNames :many +SELECT + full_name +FROM + ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) subquery +WHERE + full_name IN ($1) +` + +func (q *Queries) GetFullNames(ctx context.Context, fullNames []interface{}) ([]interface{}, error) { + query := getFullNames + var queryParams []interface{} + if len(fullNames) > 0 { + for _, v := range fullNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:full_names*/?", strings.Repeat(",?", len(fullNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:full_names*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []interface{} + for rows.Next() { + var full_name interface{} + if err := rows.Scan(&full_name); err != nil { + return nil, err + } + items = append(items, full_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getFullNames2 = `-- name: GetFullNames2 :many +WITH subquery AS ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) +SELECT + full_name +FROM + subquery +WHERE + full_name IN ($1) +` + +func (q *Queries) GetFullNames2(ctx context.Context, fullNames []interface{}) ([]interface{}, error) { + query := getFullNames2 + var queryParams []interface{} + if len(fullNames) > 0 { + for _, v := range fullNames { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:full_names*/?", strings.Repeat(",?", len(fullNames))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:full_names*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []interface{} + for rows.Next() { + var full_name interface{} + if err := rows.Scan(&full_name); err != nil { + return nil, err + } + items = append(items, full_name) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/query.sql b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/query.sql new file mode 100644 index 0000000000..f34f1e2e90 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/query.sql @@ -0,0 +1,26 @@ +-- name: GetFullNames :many +SELECT + full_name +FROM + ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) subquery +WHERE + full_name IN (sqlc.slice ('full_names')); + +-- name: GetFullNames2 :many +WITH subquery AS ( + SELECT + concat(first_name, ' ', last_name) as full_name + FROM + customers + ) +SELECT + full_name +FROM + subquery +WHERE + full_name IN (sqlc.slice ('full_names')); \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/schema.sql b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/schema.sql new file mode 100644 index 0000000000..896d4d6fe6 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE customers ( + id bigserial PRIMARY KEY, + first_name varchar(255) not null, + last_name varchar(255) not null +); \ No newline at end of file diff --git a/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/sqlc.yaml b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/sqlc.yaml new file mode 100644 index 0000000000..fa39ef7437 --- /dev/null +++ b/internal/endtoend/testdata/cte_and_subquery_column_alias/postgresql/sqlc.yaml @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "schema": "schema.sql", + "queries": "query.sql", + "engine": "postgresql", + "gen": { + "go": { + "package": "db", + "out": "db" + } + } + } + ] +} \ No newline at end of file From 98e076f987605ccf4b9c24e79278b23d8a23a453 Mon Sep 17 00:00:00 2001 From: Pulung Ragil Date: Fri, 20 Dec 2024 18:30:25 +0700 Subject: [PATCH 2/2] revert: deleted fmt used as log --- internal/compiler/resolve.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 100bdf662f..dcb57cf864 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -288,6 +288,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *ast.BetweenExpr: if n == nil || n.Expr == nil || n.Left == nil || n.Right == nil { + fmt.Println("ast.BetweenExpr is nil") continue } @@ -546,6 +547,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *ast.In: if n == nil || n.List == nil { + fmt.Println("ast.In is nil") continue }