diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..ece1a31d01 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -181,6 +181,12 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } + if c.conf.Engine == config.EngineSQLite { + if err := check(validate.ValidateSQLiteQualifiedColumnRefs(raw.Stmt)); err != nil { + return nil, err + } + } + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) if err := check(err); err != nil { return nil, err diff --git a/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/query.sql b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/query.sql new file mode 100644 index 0000000000..ad3f50c3a3 --- /dev/null +++ b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/query.sql @@ -0,0 +1,9 @@ +-- name: GetByPublicID :one +SELECT * +FROM locations l +WHERE l.public_id = ? +AND EXISTS ( + SELECT 1 + FROM projects p + WHERE p.id = location.project_id +); diff --git a/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/schema.sql b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/schema.sql new file mode 100644 index 0000000000..102bbef7cf --- /dev/null +++ b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/schema.sql @@ -0,0 +1,20 @@ +CREATE TABLE organizations ( + id INTEGER PRIMARY KEY +); + +CREATE TABLE organization_members ( + id INTEGER PRIMARY KEY, + organization_id INTEGER NOT NULL, + account_id INTEGER NOT NULL +); + +CREATE TABLE projects ( + id INTEGER PRIMARY KEY, + organization_id INTEGER NOT NULL +); + +CREATE TABLE locations ( + id INTEGER PRIMARY KEY, + public_id TEXT UNIQUE NOT NULL, + project_id INTEGER NOT NULL +) STRICT; diff --git a/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml new file mode 100644 index 0000000000..3fe71d6f20 --- /dev/null +++ b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml @@ -0,0 +1,12 @@ +version: "2" + +sql: + - engine: "sqlite" + schema: "schema.sql" + queries: "query.sql" + rules: + - sqlc/db-prepare + gen: + go: + package: "db" + out: "generated" diff --git a/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/stderr.txt b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/stderr.txt new file mode 100644 index 0000000000..496a56dbec --- /dev/null +++ b/internal/endtoend/testdata/base/sqlite_invalid_correlated_ref/sqlite/stderr.txt @@ -0,0 +1,2 @@ +# package db +query.sql:8:18: table alias "location" does not exist diff --git a/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt b/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt index 97e43851e0..1eddeaac99 100644 --- a/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt +++ b/internal/endtoend/testdata/invalid_table_alias/sqlite/stderr.txt @@ -1,2 +1,2 @@ # package querytest -query.sql:1:1: sqlite3: SQL logic error: no such column: p.id +query.sql:4:9: table alias "p" does not exist diff --git a/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/query.sql b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/query.sql new file mode 100644 index 0000000000..ad3f50c3a3 --- /dev/null +++ b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/query.sql @@ -0,0 +1,9 @@ +-- name: GetByPublicID :one +SELECT * +FROM locations l +WHERE l.public_id = ? +AND EXISTS ( + SELECT 1 + FROM projects p + WHERE p.id = location.project_id +); diff --git a/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/schema.sql b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/schema.sql new file mode 100644 index 0000000000..102bbef7cf --- /dev/null +++ b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/schema.sql @@ -0,0 +1,20 @@ +CREATE TABLE organizations ( + id INTEGER PRIMARY KEY +); + +CREATE TABLE organization_members ( + id INTEGER PRIMARY KEY, + organization_id INTEGER NOT NULL, + account_id INTEGER NOT NULL +); + +CREATE TABLE projects ( + id INTEGER PRIMARY KEY, + organization_id INTEGER NOT NULL +); + +CREATE TABLE locations ( + id INTEGER PRIMARY KEY, + public_id TEXT UNIQUE NOT NULL, + project_id INTEGER NOT NULL +) STRICT; diff --git a/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml new file mode 100644 index 0000000000..3fe71d6f20 --- /dev/null +++ b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/sqlc.yaml @@ -0,0 +1,12 @@ +version: "2" + +sql: + - engine: "sqlite" + schema: "schema.sql" + queries: "query.sql" + rules: + - sqlc/db-prepare + gen: + go: + package: "db" + out: "generated" diff --git a/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/stderr.txt b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/stderr.txt new file mode 100644 index 0000000000..496a56dbec --- /dev/null +++ b/internal/endtoend/testdata/managed-db/sqlite_invalid_correlated_ref/sqlite/stderr.txt @@ -0,0 +1,2 @@ +# package db +query.sql:8:18: table alias "location" does not exist diff --git a/internal/engine/sqlite/analyzer/analyze.go b/internal/engine/sqlite/analyzer/analyze.go index 3af9f99a30..a9f4725607 100644 --- a/internal/engine/sqlite/analyzer/analyze.go +++ b/internal/engine/sqlite/analyzer/analyze.go @@ -17,6 +17,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" + "github.com/sqlc-dev/sqlc/internal/sql/validate" ) type Analyzer struct { @@ -76,6 +77,16 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat } } + // SQLite-specific validation + toValidate := n + if raw, ok := n.(*ast.RawStmt); ok && raw != nil && raw.Stmt != nil { + toValidate = raw.Stmt + } + + if err := validate.ValidateSQLiteQualifiedColumnRefs(toValidate); err != nil { + return nil, err + } + // Prepare the statement to get column and parameter information stmt, _, err := a.conn.Prepare(query) if err != nil { diff --git a/internal/sql/validate/sqlite_qualified_refs.go b/internal/sql/validate/sqlite_qualified_refs.go new file mode 100644 index 0000000000..183040b925 --- /dev/null +++ b/internal/sql/validate/sqlite_qualified_refs.go @@ -0,0 +1,217 @@ +package validate + +import ( + "fmt" + "reflect" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +// ValidateSQLiteQualifiedColumnRefs validates that qualified column references +// only use visible tables/aliases in the current or outer SELECT scopes. +func ValidateSQLiteQualifiedColumnRefs(root ast.Node) error { + return validateNodeSQLite(root, nil) +} + +type scope struct { + parent *scope + names map[string]struct{} +} + +func newScope(parent *scope) *scope { + return &scope{parent: parent, names: map[string]struct{}{}} +} + +func (s *scope) add(name string) { + if name == "" { + return + } + s.names[name] = struct{}{} +} + +func (s *scope) has(name string) bool { + for cur := s; cur != nil; cur = cur.parent { + if _, ok := cur.names[name]; ok { + return true + } + } + return false +} + +func stringSlice(list *ast.List) []string { + if list == nil { + return nil + } + out := make([]string, 0, len(list.Items)) + for _, it := range list.Items { + if s, ok := it.(*ast.String); ok { + out = append(out, s.Str) + } + } + return out +} + +func qualifierFromColumnRef(ref *ast.ColumnRef) (string, bool) { + if ref == nil || ref.Fields == nil { + return "", false + } + items := stringSlice(ref.Fields) + switch len(items) { + case 2: + return items[0], true + case 3: + return items[1], true + default: + return "", false + } +} + +func addFromItemToScope(sc *scope, n ast.Node) { + switch t := n.(type) { + case *ast.RangeVar: + if t.Relname != nil { + sc.add(*t.Relname) + } + if t.Alias != nil && t.Alias.Aliasname != nil { + sc.add(*t.Alias.Aliasname) + } + case *ast.JoinExpr: + addFromItemToScope(sc, t.Larg) + addFromItemToScope(sc, t.Rarg) + case *ast.RangeSubselect: + if t.Alias != nil && t.Alias.Aliasname != nil { + sc.add(*t.Alias.Aliasname) + } + case *ast.RangeFunction: + if t.Alias != nil && t.Alias.Aliasname != nil { + sc.add(*t.Alias.Aliasname) + } + } +} + +func validateNodeSQLite(node ast.Node, parent *scope) error { + switch n := node.(type) { + case *ast.SelectStmt: + sc := newScope(parent) + if n.FromClause != nil { + for _, item := range n.FromClause.Items { + addFromItemToScope(sc, item) + } + } + return walkSQLite(n, sc) + default: + return nil + } +} + +func walkSQLite(node ast.Node, sc *scope) error { + if node == nil { + return nil + } + + if ref, ok := node.(*ast.ColumnRef); ok { + if qual, ok := qualifierFromColumnRef(ref); ok && !sc.has(qual) { + return &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("table alias %q does not exist", qual), + Location: ref.Location, + } + } + } + + switch n := node.(type) { + case *ast.SubLink: + if n.Subselect != nil { + return validateNodeSQLite(n.Subselect, sc) + } + return nil + case *ast.RangeSubselect: + if n.Subquery != nil { + return validateNodeSQLite(n.Subquery, sc) + } + return nil + } + + return walkSQLiteReflect(node, sc) +} + +func walkSQLiteReflect(node ast.Node, sc *scope) error { + v := reflect.ValueOf(node) + if v.Kind() == reflect.Pointer { + if v.IsNil() { + return nil + } + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil + } + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + if t.Field(i).PkgPath != "" { + continue + } + f := v.Field(i) + if !f.IsValid() { + continue + } + + for f.Kind() == reflect.Pointer { + if f.IsNil() { + goto next + } + f = f.Elem() + } + + if f.Type() == reflect.TypeOf(ast.List{}) { + list := f.Addr().Interface().(*ast.List) + for _, n := range list.Items { + if err := walkSQLite(n, sc); err != nil { + return err + } + } + continue + } + + if f.CanAddr() { + if pl, ok := f.Addr().Interface().(**ast.List); ok && *pl != nil { + for _, n := range (*pl).Items { + if err := walkSQLite(n, sc); err != nil { + return err + } + } + continue + } + } + + if f.CanInterface() { + if n, ok := f.Interface().(ast.Node); ok { + if err := walkSQLite(n, sc); err != nil { + return err + } + continue + } + } + + if f.Kind() == reflect.Slice { + for j := 0; j < f.Len(); j++ { + elem := f.Index(j) + if elem.Kind() == reflect.Pointer && elem.IsNil() { + continue + } + if elem.CanInterface() { + if n, ok := elem.Interface().(ast.Node); ok { + if err := walkSQLite(n, sc); err != nil { + return err + } + } + } + } + } + + next: + } + return nil +}