Skip to content

Commit 6acb774

Browse files
aalan3nineinchnick
authored andcommitted
Add support for optional query_timeout param in dsn
1 parent 2a51f32 commit 6acb774

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,16 @@ trino.RegisterCustomClient("otel", otelClient)
239239
db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=otel")
240240
```
241241

242+
##### `queryTimeout`
243+
244+
```
245+
Type: time.Duration
246+
Valid values: duration string
247+
Default: nil
248+
```
249+
250+
The `queryTimeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used.
251+
242252
#### Examples
243253

244254
```

trino/trino.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func init() {
8585

8686
var (
8787
// DefaultQueryTimeout is the default timeout for queries executed without a context.
88-
DefaultQueryTimeout = 60 * time.Second
88+
DefaultQueryTimeout = 10 * time.Hour
8989

9090
// DefaultCancelQueryTimeout is the timeout for the request to cancel queries in Trino.
9191
DefaultCancelQueryTimeout = 30 * time.Second
@@ -187,6 +187,7 @@ type Config struct {
187187
SSLCert string // The SSL cert for TLS verification (optional)
188188
AccessToken string // An access token (JWT) for authentication (optional)
189189
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional)
190+
QueryTimeout *time.Duration // Configurable timeout for query (optional)
190191
}
191192

192193
// FormatDSN returns a DSN string from the configuration.
@@ -266,6 +267,10 @@ func (c *Config) FormatDSN() (string, error) {
266267
sort.Strings(sessionkv)
267268
sort.Strings(credkv)
268269

270+
if c.QueryTimeout != nil {
271+
query.Add("query_timeout", c.QueryTimeout.String())
272+
}
273+
269274
for k, v := range map[string]string{
270275
"catalog": c.Catalog,
271276
"schema": c.Schema,
@@ -295,6 +300,7 @@ type Conn struct {
295300
progressUpdaterPeriod queryProgressCallbackPeriod
296301
useExplicitPrepare bool
297302
forwardAuthorizationHeader bool
303+
queryTimeout *time.Duration
298304
}
299305

300306
var (
@@ -369,6 +375,15 @@ func newConn(dsn string) (*Conn, error) {
369375
}
370376
}
371377

378+
var queryTimeout *time.Duration
379+
if timeoutStr := query.Get("query_timeout"); timeoutStr != "" {
380+
d, err := time.ParseDuration(timeoutStr)
381+
if err != nil {
382+
return nil, fmt.Errorf("trino: invalid timeout: %w", err)
383+
}
384+
queryTimeout = &d
385+
}
386+
372387
c := &Conn{
373388
baseURL: serverURL.Scheme + "://" + serverURL.Host,
374389
httpClient: *httpClient,
@@ -378,6 +393,7 @@ func newConn(dsn string) (*Conn, error) {
378393
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
379394
useExplicitPrepare: useExplicitPrepare,
380395
forwardAuthorizationHeader: forwardAuthorizationHeader,
396+
queryTimeout: queryTimeout,
381397
}
382398

383399
var user string
@@ -963,9 +979,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
963979
}
964980

965981
var cancel context.CancelFunc = func() {}
966-
if _, ok := ctx.Deadline(); !ok {
982+
if st.conn.queryTimeout != nil {
983+
ctx, cancel = context.WithTimeout(ctx, *st.conn.queryTimeout)
984+
} else if _, ok := ctx.Deadline(); !ok {
967985
ctx, cancel = context.WithTimeout(ctx, DefaultQueryTimeout)
968986
}
987+
969988
req, err := st.conn.newRequest(ctx, "POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs)
970989
if err != nil {
971990
cancel()

trino/trino_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,19 @@ func TestRegisterCustomClientReserved(t *testing.T) {
265265
}
266266
}
267267

268+
func TestQueryTimeout(t *testing.T) {
269+
timeout := 10 * time.Second
270+
c := &Config{
271+
ServerURI: "https://foobar@localhost:8090",
272+
QueryTimeout: &timeout,
273+
}
274+
dsn, err := c.FormatDSN()
275+
require.NoError(t, err)
276+
277+
want := "https://foobar@localhost:8090?query_timeout=10s&source=trino-go-client"
278+
assert.Equal(t, want, dsn)
279+
}
280+
268281
func TestRoundTripRetryQueryError(t *testing.T) {
269282
testcases := []struct {
270283
Name string
@@ -1973,3 +1986,45 @@ func TestForwardAuthorizationHeader(t *testing.T) {
19731986

19741987
assert.NoError(t, db.Close())
19751988
}
1989+
1990+
func TestQueryTimeoutDeadline(t *testing.T) {
1991+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1992+
time.Sleep(200 * time.Millisecond) // Simulate slow response
1993+
w.WriteHeader(http.StatusOK)
1994+
}))
1995+
defer ts.Close()
1996+
1997+
testcases := []struct {
1998+
name string
1999+
queryTimeout string
2000+
expectedError string
2001+
}{
2002+
{
2003+
name: "with timeout",
2004+
queryTimeout: "100ms",
2005+
expectedError: "context deadline exceeded",
2006+
},
2007+
{
2008+
name: "without timeout",
2009+
queryTimeout: "10s",
2010+
expectedError: "EOF", // Default server response
2011+
},
2012+
{
2013+
name: "bad timeout",
2014+
queryTimeout: "abc",
2015+
expectedError: "trino: invalid timeout", // Default server response
2016+
},
2017+
}
2018+
2019+
for _, tc := range testcases {
2020+
t.Run(tc.name, func(t *testing.T) {
2021+
println(ts.URL + "?query_timeout=" + tc.queryTimeout)
2022+
db, err := sql.Open("trino", ts.URL+"?query_timeout="+tc.queryTimeout)
2023+
require.NoError(t, err)
2024+
defer db.Close()
2025+
2026+
_, err = db.Query("SELECT 1")
2027+
assert.ErrorContains(t, err, tc.expectedError)
2028+
})
2029+
}
2030+
}

0 commit comments

Comments
 (0)