Skip to content

Commit 3436dde

Browse files
committed
fix(compiler): robustly strip psql meta commands without breaking SQL
Replace naive line-based removal with a single-pass state machine that correctly distinguishes psql meta-commands from backslashes in SQL code, literals, and comments. The previous implementation would incorrectly strip any line starting with a backslash, breaking valid SQL containing: - Backslashes in string literals (E'\\n', escape sequences) - Meta-command text in comments or documentation - Dollar-quoted function bodies with backslash content Changes: - Track parsing state for single quotes, dollar quotes, and block comments - Only remove backslash commands at true line starts outside any literal context - Properly handle escaped quotes (''), nested block comments (/* /* */ */) - Support dollar-quoted tags with identifiers ($tag$...$tag$) - Add comprehensive test suite covering: * All documented psql meta-commands (\connect, \set, \d*, etc.) * String literals with backslashes and nested quotes * Dollar-quoted blocks with various tag formats * Nested block comments containing meta-command text * Edge cases: empty input, whitespace-only, missing newlines Performance improvements: - Pre-allocate output buffer with strings.Builder.Grow() - Single pass eliminates redundant string operations - Reduces allocations by avoiding intermediate line slice
1 parent 29bbb81 commit 3436dde

File tree

2 files changed

+303
-8
lines changed

2 files changed

+303
-8
lines changed

internal/compiler/compile.go

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package compiler
22

33
import (
4-
"bufio"
54
"errors"
65
"fmt"
76
"io"
@@ -60,16 +59,153 @@ func (c *Compiler) parseCatalog(schemas []string) error {
6059
}
6160

6261
func removePsqlMetaCommands(contents string) string {
63-
s := bufio.NewScanner(strings.NewReader(contents))
64-
var lines []string
65-
for s.Scan() {
66-
line := s.Text()
67-
if strings.HasPrefix(line, `\`) {
62+
if contents == "" {
63+
return contents
64+
}
65+
var out strings.Builder
66+
out.Grow(len(contents))
67+
68+
lineStart := true
69+
inSingle := false
70+
inDollar := false
71+
var dollarTag string
72+
blockDepth := 0
73+
n := len(contents)
74+
for i := 0; ; {
75+
if lineStart && !inSingle && blockDepth == 0 && !inDollar {
76+
start := i
77+
for i < n {
78+
c := contents[i]
79+
if c == ' ' || c == '\t' || c == '\r' {
80+
i++
81+
continue
82+
}
83+
break
84+
}
85+
if i < n && contents[i] == '\\' {
86+
for i < n && contents[i] != '\n' {
87+
i++
88+
}
89+
if i < n && contents[i] == '\n' {
90+
out.WriteByte('\n')
91+
i++
92+
}
93+
lineStart = true
94+
continue
95+
}
96+
if start < i {
97+
out.WriteString(contents[start:i])
98+
}
99+
if i >= n {
100+
break
101+
}
102+
}
103+
if i >= n {
104+
break
105+
}
106+
c := contents[i]
107+
if inSingle {
108+
out.WriteByte(c)
109+
if c == '\'' {
110+
if i+1 < n && contents[i+1] == '\'' {
111+
out.WriteByte(contents[i+1])
112+
i += 2
113+
lineStart = false
114+
continue
115+
}
116+
inSingle = false
117+
}
118+
if c == '\n' {
119+
lineStart = true
120+
} else {
121+
lineStart = false
122+
}
123+
i++
68124
continue
69125
}
70-
lines = append(lines, line)
126+
if inDollar {
127+
if strings.HasPrefix(contents[i:], dollarTag) {
128+
out.WriteString(dollarTag)
129+
i += len(dollarTag)
130+
inDollar = false
131+
lineStart = false
132+
continue
133+
}
134+
out.WriteByte(c)
135+
if c == '\n' {
136+
lineStart = true
137+
} else {
138+
lineStart = false
139+
}
140+
i++
141+
continue
142+
}
143+
if blockDepth > 0 {
144+
if c == '/' && i+1 < n && contents[i+1] == '*' {
145+
blockDepth++
146+
out.WriteString("/*")
147+
i += 2
148+
lineStart = false
149+
continue
150+
}
151+
if c == '*' && i+1 < n && contents[i+1] == '/' {
152+
blockDepth--
153+
out.WriteString("*/")
154+
i += 2
155+
lineStart = false
156+
continue
157+
}
158+
out.WriteByte(c)
159+
if c == '\n' {
160+
lineStart = true
161+
} else {
162+
lineStart = false
163+
}
164+
i++
165+
continue
166+
}
167+
switch c {
168+
case '\'':
169+
inSingle = true
170+
out.WriteByte(c)
171+
lineStart = false
172+
i++
173+
continue
174+
case '$':
175+
tagEnd := i + 1
176+
for tagEnd < n && isDollarTagChar(contents[tagEnd]) {
177+
tagEnd++
178+
}
179+
if tagEnd < n && contents[tagEnd] == '$' {
180+
dollarTag = contents[i : tagEnd+1]
181+
inDollar = true
182+
out.WriteString(dollarTag)
183+
i = tagEnd + 1
184+
lineStart = false
185+
continue
186+
}
187+
case '/':
188+
if i+1 < n && contents[i+1] == '*' {
189+
blockDepth = 1
190+
out.WriteString("/*")
191+
i += 2
192+
lineStart = false
193+
continue
194+
}
195+
}
196+
out.WriteByte(c)
197+
if c == '\n' {
198+
lineStart = true
199+
} else {
200+
lineStart = false
201+
}
202+
i++
71203
}
72-
return strings.Join(lines, "\n")
204+
return out.String()
205+
}
206+
207+
func isDollarTagChar(b byte) bool {
208+
return b == '_' || (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
73209
}
74210

75211
func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package compiler
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
"testing"
7+
)
8+
9+
var allPsqlMetaCommands = []string{
10+
`\a`, `\bind`, `\bind_named`, `\c`, `\connect`, `\C`, `\cd`, `\close_prepared`, `\conninfo`, `\copy`,
11+
`\copyright`, `\crosstabview`, `\d`, `\da`, `\dA`, `\dAc`, `\dAf`, `\dAo`, `\dAp`, `\db`,
12+
`\dc`, `\dconfig`, `\dC`, `\dd`, `\dD`, `\ddp`, `\dE`, `\di`, `\dm`, `\ds`,
13+
`\dt`, `\dv`, `\des`, `\det`, `\deu`, `\dew`, `\df`, `\dF`, `\dFd`, `\dFp`,
14+
`\dFt`, `\dg`, `\dl`, `\dL`, `\dn`, `\do`, `\dO`, `\dp`, `\dP`, `\drds`,
15+
`\drg`, `\dRp`, `\dRs`, `\dT`, `\du`, `\dx`, `\dX`, `\dy`, `\e`, `\edit`,
16+
`\echo`, `\ef`, `\encoding`, `\ev`, `\f`, `\g`, `\gdesc`, `\getenv`, `\gexec`, `\gset`,
17+
`\gx`, `\h`, `\help`, `\H`, `\html`, `\i`, `\include`, `\if`, `\elif`, `\else`,
18+
`\endif`, `\ir`, `\include_relative`, `\l`, `\list`, `\lo_export`, `\lo_import`, `\lo_list`, `\lo_unlink`, `\o`,
19+
`\out`, `\p`, `\print`, `\parse`, `\password`, `\prompt`, `\pset`, `\q`, `\quit`, `\qecho`,
20+
`\r`, `\reset`, `\restrict`, `\s`, `\set`, `\setenv`, `\sf`, `\sv`, `\startpipeline`, `\sendpipeline`,
21+
`\syncpipeline`, `\endpipeline`, `\flushrequest`, `\flush`, `\getresults`, `\t`, `\T`, `\timing`, `\unrestrict`, `\unset`,
22+
`\w`, `\write`, `\warn`, `\watch`, `\x`, `\z`, `\!`, `\?`, `\;`,
23+
}
24+
25+
func TestRemovePsqlMetaCommands_TableDriven(t *testing.T) {
26+
inDoubleQuoted := "CREATE TABLE \"foo\\bar\" (id int);\nSELECT \"foo\\bar\"." +
27+
"id FROM \"foo\\bar\";\n"
28+
inValidSQL := "CREATE TABLE t (id int);\nINSERT INTO t VALUES (1);\n"
29+
inWhitespaceOnly := " \t "
30+
inNoTrailingNewline := "SELECT 1"
31+
inBackslashNotAtStart := "SELECT '\\not_meta' AS col;\n SELECT '\\still_not_meta';\n"
32+
inDoubleSingleQuotes := "INSERT INTO t VALUES ('It''s fine');\n"
33+
34+
tests := []struct {
35+
name string
36+
in string
37+
want string
38+
}{
39+
{
40+
name: "RemovesTopLevelMetaCommands",
41+
in: "CREATE TABLE public.authors();\n\\connect test\n \\set ON_ERROR_STOP on\nSELECT 1;\n",
42+
want: "CREATE TABLE public.authors();\n\n\nSELECT 1;\n",
43+
},
44+
{
45+
name: "IgnoresBackslashesInStrings",
46+
in: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\\connect nope\n",
47+
want: "SELECT E'\\n' || E'\\' || '\n\\restrict inside';\nSELECT E'\n\\still_string\n';\n\n",
48+
},
49+
{
50+
name: "PreservesDollarQuotedBlocks",
51+
in: "DO $$\n\\this_should_stay\n$$;\n\\connect other\n",
52+
want: "DO $$\n\\this_should_stay\n$$;\n\n",
53+
},
54+
{
55+
name: "IgnoresBlockComments",
56+
in: "/*\n\\comment_not_meta\n*/\n\\set x 1\nSELECT 1;\n",
57+
want: "/*\n\\comment_not_meta\n*/\n\nSELECT 1;\n",
58+
},
59+
{
60+
name: "LeavesValidSqlUntouched",
61+
in: inValidSQL,
62+
want: inValidSQL,
63+
},
64+
{
65+
name: "HandlesEmptyInput",
66+
in: "",
67+
want: "",
68+
},
69+
{
70+
name: "PreservesWhitespaceOnlyInput",
71+
in: inWhitespaceOnly,
72+
want: inWhitespaceOnly,
73+
},
74+
{
75+
name: "PreservesFinalLineWithoutNewline",
76+
in: inNoTrailingNewline,
77+
want: inNoTrailingNewline,
78+
},
79+
{
80+
name: "BackslashInDoubleQuotedIdentifier",
81+
in: inDoubleQuoted,
82+
want: inDoubleQuoted,
83+
},
84+
{
85+
name: "BackslashNotAtLineStart",
86+
in: inBackslashNotAtStart,
87+
want: inBackslashNotAtStart,
88+
},
89+
{
90+
name: "DoubleSingleQuotesRemain",
91+
in: inDoubleSingleQuotes,
92+
want: inDoubleSingleQuotes,
93+
},
94+
{
95+
name: "MetaCommandTextInsideLiteral",
96+
in: `INSERT INTO logs VALUES ('Remember to run \connect later');
97+
SELECT E'\n\connect\n' as literal;` + "\n",
98+
want: `INSERT INTO logs VALUES ('Remember to run \connect later');
99+
SELECT E'\n\connect\n' as literal;` + "\n",
100+
},
101+
{
102+
name: "BlockCommentsPreserveMetaText",
103+
in: `/* outer block begins
104+
/* nested: run \connect test_db for interactive work */
105+
documenting with \connect text shouldn't strip SQL
106+
*/
107+
SELECT 1;
108+
/* Change instructions:
109+
\connect reporting
110+
111+
Reason: run maintenance scripts as reporting user.
112+
*/
113+
\connect should_go
114+
`,
115+
want: `/* outer block begins
116+
/* nested: run \connect test_db for interactive work */
117+
documenting with \connect text shouldn't strip SQL
118+
*/
119+
SELECT 1;
120+
/* Change instructions:
121+
\connect reporting
122+
123+
Reason: run maintenance scripts as reporting user.
124+
*/
125+
126+
`,
127+
},
128+
{
129+
name: "DollarTagWithIdentifier",
130+
in: "DO $foo$\n\\inside\n$foo$;\n\\set should_go\n",
131+
want: "DO $foo$\n\\inside\n$foo$;\n\n",
132+
},
133+
}
134+
135+
for _, tc := range tests {
136+
t.Run(tc.name, func(t *testing.T) {
137+
got := removePsqlMetaCommands(tc.in)
138+
if got != tc.want {
139+
t.Fatalf("unexpected output after stripping meta commands:\nwant=%q\ngot =%q", tc.want, got)
140+
}
141+
})
142+
}
143+
144+
t.Run("CoversDocumentedMetaCommands", func(t *testing.T) {
145+
for _, cmd := range allPsqlMetaCommands {
146+
t.Run(fmt.Sprintf("strip_%s", strings.TrimPrefix(cmd, `\`)), func(t *testing.T) {
147+
input := fmt.Sprintf("%s -- meta command\nSELECT 42;\n", cmd)
148+
got := removePsqlMetaCommands(input)
149+
150+
if strings.Contains(got, cmd+" -- meta command") {
151+
t.Fatalf("meta command %q line was not removed", cmd)
152+
}
153+
if !strings.Contains(got, "SELECT 42;") {
154+
t.Fatalf("SQL content was unexpectedly removed for %q", cmd)
155+
}
156+
})
157+
}
158+
})
159+
}

0 commit comments

Comments
 (0)