Skip to content

Commit 10e3f8d

Browse files
authored
internal/dinosql: Overrides can now be basic types (#271)
The `go_type` value in an override can now reference a basic type, such as "string" or "bool".
1 parent d1a6e3c commit 10e3f8d

File tree

4 files changed

+140
-23
lines changed

4 files changed

+140
-23
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
229229

230230
Your favorite PostgreSQL / Go features are supported:
231231
- SQL
232+
- [Query annotations](./docs/annotations.md)
233+
- [Transactions](./docs/transactions.md)
234+
- [Prepared queries](./docs/prepared_query.md)
232235
- [SELECT](./docs/query_one.md)
233236
- [NULL](./docs/null.md)
234237
- [COUNT](./docs/query_count.md)
@@ -237,8 +240,6 @@ Your favorite PostgreSQL / Go features are supported:
237240
- [DELETE](./docs/delete.md)
238241
- [RETURNING](./docs/returning.md)
239242
- [ANY](./docs/any.md)
240-
- [Transactions](./docs/transactions.md)
241-
- [Prepared queries](./docs/prepared_query.md)
242243
- PostgreSQL Types
243244
- [Arrays](./docs/arrays.md)
244245
- [Enums](./docs/enums.md)

internal/dinosql/config.go

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"go/types"
78
"io"
89
"path/filepath"
910
"strings"
@@ -66,10 +67,11 @@ type Override struct {
6667
// fully qualified name of the column, e.g. `accounts.id`
6768
Column string `json:"column"`
6869

69-
columnName string
70-
table pg.FQN
71-
goTypeName string
72-
goPackage string
70+
columnName string
71+
table pg.FQN
72+
goTypeName string
73+
goPackage string
74+
goBasicType bool
7375
}
7476

7577
func (o *Override) Parse() error {
@@ -101,26 +103,49 @@ func (o *Override) Parse() error {
101103

102104
// validate GoType
103105
lastDot := strings.LastIndex(o.GoType, ".")
104-
if lastDot == -1 {
105-
return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType)
106-
}
107106
lastSlash := strings.LastIndex(o.GoType, "/")
108-
if lastSlash == -1 {
109-
return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType)
110-
}
111-
typename := o.GoType[lastSlash+1:]
112-
if strings.HasPrefix(typename, "go-") {
113-
// a package name beginning with "go-" will give syntax errors in
114-
// generated code. We should do the right thing and get the actual
115-
// import name, but in lieu of that, stripping the leading "go-" may get
116-
// us what we want.
117-
typename = typename[len("go-"):]
118-
}
119-
if strings.HasSuffix(typename, "-go") {
120-
typename = typename[:len(typename)-len("-go")]
107+
typename := o.GoType
108+
if lastDot == -1 && lastSlash == -1 {
109+
// if the type name has no slash and no dot, validate that the type is a basic Go type
110+
var found bool
111+
for _, typ := range types.Typ {
112+
info := typ.Info()
113+
if info == 0 {
114+
continue
115+
}
116+
if info&types.IsUntyped != 0 {
117+
continue
118+
}
119+
if typename == typ.Name() {
120+
found = true
121+
}
122+
}
123+
if !found {
124+
return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoType)
125+
}
126+
o.goBasicType = true
127+
} else {
128+
// assume the type lives in a Go package
129+
if lastDot == -1 {
130+
return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType)
131+
}
132+
if lastSlash == -1 {
133+
return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType)
134+
}
135+
typename = o.GoType[lastSlash+1:]
136+
if strings.HasPrefix(typename, "go-") {
137+
// a package name beginning with "go-" will give syntax errors in
138+
// generated code. We should do the right thing and get the actual
139+
// import name, but in lieu of that, stripping the leading "go-" may get
140+
// us what we want.
141+
typename = typename[len("go-"):]
142+
}
143+
if strings.HasSuffix(typename, "-go") {
144+
typename = typename[:len(typename)-len("-go")]
145+
}
146+
o.goPackage = o.GoType[:lastDot]
121147
}
122148
o.goTypeName = typename
123-
o.goPackage = o.GoType[:lastDot]
124149
isPointer := o.GoType[0] == '*'
125150
if isPointer {
126151
o.goPackage = o.goPackage[1:]

internal/dinosql/config_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,88 @@ func TestBadConfigs(t *testing.T) {
6161
})
6262
}
6363
}
64+
65+
func TestTypeOverrides(t *testing.T) {
66+
for _, test := range []struct {
67+
override Override
68+
pkg string
69+
typeName string
70+
basic bool
71+
}{
72+
{
73+
Override{
74+
PostgresType: "uuid",
75+
GoType: "github.com/segmentio/ksuid.KSUID",
76+
},
77+
"github.com/segmentio/ksuid",
78+
"ksuid.KSUID",
79+
false,
80+
},
81+
// TODO: Add test for struct pointers
82+
//
83+
// {
84+
// Override{
85+
// PostgresType: "uuid",
86+
// GoType: "github.com/segmentio/*ksuid.KSUID",
87+
// },
88+
// "github.com/segmentio/ksuid",
89+
// "*ksuid.KSUID",
90+
// false,
91+
// },
92+
{
93+
Override{
94+
PostgresType: "citext",
95+
GoType: "string",
96+
},
97+
"",
98+
"string",
99+
true,
100+
},
101+
} {
102+
tt := test
103+
t.Run(tt.override.GoType, func(t *testing.T) {
104+
if err := tt.override.Parse(); err != nil {
105+
t.Fatalf("override parsing failed; %s", err)
106+
}
107+
if diff := cmp.Diff(tt.typeName, tt.override.goTypeName); diff != "" {
108+
t.Errorf("type name mismatch;\n%s", diff)
109+
}
110+
if diff := cmp.Diff(tt.pkg, tt.override.goPackage); diff != "" {
111+
t.Errorf("package mismatch;\n%s", diff)
112+
}
113+
if diff := cmp.Diff(tt.basic, tt.override.goBasicType); diff != "" {
114+
t.Errorf("basic mismatch;\n%s", diff)
115+
}
116+
})
117+
}
118+
for _, test := range []struct {
119+
override Override
120+
err string
121+
}{
122+
{
123+
Override{
124+
PostgresType: "uuid",
125+
GoType: "Pointer",
126+
},
127+
"Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
128+
},
129+
{
130+
Override{
131+
PostgresType: "uuid",
132+
GoType: "untyped rune",
133+
},
134+
"Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
135+
},
136+
} {
137+
tt := test
138+
t.Run(tt.override.GoType, func(t *testing.T) {
139+
err := tt.override.Parse()
140+
if err == nil {
141+
t.Fatalf("expected pars to fail; got nil")
142+
}
143+
if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
144+
t.Errorf("error mismatch;\n%s", diff)
145+
}
146+
})
147+
}
148+
}

internal/dinosql/gen.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ func ModelImports(r Generateable, settings GenerateSettings) [][]string {
225225
pkg := make(map[string]struct{})
226226
overrideTypes := map[string]string{}
227227
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
228+
if o.goBasicType {
229+
continue
230+
}
228231
overrideTypes[o.goTypeName] = o.goPackage
229232
}
230233

@@ -357,6 +360,9 @@ func QueryImports(r Generateable, settings GenerateSettings, filename string) []
357360
pkg := make(map[string]struct{})
358361
overrideTypes := map[string]string{}
359362
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
363+
if o.goBasicType {
364+
continue
365+
}
360366
overrideTypes[o.goTypeName] = o.goPackage
361367
}
362368

0 commit comments

Comments
 (0)