diff --git a/trino/trino.go b/trino/trino.go index d6a14a1..463bfb3 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -963,6 +963,14 @@ type stmtStage struct { SubStages []stmtStage `json:"subStages"` } +type QueryStage struct { + State string `json:"state"` +} + +type QueryState string + +const FINISHED QueryState = "FINISHED" + type jsonFloat64 float64 func (f *jsonFloat64) UnmarshalJSON(data []byte) error { @@ -1399,8 +1407,31 @@ func (qr *driverRows) Close() error { if qr.stmt.user != "" { hs.Add(trinoUserHeader, qr.stmt.user) } + ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) defer cancel() + + // Special handling for single-row scans triggered by QueryRowContext() + Scan(): + // The database/sql library automatically calls Close() after Scan() finishes reading + // a single row. We have no way to distinguish this from a multi-row iteration ahead of time. + // + // Without this check, Close() would always send a DELETE request to the Trino server, + // even if the query had already finished. This would cause a USER CANCELED error. + // + // By checking rowindex == 1, we heuristically detect the single-row scan case. + // We then call getQueryState() to verify if the query is already FINISHED. + // If so, we skip the DELETE request, avoiding unnecessary errors on trino side. + if qr.rowindex == 1 { + queryState, err := qr.getQueryState() + if err != nil { + return err + } + + if queryState == string(FINISHED) { + return nil + } + } + req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) if err != nil { return err @@ -1418,6 +1449,41 @@ func (qr *driverRows) Close() error { return qr.err } +func (qr *driverRows) getQueryState() (string, error) { + hs := make(http.Header) + if qr.stmt.user != "" { + hs.Add(trinoUserHeader, qr.stmt.user) + } + + ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) + defer cancel() + + req, err := qr.stmt.conn.newRequest(ctx, "GET", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) + if err != nil { + return "", err + } + + resp, err := qr.stmt.conn.roundTrip(ctx, req) + + if err != nil { + qferr, ok := err.(*ErrQueryFailed) + if ok && qferr.StatusCode == http.StatusNoContent { + qr.nextURI = "" + return string(FINISHED), nil + } + } + + var queryStage QueryStage + if err := json.NewDecoder(resp.Body).Decode(&queryStage); err != nil { + resp.Body.Close() + return "", err + } + + resp.Body.Close() + + return queryStage.State, nil +} + // Columns returns the names of the columns. func (qr *driverRows) Columns() []string { if qr.err != nil {