Skip to content

Commit 7d4ff0b

Browse files
authored
Implement support for pgx's CopyFrom (#1352)
1 parent 3edeba8 commit 7d4ff0b

File tree

24 files changed

+422
-26
lines changed

24 files changed

+422
-26
lines changed

docs/guides/development.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ MYSQL_DATABASE dinotest
5757
## Regenerate expected test output
5858

5959
If you need to update a large number of expected test output in the
60-
`internal/endtoend/testdata` directory, run the `regenerate.sh` script.
60+
`internal/endtoend/testdata` directory, run the `regenerate` script.
6161

6262
```
63-
make regen
63+
go build -o ~/go/bin/sqlc-dev ./cmd/sqlc
64+
go run scripts/regenerate/main.go
6465
```
6566

6667
Note that this uses the `sqlc-dev` binary, not `sqlc` so make sure you have an

docs/howto/insert.md

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,41 @@ func (q *Queries) DeleteID(ctx context.Context, id int) (int, error) {
124124
return i, err
125125
}
126126

127-
const deleteAuhtor = `-- name: DeleteAuthor :one
127+
const deleteAuthor = `-- name: DeleteAuthor :one
128128
DELETE FROM authors WHERE id = $1
129129
RETURNING id, bio
130130
`
131131

132132
func (q *Queries) DeleteAuthor(ctx context.Context, id int) (Author, error) {
133-
row := q.db.QueryRowContext(ctx, deleteAuhtor, id)
133+
row := q.db.QueryRowContext(ctx, deleteAuthor, id)
134134
var i Author
135135
err := row.Scan(&i.ID, &i.Bio)
136136
return i, err
137137
}
138138
```
139+
140+
## Using CopyFrom
141+
142+
PostgreSQL supports the Copy Protocol that can insert rows a lot faster than sequential inserts. You can use this easily with sqlc:
143+
144+
```sql
145+
CREATE TABLE authors (
146+
id SERIAL PRIMARY KEY,
147+
name text NOT NULL,
148+
bio text NOT NULL
149+
);
150+
151+
-- name: CreateAuthors :copyfrom
152+
INSERT INTO authors (name, bio) VALUES ($1, $2);
153+
```
154+
155+
```go
156+
type CreateAuthorsParams struct {
157+
Name string
158+
Bio string
159+
}
160+
161+
func (q *Queries) CreateAuthors(ctx context.Context, arg []CreateAuthorsParams) (int64, error) {
162+
return q.db.CopyFrom(ctx, []string{"authors"}, []string{"name", "bio"}, &iteratorForCreateAuthors{rows: arg})
163+
}
164+
```

internal/codegen/golang/field.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import (
99
)
1010

1111
type Field struct {
12-
Name string
12+
Name string // CamelCased name for Go
13+
DBName string // Name as used in the DB
1314
Type string
1415
Tags map[string]string
1516
Comment string

internal/codegen/golang/gen.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package golang
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"fmt"
78
"go/format"
89
"strings"
@@ -11,6 +12,7 @@ import (
1112
"github.com/kyleconroy/sqlc/internal/codegen"
1213
"github.com/kyleconroy/sqlc/internal/compiler"
1314
"github.com/kyleconroy/sqlc/internal/config"
15+
"github.com/kyleconroy/sqlc/internal/metadata"
1416
)
1517

1618
type Generateable interface {
@@ -37,6 +39,7 @@ type tmplCtx struct {
3739
EmitInterface bool
3840
EmitEmptySlices bool
3941
EmitMethodsWithDBArgument bool
42+
UsesCopyFrom bool
4043
}
4144

4245
func (t *tmplCtx) OutputQuery(sourceName string) bool {
@@ -87,6 +90,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
8790
EmitPreparedQueries: golang.EmitPreparedQueries,
8891
EmitEmptySlices: golang.EmitEmptySlices,
8992
EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument,
93+
UsesCopyFrom: usesCopyFrom(queries),
9094
SQLPackage: SQLPackageFromString(golang.SQLPackage),
9195
Q: "`",
9296
Package: golang.Package,
@@ -95,6 +99,10 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
9599
Structs: structs,
96100
}
97101

102+
if tctx.UsesCopyFrom && tctx.SQLPackage != SQLPackagePGX {
103+
return nil, errors.New(":copyfrom is only supported by pgx")
104+
}
105+
98106
output := map[string]string{}
99107

100108
execute := func(name, templateName string) error {
@@ -135,6 +143,8 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
135143
if golang.OutputQuerierFileName != "" {
136144
querierFileName = golang.OutputQuerierFileName
137145
}
146+
copyfromFileName := "copyfrom.go"
147+
// TODO(Jille): Make this configurable.
138148

139149
if err := execute(dbFileName, "dbFile"); err != nil {
140150
return nil, err
@@ -147,6 +157,11 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
147157
return nil, err
148158
}
149159
}
160+
if tctx.UsesCopyFrom {
161+
if err := execute(copyfromFileName, "copyfromFile"); err != nil {
162+
return nil, err
163+
}
164+
}
150165

151166
files := map[string]struct{}{}
152167
for _, gq := range queries {
@@ -160,3 +175,12 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
160175
}
161176
return output, nil
162177
}
178+
179+
func usesCopyFrom(queries []Query) bool {
180+
for _, q := range queries {
181+
if q.Cmd == metadata.CmdCopyFrom {
182+
return true
183+
}
184+
}
185+
return false
186+
}

internal/codegen/golang/imports.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (i *importer) Imports(filename string) [][]ImportSpec {
8989
if i.Settings.Go.OutputQuerierFileName != "" {
9090
querierFileName = i.Settings.Go.OutputQuerierFileName
9191
}
92+
copyfromFileName := "copyfrom.go"
9293

9394
switch filename {
9495
case dbFileName:
@@ -97,6 +98,8 @@ func (i *importer) Imports(filename string) [][]ImportSpec {
9798
return mergeImports(i.modelImports())
9899
case querierFileName:
99100
return mergeImports(i.interfaceImports())
101+
case copyfromFileName:
102+
return mergeImports(i.interfaceImports())
100103
default:
101104
return mergeImports(i.queryImports(filename))
102105
}
@@ -279,9 +282,13 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp
279282

280283
func (i *importer) queryImports(filename string) fileImports {
281284
var gq []Query
285+
anyNonCopyFrom := false
282286
for _, query := range i.Queries {
283287
if query.SourceName == filename {
284288
gq = append(gq, query)
289+
if query.Cmd != metadata.CmdCopyFrom {
290+
anyNonCopyFrom = true
291+
}
285292
}
286293
}
287294

@@ -349,7 +356,9 @@ func (i *importer) queryImports(filename string) fileImports {
349356
return false
350357
}
351358

352-
std["context"] = struct{}{}
359+
if anyNonCopyFrom {
360+
std["context"] = struct{}{}
361+
}
353362

354363
sqlpkg := SQLPackageFromString(i.Settings.Go.SQLPackage)
355364
if sliceScan() && sqlpkg != SQLPackagePGX {

internal/codegen/golang/query.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package golang
22

33
import (
4+
"fmt"
45
"strings"
56

67
"github.com/kyleconroy/sqlc/internal/metadata"
8+
"github.com/kyleconroy/sqlc/internal/sql/ast"
79
)
810

911
type QueryValue struct {
@@ -38,6 +40,13 @@ func (v QueryValue) Pair() string {
3840
return v.Name + " " + v.DefineType()
3941
}
4042

43+
func (v QueryValue) SlicePair() string {
44+
if v.isEmpty() {
45+
return ""
46+
}
47+
return v.Name + " []" + v.DefineType()
48+
}
49+
4150
func (v QueryValue) Type() string {
4251
if v.Typ != "" {
4352
return v.Typ
@@ -105,6 +114,17 @@ func (v QueryValue) Params() string {
105114
return "\n" + strings.Join(out, ",\n")
106115
}
107116

117+
func (v QueryValue) ColumnNames() string {
118+
if v.Struct == nil {
119+
return fmt.Sprintf("[]string{%q}", v.Name)
120+
}
121+
escapedNames := make([]string, len(v.Struct.Fields))
122+
for i, f := range v.Struct.Fields {
123+
escapedNames[i] = fmt.Sprintf("%q", f.DBName)
124+
}
125+
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
126+
}
127+
108128
func (v QueryValue) Scan() string {
109129
var out []string
110130
if v.Struct == nil {
@@ -140,9 +160,21 @@ type Query struct {
140160
SourceName string
141161
Ret QueryValue
142162
Arg QueryValue
163+
// Used for :copyfrom
164+
Table *ast.TableName
143165
}
144166

145167
func (q Query) hasRetType() bool {
146168
scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany
147169
return scanned && !q.Ret.isEmpty()
148170
}
171+
172+
func (q Query) TableIdentifier() string {
173+
escapedNames := make([]string, 0, 3)
174+
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
175+
if p != "" {
176+
escapedNames = append(escapedNames, fmt.Sprintf("%q", p))
177+
}
178+
}
179+
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
180+
}

internal/codegen/golang/result.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
160160
SourceName: query.Filename,
161161
SQL: query.SQL,
162162
Comments: query.Comments,
163+
Table: query.InsertIntoTable,
163164
}
164165
sqlpkg := SQLPackageFromString(settings.Go.SQLPackage)
165166

@@ -295,9 +296,10 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
295296
tags["json:"] = JSONTagName(tagName, settings)
296297
}
297298
gs.Fields = append(gs.Fields, Field{
298-
Name: fieldName,
299-
Type: goType(r, c.Column, settings),
300-
Tags: tags,
299+
Name: fieldName,
300+
DBName: colName,
301+
Type: goType(r, c.Column, settings),
302+
Tags: tags,
301303
})
302304
if _, found := seen[baseFieldName]; !found {
303305
seen[baseFieldName] = []int{i}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{{define "copyfromCodePgx"}}
2+
{{range .GoQueries}}
3+
{{if eq .Cmd ":copyfrom" }}
4+
// iteratorFor{{.MethodName}} implements pgx.CopyFromSource.
5+
type iteratorFor{{.MethodName}} struct {
6+
rows []{{.Arg.DefineType}}
7+
skippedFirstNextCall bool
8+
}
9+
10+
func (r *iteratorFor{{.MethodName}}) Next() bool {
11+
if len(r.rows) == 0 {
12+
return false
13+
}
14+
if !r.skippedFirstNextCall {
15+
r.skippedFirstNextCall = true
16+
return true
17+
}
18+
r.rows = r.rows[1:]
19+
return len(r.rows) > 0
20+
}
21+
22+
func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) {
23+
return []interface{}{
24+
{{- if .Arg.Struct }}
25+
{{- range .Arg.Struct.Fields }}
26+
r.rows[0].{{.Name}},
27+
{{- end }}
28+
{{- else }}
29+
r.rows[0],
30+
{{- end }}
31+
}, nil
32+
}
33+
34+
func (r iteratorFor{{.MethodName}}) Err() error {
35+
return nil
36+
}
37+
38+
{{range .Comments}}//{{.}}
39+
{{end -}}
40+
{{- if $.EmitMethodsWithDBArgument}}
41+
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) {
42+
return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
43+
{{- else}}
44+
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
45+
return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
46+
{{- end}}
47+
}
48+
49+
{{end}}
50+
{{end}}
51+
{{end}}

internal/codegen/golang/templates/pgx/dbCode.tmpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ type DBTX interface {
44
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
55
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
66
QueryRow(context.Context, string, ...interface{}) pgx.Row
7+
{{- if .UsesCopyFrom }}
8+
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
9+
{{- end }}
710
}
811

912
{{ if .EmitMethodsWithDBArgument}}

internal/codegen/golang/templates/pgx/interfaceCode.tmpl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
{{- else if eq .Cmd ":execresult" }}
2828
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error)
2929
{{- end}}
30+
{{- if and (eq .Cmd ":copyfrom") ($dbtxParam) }}
31+
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error)
32+
{{- else if eq .Cmd ":copyfrom" }}
33+
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error)
34+
{{- end}}
3035
{{- end}}
3136
}
3237

0 commit comments

Comments
 (0)