Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,14 @@ type stmtStage struct {
SubStages []stmtStage `json:"subStages"`
}

type QueryStage struct {
State string `json:"state"`
}

type QueryState string

const FINISHED QueryState = "FINISHED"

type jsonFloat64 float64

func (f *jsonFloat64) UnmarshalJSON(data []byte) error {
Expand Down Expand Up @@ -1399,8 +1407,31 @@ func (qr *driverRows) Close() error {
if qr.stmt.user != "" {
hs.Add(trinoUserHeader, qr.stmt.user)
}

ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout)
defer cancel()

// Special handling for single-row scans triggered by QueryRowContext() + Scan():
// The database/sql library automatically calls Close() after Scan() finishes reading
// a single row. We have no way to distinguish this from a multi-row iteration ahead of time.
//
// Without this check, Close() would always send a DELETE request to the Trino server,
// even if the query had already finished. This would cause a USER CANCELED error.
//
// By checking rowindex == 1, we heuristically detect the single-row scan case.
// We then call getQueryState() to verify if the query is already FINISHED.
// If so, we skip the DELETE request, avoiding unnecessary errors on trino side.
if qr.rowindex == 1 {
queryState, err := qr.getQueryState()
if err != nil {
return err
}

if queryState == string(FINISHED) {
return nil
}
}

req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
if err != nil {
return err
Expand All @@ -1418,6 +1449,41 @@ func (qr *driverRows) Close() error {
return qr.err
}

func (qr *driverRows) getQueryState() (string, error) {
hs := make(http.Header)
if qr.stmt.user != "" {
hs.Add(trinoUserHeader, qr.stmt.user)
}

ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout)
defer cancel()

req, err := qr.stmt.conn.newRequest(ctx, "GET", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
if err != nil {
return "", err
}

resp, err := qr.stmt.conn.roundTrip(ctx, req)

if err != nil {
qferr, ok := err.(*ErrQueryFailed)
if ok && qferr.StatusCode == http.StatusNoContent {
qr.nextURI = ""
return string(FINISHED), nil
}
}

var queryStage QueryStage
if err := json.NewDecoder(resp.Body).Decode(&queryStage); err != nil {
resp.Body.Close()
return "", err
}

resp.Body.Close()

return queryStage.State, nil
}

// Columns returns the names of the columns.
func (qr *driverRows) Columns() []string {
if qr.err != nil {
Expand Down
Loading