diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index f73db44241..4480b49c3a 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -1,7 +1,6 @@ package compiler import ( - "bufio" "errors" "fmt" "io" @@ -60,16 +59,153 @@ func (c *Compiler) parseCatalog(schemas []string) error { } func removePsqlMetaCommands(contents string) string { - s := bufio.NewScanner(strings.NewReader(contents)) - var lines []string - for s.Scan() { - line := s.Text() - if strings.HasPrefix(line, `\`) { + if contents == "" { + return contents + } + var out strings.Builder + out.Grow(len(contents)) + + lineStart := true + inSingle := false + inDollar := false + var dollarTag string + blockDepth := 0 + n := len(contents) + for i := 0; ; { + if lineStart && !inSingle && blockDepth == 0 && !inDollar { + start := i + for i < n { + c := contents[i] + if c == ' ' || c == '\t' || c == '\r' { + i++ + continue + } + break + } + if i < n && contents[i] == '\\' { + for i < n && contents[i] != '\n' { + i++ + } + if i < n && contents[i] == '\n' { + out.WriteByte('\n') + i++ + } + lineStart = true + continue + } + if start < i { + out.WriteString(contents[start:i]) + } + if i >= n { + break + } + } + if i >= n { + break + } + c := contents[i] + if inSingle { + out.WriteByte(c) + if c == '\'' { + if i+1 < n && contents[i+1] == '\'' { + out.WriteByte(contents[i+1]) + i += 2 + lineStart = false + continue + } + inSingle = false + } + if c == '\n' { + lineStart = true + } else { + lineStart = false + } + i++ continue } - lines = append(lines, line) + if inDollar { + if strings.HasPrefix(contents[i:], dollarTag) { + out.WriteString(dollarTag) + i += len(dollarTag) + inDollar = false + lineStart = false + continue + } + out.WriteByte(c) + if c == '\n' { + lineStart = true + } else { + lineStart = false + } + i++ + continue + } + if blockDepth > 0 { + if c == '/' && i+1 < n && contents[i+1] == '*' { + blockDepth++ + out.WriteString("/*") + i += 2 + lineStart = false + continue + } + if c == '*' && i+1 < n && contents[i+1] == '/' { + blockDepth-- + out.WriteString("*/") + i += 2 + lineStart = false + continue + } + out.WriteByte(c) + if c == '\n' { + lineStart = true + } else { + lineStart = false + } + i++ + continue + } + switch c { + case '\'': + inSingle = true + out.WriteByte(c) + lineStart = false + i++ + continue + case '$': + tagEnd := i + 1 + for tagEnd < n && isDollarTagChar(contents[tagEnd]) { + tagEnd++ + } + if tagEnd < n && contents[tagEnd] == '$' { + dollarTag = contents[i : tagEnd+1] + inDollar = true + out.WriteString(dollarTag) + i = tagEnd + 1 + lineStart = false + continue + } + case '/': + if i+1 < n && contents[i+1] == '*' { + blockDepth = 1 + out.WriteString("/*") + i += 2 + lineStart = false + continue + } + } + out.WriteByte(c) + if c == '\n' { + lineStart = true + } else { + lineStart = false + } + i++ } - return strings.Join(lines, "\n") + return out.String() +} + +func isDollarTagChar(b byte) bool { + return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) { diff --git a/internal/compiler/psql_meta_test.go b/internal/compiler/psql_meta_test.go new file mode 100644 index 0000000000..a3e0575129 --- /dev/null +++ b/internal/compiler/psql_meta_test.go @@ -0,0 +1,159 @@ +package compiler + +import ( + "fmt" + "strings" + "testing" +) + +var allPsqlMetaCommands = []string{ + `\a`, `\bind`, `\bind_named`, `\c`, `\connect`, `\C`, `\cd`, `\close_prepared`, `\conninfo`, `\copy`, + `\copyright`, `\crosstabview`, `\d`, `\da`, `\dA`, `\dAc`, `\dAf`, `\dAo`, `\dAp`, `\db`, + `\dc`, `\dconfig`, `\dC`, `\dd`, `\dD`, `\ddp`, `\dE`, `\di`, `\dm`, `\ds`, + `\dt`, `\dv`, `\des`, `\det`, `\deu`, `\dew`, `\df`, `\dF`, `\dFd`, `\dFp`, + `\dFt`, `\dg`, `\dl`, `\dL`, `\dn`, `\do`, `\dO`, `\dp`, `\dP`, `\drds`, + `\drg`, `\dRp`, `\dRs`, `\dT`, `\du`, `\dx`, `\dX`, `\dy`, `\e`, `\edit`, + `\echo`, `\ef`, `\encoding`, `\ev`, `\f`, `\g`, `\gdesc`, `\getenv`, `\gexec`, `\gset`, + `\gx`, `\h`, `\help`, `\H`, `\html`, `\i`, `\include`, `\if`, `\elif`, `\else`, + `\endif`, `\ir`, `\include_relative`, `\l`, `\list`, `\lo_export`, `\lo_import`, `\lo_list`, `\lo_unlink`, `\o`, + `\out`, `\p`, `\print`, `\parse`, `\password`, `\prompt`, `\pset`, `\q`, `\quit`, `\qecho`, + `\r`, `\reset`, `\restrict`, `\s`, `\set`, `\setenv`, `\sf`, `\sv`, `\startpipeline`, `\sendpipeline`, + `\syncpipeline`, `\endpipeline`, `\flushrequest`, `\flush`, `\getresults`, `\t`, `\T`, `\timing`, `\unrestrict`, `\unset`, + `\w`, `\write`, `\warn`, `\watch`, `\x`, `\z`, `\!`, `\?`, `\;`, +} + +func TestRemovePsqlMetaCommands_TableDriven(t *testing.T) { + inDoubleQuoted := "CREATE TABLE \"foo\\bar\" (id int);\nSELECT \"foo\\bar\"." + + "id FROM \"foo\\bar\";\n" + inValidSQL := "CREATE TABLE t (id int);\nINSERT INTO t VALUES (1);\n" + inWhitespaceOnly := " \t " + inNoTrailingNewline := "SELECT 1" + inBackslashNotAtStart := "SELECT '\\not_meta' AS col;\n SELECT '\\still_not_meta';\n" + inDoubleSingleQuotes := "INSERT INTO t VALUES ('It''s fine');\n" + + tests := []struct { + name string + in string + want string + }{ + { + name: "RemovesTopLevelMetaCommands", + in: "CREATE TABLE public.authors();\n\\connect test\n \\set ON_ERROR_STOP on\nSELECT 1;\n", + want: "CREATE TABLE public.authors();\n\n\nSELECT 1;\n", + }, + { + name: "IgnoresBackslashesInStrings", + in: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\\connect nope\n", + want: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\n", + }, + { + name: "PreservesDollarQuotedBlocks", + in: "DO $$\n\\this_should_stay\n$$;\n\\connect other\n", + want: "DO $$\n\\this_should_stay\n$$;\n\n", + }, + { + name: "IgnoresBlockComments", + in: "/*\n\\comment_not_meta\n*/\n\\set x 1\nSELECT 1;\n", + want: "/*\n\\comment_not_meta\n*/\n\nSELECT 1;\n", + }, + { + name: "LeavesValidSqlUntouched", + in: inValidSQL, + want: inValidSQL, + }, + { + name: "HandlesEmptyInput", + in: "", + want: "", + }, + { + name: "PreservesWhitespaceOnlyInput", + in: inWhitespaceOnly, + want: inWhitespaceOnly, + }, + { + name: "PreservesFinalLineWithoutNewline", + in: inNoTrailingNewline, + want: inNoTrailingNewline, + }, + { + name: "BackslashInDoubleQuotedIdentifier", + in: inDoubleQuoted, + want: inDoubleQuoted, + }, + { + name: "BackslashNotAtLineStart", + in: inBackslashNotAtStart, + want: inBackslashNotAtStart, + }, + { + name: "DoubleSingleQuotesRemain", + in: inDoubleSingleQuotes, + want: inDoubleSingleQuotes, + }, + { + name: "MetaCommandTextInsideLiteral", + in: `INSERT INTO logs VALUES ('Remember to run \connect later'); + SELECT E'\n\connect\n' as literal;` + "\n", + want: `INSERT INTO logs VALUES ('Remember to run \connect later'); + SELECT E'\n\connect\n' as literal;` + "\n", + }, + { + name: "BlockCommentsPreserveMetaText", + in: `/* outer block begins +/* nested: run \connect test_db for interactive work */ +documenting with \connect text shouldn't strip SQL +*/ +SELECT 1; +/* Change instructions: +\connect reporting + +Reason: run maintenance scripts as reporting user. +*/ +\connect should_go +`, + want: `/* outer block begins +/* nested: run \connect test_db for interactive work */ +documenting with \connect text shouldn't strip SQL +*/ +SELECT 1; +/* Change instructions: +\connect reporting + +Reason: run maintenance scripts as reporting user. +*/ + +`, + }, + { + name: "DollarTagWithIdentifier", + in: "DO $foo$\n\\inside\n$foo$;\n\\set should_go\n", + want: "DO $foo$\n\\inside\n$foo$;\n\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := removePsqlMetaCommands(tc.in) + if got != tc.want { + t.Fatalf("unexpected output after stripping meta commands:\nwant=%q\ngot =%q", tc.want, got) + } + }) + } + + t.Run("CoversDocumentedMetaCommands", func(t *testing.T) { + for _, cmd := range allPsqlMetaCommands { + t.Run(fmt.Sprintf("strip_%s", strings.TrimPrefix(cmd, `\`)), func(t *testing.T) { + input := fmt.Sprintf("%s -- meta command\nSELECT 42;\n", cmd) + got := removePsqlMetaCommands(input) + + if strings.Contains(got, cmd+" -- meta command") { + t.Fatalf("meta command %q line was not removed", cmd) + } + if !strings.Contains(got, "SELECT 42;") { + t.Fatalf("SQL content was unexpectedly removed for %q", cmd) + } + }) + } + }) +}