Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
return nil, err
}
rvs := rangeVars(raw.Stmt)
rss := rangeSubSelects(raw.Stmt)
refs, errs := findParameters(raw.Stmt)
if len(errs) > 0 {
if failfast {
Expand All @@ -173,7 +174,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
return nil, err
}

params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds, rss)
if err := check(err); err != nil {
return nil, err
}
Expand Down
9 changes: 9 additions & 0 deletions internal/compiler/compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ func parseRelation(node ast.Node) (*Relation, error) {
return &Relation{Name: n.Name}, nil
}

case *ast.RangeSubselect:
if n == nil {
return nil, fmt.Errorf("unexpected nil in %T node", n)
}
if n.Alias != nil && n.Alias.Aliasname != nil {
return &Relation{Name: *n.Alias.Aliasname}, nil
}
return nil, fmt.Errorf("no alias in subquery")

default:
return nil, fmt.Errorf("unexpected node type: %T", node)
}
Expand Down
12 changes: 12 additions & 0 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ func rangeVars(root ast.Node) []*ast.RangeVar {
return vars
}

func rangeSubSelects(root ast.Node) []*ast.RangeSubselect {
var rss []*ast.RangeSubselect
find := astutils.VisitorFunc(func(node ast.Node) {
switch n := node.(type) {
case *ast.RangeSubselect:
rss = append(rss, n)
}
})
astutils.Walk(find, root)
return rss
}

func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
m := make(map[int]bool, len(in))
o := make([]paramRef, 0, len(in))
Expand Down
29 changes: 29 additions & 0 deletions internal/compiler/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,32 @@ type Parameter struct {
Number int
Column *Column
}

func (t *Table) toCatalogTable() catalog.Table {
var catalogCols []*catalog.Column
for _, qcol := range t.Columns {
catalogColType := ast.TypeName{}
if qcol.Type != nil {
catalogColType = *qcol.Type
}

catalogCol := &catalog.Column{
Name: qcol.Name,
Type: catalogColType,
IsNotNull: qcol.NotNull,
IsUnsigned: qcol.Unsigned,
IsArray: qcol.IsArray,
ArrayDims: qcol.ArrayDims,
Comment: qcol.Comment,
Length: qcol.Length,
}

catalogCols = append(catalogCols, catalogCol)
}

return catalog.Table{
Rel: t.Rel,
Columns: catalogCols,
Comment: "",
}
}
28 changes: 24 additions & 4 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func dataType(n *ast.TypeName) string {
}
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet,
embeds rewrite.EmbedSet, rss []*ast.RangeSubselect) ([]Parameter, error) {
c := comp.catalog

aliasMap := map[string]*ast.TableName{}
Expand Down Expand Up @@ -67,10 +68,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}
// If the table name doesn't exist, first check if it's a CTE
if _, qcerr := qc.GetTable(fqn); qcerr != nil {
return nil, err
var qcTable *Table
var qcerr error
if qcTable, qcerr = qc.GetTable(fqn); qcerr != nil {
return nil, qcerr
}
continue
table = qcTable.toCatalogTable()
}
err = indexTable(table)
if err != nil {
Expand All @@ -81,6 +84,23 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
}
}

for _, rs := range rss {
fqn, err := ParseTableName(rs)
if err != nil {
return nil, err
}

cols, err := comp.outputColumns(qc, rs.Subquery)
if err != nil {
return nil, err
}
rsTable := Table{Rel: fqn, Columns: cols}
err = indexTable(rsTable.toCatalogTable())
if err != nil {
return nil, err
}
}

// resolve a table for an embed
for _, embed := range embeds {
table, err := c.GetTable(embed.Table)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://github.com/sqlc-dev/sqlc/issues/3720

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- name: GetFullNames :many
SELECT
full_name
FROM
(
SELECT
concat(first_name, ' ', last_name) as full_name
FROM
customers
) subquery
WHERE
full_name IN (sqlc.slice ("full_names"));

-- name: GetFullNames2 :many
WITH subquery AS (
SELECT
concat(first_name, ' ', last_name) as full_name
FROM
customers
)
SELECT
full_name
FROM
subquery
WHERE
full_name IN (sqlc.slice ("full_names"));
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE customers (
id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
first_name varchar(255) not null,
last_name varchar(255) not null
) ENGINE = INNODB DEFAULT CHARSET = utf8mb4;
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"version": "2",
"sql": [
{
"schema": "schema.sql",
"queries": "query.sql",
"engine": "mysql",
"gen": {
"go": {
"package": "db",
"out": "db"
}
}
}
]
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading