Skip to content

Commit 76f4666

Browse files
authored
chore: Upgrade to aws-sdk-go-v2 (#37)
The conversion is essentially complete. The only gap is in the implementation of `Rows.Next()`, where a context is not available from the caller for use when fetching more results from Athena, because the function is defined in its interface to not include a context. Related to contexts, the driver does not yet implement `DriverContext`, but should be, separately.
1 parent 3d220fb commit 76f4666

File tree

9 files changed

+190
-121
lines changed

9 files changed

+190
-121
lines changed

api.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package athena
2+
3+
import (
4+
"context"
5+
6+
"github.com/aws/aws-sdk-go-v2/service/athena"
7+
)
8+
9+
type athenaAPI interface {
10+
GetQueryExecution(context.Context, *athena.GetQueryExecutionInput, ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error)
11+
GetQueryResults(context.Context, *athena.GetQueryResultsInput, ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error)
12+
StartQueryExecution(context.Context, *athena.StartQueryExecutionInput, ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error)
13+
StopQueryExecution(context.Context, *athena.StopQueryExecutionInput, ...func(*athena.Options)) (*athena.StopQueryExecutionOutput, error)
14+
}

conn.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ import (
66
"errors"
77
"time"
88

9-
"github.com/aws/aws-sdk-go/aws"
10-
"github.com/aws/aws-sdk-go/service/athena"
11-
"github.com/aws/aws-sdk-go/service/athena/athenaiface"
9+
"github.com/aws/aws-sdk-go-v2/aws"
10+
"github.com/aws/aws-sdk-go-v2/service/athena"
11+
"github.com/aws/aws-sdk-go-v2/service/athena/types"
1212
)
1313

1414
type conn struct {
15-
athena athenaiface.AthenaAPI
15+
athena athenaAPI
1616
db string
1717
OutputLocation string
1818

@@ -38,7 +38,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
3838
}
3939

4040
func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) {
41-
queryID, err := c.startQuery(query)
41+
queryID, err := c.startQuery(ctx, query)
4242
if err != nil {
4343
return nil, err
4444
}
@@ -47,7 +47,7 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
4747
return nil, err
4848
}
4949

50-
return newRows(rowsConfig{
50+
return newRows(ctx, rowsConfig{
5151
Athena: c.athena,
5252
QueryID: queryID,
5353
// todo add check for ddl queries to not skip header(#10)
@@ -56,13 +56,13 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error)
5656
}
5757

5858
// startQuery starts an Athena query and returns its ID.
59-
func (c *conn) startQuery(query string) (string, error) {
60-
resp, err := c.athena.StartQueryExecution(&athena.StartQueryExecutionInput{
59+
func (c *conn) startQuery(ctx context.Context, query string) (string, error) {
60+
resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{
6161
QueryString: aws.String(query),
62-
QueryExecutionContext: &athena.QueryExecutionContext{
62+
QueryExecutionContext: &types.QueryExecutionContext{
6363
Database: aws.String(c.db),
6464
},
65-
ResultConfiguration: &athena.ResultConfiguration{
65+
ResultConfiguration: &types.ResultConfiguration{
6666
OutputLocation: aws.String(c.OutputLocation),
6767
},
6868
})
@@ -76,28 +76,28 @@ func (c *conn) startQuery(query string) (string, error) {
7676
// waitOnQuery blocks until a query finishes, returning an error if it failed.
7777
func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
7878
for {
79-
statusResp, err := c.athena.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
79+
statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{
8080
QueryExecutionId: aws.String(queryID),
8181
})
8282
if err != nil {
8383
return err
8484
}
8585

86-
switch *statusResp.QueryExecution.Status.State {
87-
case athena.QueryExecutionStateCancelled:
86+
switch statusResp.QueryExecution.Status.State {
87+
case types.QueryExecutionStateCancelled:
8888
return context.Canceled
89-
case athena.QueryExecutionStateFailed:
89+
case types.QueryExecutionStateFailed:
9090
reason := *statusResp.QueryExecution.Status.StateChangeReason
9191
return errors.New(reason)
92-
case athena.QueryExecutionStateSucceeded:
92+
case types.QueryExecutionStateSucceeded:
9393
return nil
94-
case athena.QueryExecutionStateQueued:
95-
case athena.QueryExecutionStateRunning:
94+
case types.QueryExecutionStateQueued:
95+
case types.QueryExecutionStateRunning:
9696
}
9797

9898
select {
9999
case <-ctx.Done():
100-
c.athena.StopQueryExecution(&athena.StopQueryExecutionInput{
100+
c.athena.StopQueryExecution(ctx, &athena.StopQueryExecutionInput{
101101
QueryExecutionId: aws.String(queryID),
102102
})
103103

@@ -109,7 +109,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
109109
}
110110

111111
func (c *conn) Prepare(query string) (driver.Stmt, error) {
112-
panic("Athena doesn't support prepared statements")
112+
panic("The go-athena driver doesn't support prepared statements yet")
113113
}
114114

115115
func (c *conn) Begin() (driver.Tx, error) {

db_test.go

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import (
1111
"testing"
1212
"time"
1313

14-
"github.com/aws/aws-sdk-go/aws"
15-
"github.com/aws/aws-sdk-go/aws/session"
16-
"github.com/aws/aws-sdk-go/service/s3"
14+
"github.com/aws/aws-sdk-go-v2/aws"
15+
"github.com/aws/aws-sdk-go-v2/config"
16+
"github.com/aws/aws-sdk-go-v2/service/s3"
1717
uuid "github.com/satori/go.uuid"
1818
"github.com/stretchr/testify/assert"
1919
"github.com/stretchr/testify/require"
@@ -35,8 +35,9 @@ func init() {
3535
}
3636

3737
func TestQuery(t *testing.T) {
38-
harness := setup(t)
39-
// defer harness.teardown()
38+
ctx := context.Background()
39+
harness := setup(ctx, t)
40+
// defer harness.teardown(ctx)
4041

4142
expected := []dummyRow{
4243
{
@@ -77,9 +78,9 @@ func TestQuery(t *testing.T) {
7778
},
7879
}
7980
expectedTypeNames := []string{"varchar", "smallint", "integer", "bigint", "boolean", "float", "double", "varchar", "timestamp", "date", "decimal"}
80-
harness.uploadData(expected)
81+
harness.uploadData(ctx, expected)
8182

82-
rows := harness.mustQuery("select * from %s", harness.table)
83+
rows := harness.mustQuery(ctx, "select * from %s", harness.table)
8384
index := -1
8485
for rows.Next() {
8586
index++
@@ -115,8 +116,10 @@ func TestQuery(t *testing.T) {
115116
}
116117

117118
func TestOpen(t *testing.T) {
118-
db, err := Open(Config{
119-
Session: session.Must(session.NewSession()),
119+
awsConfig, err := config.LoadDefaultConfig(context.Background())
120+
require.NoError(t, err, "LoadDefaultConfig")
121+
db, err := Open(DriverConfig{
122+
Config: &awsConfig,
120123
Database: AthenaDatabase,
121124
OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket),
122125
})
@@ -143,28 +146,29 @@ type dummyRow struct {
143146
type athenaHarness struct {
144147
t *testing.T
145148
db *sql.DB
146-
s3 *s3.S3
149+
s3 *s3.Client
147150

148151
table string
149152
}
150153

151-
func setup(t *testing.T) *athenaHarness {
152-
harness := athenaHarness{t: t, s3: s3.New(session.New())}
154+
func setup(ctx context.Context, t *testing.T) *athenaHarness {
155+
awsConfig, err := config.LoadDefaultConfig(ctx)
156+
require.NoError(t, err)
157+
harness := athenaHarness{t: t, s3: s3.NewFromConfig(awsConfig)}
153158

154-
var err error
155159
harness.db, err = sql.Open("athena", fmt.Sprintf("db=%s&output_location=s3://%s/output", AthenaDatabase, S3Bucket))
156160
require.NoError(t, err)
157161

158-
harness.setupTable()
162+
harness.setupTable(ctx)
159163

160164
return &harness
161165
}
162166

163-
func (a *athenaHarness) setupTable() {
167+
func (a *athenaHarness) setupTable(ctx context.Context) {
164168
// tables cannot start with numbers or contain dashes
165169
id := uuid.NewV4()
166170
a.table = "t_" + strings.Replace(id.String(), "-", "_", -1)
167-
a.mustExec(`CREATE EXTERNAL TABLE %[1]s (
171+
a.mustExec(ctx, `CREATE EXTERNAL TABLE %[1]s (
168172
nullValue string,
169173
smallintType smallint,
170174
intType int,
@@ -184,32 +188,32 @@ WITH SERDEPROPERTIES (
184188
fmt.Printf("created table: %s", a.table)
185189
}
186190

187-
func (a *athenaHarness) teardown() {
188-
a.mustExec("drop table %s", a.table)
191+
func (a *athenaHarness) teardown(ctx context.Context) {
192+
a.mustExec(ctx, "drop table %s", a.table)
189193
}
190194

191-
func (a *athenaHarness) mustExec(sql string, args ...interface{}) {
195+
func (a *athenaHarness) mustExec(ctx context.Context, sql string, args ...interface{}) {
192196
query := fmt.Sprintf(sql, args...)
193-
_, err := a.db.ExecContext(context.TODO(), query)
197+
_, err := a.db.ExecContext(ctx, query)
194198
require.NoError(a.t, err, query)
195199
}
196200

197-
func (a *athenaHarness) mustQuery(sql string, args ...interface{}) *sql.Rows {
201+
func (a *athenaHarness) mustQuery(ctx context.Context, sql string, args ...interface{}) *sql.Rows {
198202
query := fmt.Sprintf(sql, args...)
199-
rows, err := a.db.QueryContext(context.TODO(), query)
203+
rows, err := a.db.QueryContext(ctx, query)
200204
require.NoError(a.t, err, query)
201205
return rows
202206
}
203207

204-
func (a *athenaHarness) uploadData(rows []dummyRow) {
208+
func (a *athenaHarness) uploadData(ctx context.Context, rows []dummyRow) {
205209
var buf bytes.Buffer
206210
enc := json.NewEncoder(&buf)
207211
for _, row := range rows {
208212
err := enc.Encode(row)
209213
require.NoError(a.t, err)
210214
}
211215

212-
_, err := a.s3.PutObject(&s3.PutObjectInput{
216+
_, err := a.s3.PutObject(ctx, &s3.PutObjectInput{
213217
Bucket: aws.String(S3Bucket),
214218
Key: aws.String(fmt.Sprintf("%s/fixture.json", a.table)),
215219
Body: bytes.NewReader(buf.Bytes()),

driver.go

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package athena
22

33
import (
4+
"context"
45
"database/sql"
56
"database/sql/driver"
67
"errors"
@@ -9,9 +10,9 @@ import (
910
"sync"
1011
"time"
1112

12-
"github.com/aws/aws-sdk-go/aws"
13-
"github.com/aws/aws-sdk-go/aws/session"
14-
"github.com/aws/aws-sdk-go/service/athena"
13+
"github.com/aws/aws-sdk-go-v2/aws"
14+
"github.com/aws/aws-sdk-go-v2/config"
15+
"github.com/aws/aws-sdk-go-v2/service/athena"
1516
)
1617

1718
var (
@@ -21,15 +22,15 @@ var (
2122

2223
// Driver is a sql.Driver. It's intended for db/sql.Open().
2324
type Driver struct {
24-
cfg *Config
25+
cfg *DriverConfig
2526
}
2627

2728
// NewDriver allows you to register your own driver with `sql.Register`.
2829
// It's useful for more complex use cases. Read more in PR #3.
2930
// https://github.com/segmentio/go-athena/pull/3
3031
//
3132
// Generally, sql.Open() or athena.Open() should suffice.
32-
func NewDriver(cfg *Config) *Driver {
33+
func NewDriver(cfg *DriverConfig) *Driver {
3334
return &Driver{cfg}
3435
}
3536

@@ -65,7 +66,8 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
6566
cfg := d.cfg
6667
if cfg == nil {
6768
var err error
68-
cfg, err = configFromConnectionString(connStr)
69+
// TODO: Implement DriverContext to get proper access to context
70+
cfg, err = configFromConnectionString(context.TODO(), connStr)
6971
if err != nil {
7072
return nil, err
7173
}
@@ -76,7 +78,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
7678
}
7779

7880
return &conn{
79-
athena: athena.New(cfg.Session),
81+
athena: athena.NewFromConfig(*cfg.Config),
8082
db: cfg.Database,
8183
OutputLocation: cfg.OutputLocation,
8284
pollFrequency: cfg.PollFrequency,
@@ -86,7 +88,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
8688
// Open is a more robust version of `db.Open`, as it accepts a raw aws.Session.
8789
// This is useful if you have a complex AWS session since the driver doesn't
8890
// currently attempt to serialize all options into a string.
89-
func Open(cfg Config) (*sql.DB, error) {
91+
func Open(cfg DriverConfig) (*sql.DB, error) {
9092
if cfg.Database == "" {
9193
return nil, errors.New("db is required")
9294
}
@@ -95,8 +97,8 @@ func Open(cfg Config) (*sql.DB, error) {
9597
return nil, errors.New("s3_staging_url is required")
9698
}
9799

98-
if cfg.Session == nil {
99-
return nil, errors.New("session is required")
100+
if cfg.Config == nil {
101+
return nil, errors.New("AWS config is required")
100102
}
101103

102104
// This hack was copied from jackc/pgx. Sorry :(
@@ -111,30 +113,30 @@ func Open(cfg Config) (*sql.DB, error) {
111113
}
112114

113115
// Config is the input to Open().
114-
type Config struct {
115-
Session *session.Session
116+
type DriverConfig struct {
117+
Config *aws.Config
116118
Database string
117119
OutputLocation string
118120

119121
PollFrequency time.Duration
120122
}
121123

122-
func configFromConnectionString(connStr string) (*Config, error) {
124+
func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) {
123125
args, err := url.ParseQuery(connStr)
124126
if err != nil {
125127
return nil, err
126128
}
127129

128-
var cfg Config
130+
var cfg DriverConfig
129131

130-
var acfg []*aws.Config
131-
if region := args.Get("region"); region != "" {
132-
acfg = append(acfg, &aws.Config{Region: aws.String(region)})
133-
}
134-
cfg.Session, err = session.NewSession(acfg...)
132+
awsConfig, err := config.LoadDefaultConfig(ctx)
135133
if err != nil {
136134
return nil, err
137135
}
136+
if region := args.Get("region"); region != "" {
137+
awsConfig.Region = region
138+
}
139+
cfg.Config = &awsConfig
138140

139141
cfg.Database = args.Get("db")
140142
cfg.OutputLocation = args.Get("output_location")

go.mod

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,31 @@ module github.com/segmentio/go-athena
33
go 1.21
44

55
require (
6-
github.com/aws/aws-sdk-go v1.55.5
6+
github.com/aws/aws-sdk-go-v2 v1.30.4
7+
github.com/aws/aws-sdk-go-v2/config v1.27.30
8+
github.com/aws/aws-sdk-go-v2/service/athena v1.44.5
9+
github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1
710
github.com/satori/go.uuid v1.2.0
811
github.com/stretchr/testify v1.9.0
912
)
1013

1114
require (
15+
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect
16+
github.com/aws/aws-sdk-go-v2/credentials v1.17.29 // indirect
17+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect
18+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect
19+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect
20+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
21+
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.16 // indirect
22+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect
23+
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect
24+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect
25+
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect
26+
github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect
27+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect
28+
github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect
29+
github.com/aws/smithy-go v1.20.4 // indirect
1230
github.com/davecgh/go-spew v1.1.1 // indirect
13-
github.com/jmespath/go-jmespath v0.4.0 // indirect
1431
github.com/pmezard/go-difflib v1.0.0 // indirect
1532
gopkg.in/yaml.v3 v3.0.1 // indirect
1633
)

0 commit comments

Comments
 (0)