@@ -69,7 +69,7 @@ type conn struct {
6969}
7070
7171func (c * conn ) IsValid () bool {
72- return ! c . isClosed ()
72+ return c . isReady ()
7373}
7474
7575type currentTx interface {
8787 _ driver.QueryerContext = & conn {}
8888 _ driver.Pinger = & conn {}
8989 _ driver.NamedValueChecker = & conn {}
90- _ driver.SessionResetter = & conn {}
9190 _ driver.Validator = & conn {}
9291)
9392
@@ -99,25 +98,12 @@ func newConn(c *Connector, s table.ClosableSession, opts ...connOption) *conn {
9998 for _ , o := range opts {
10099 o (cc )
101100 }
101+ c .attach (cc )
102102 return cc
103103}
104104
105- func (c * conn ) checkClosed (err error ) error {
106- if err = badconn .Map (err ); xerrors .Is (err , driver .ErrBadConn ) {
107- atomic .StoreUint32 (& c .closed , 1 )
108- }
109- return err
110- }
111-
112- func (c * conn ) isClosed () bool {
113- if atomic .LoadUint32 (& c .closed ) == 1 {
114- return true
115- }
116- if c .session .Status () != table .SessionReady {
117- atomic .StoreUint32 (& c .closed , 1 )
118- return true
119- }
120- return false
105+ func (c * conn ) isReady () bool {
106+ return c .session .Status () == table .SessionReady
121107}
122108
123109func (conn ) CheckNamedValue (v * driver.NamedValue ) (err error ) {
@@ -129,8 +115,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt,
129115 defer func () {
130116 onDone (err )
131117 }()
132- if c . isClosed () {
133- return nil , errClosedConn
118+ if ! c . isReady () {
119+ return nil , badconn . Map ( xerrors . WithStackTrace ( errNotReadyConn ))
134120 }
135121 return & stmt {
136122 conn : c ,
@@ -139,6 +125,10 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt,
139125 }, nil
140126}
141127
128+ func (c * conn ) sinceLastUsage () time.Duration {
129+ return time .Since (time .Unix (atomic .LoadInt64 (& c .lastUsage ), 0 ))
130+ }
131+
142132func (c * conn ) execContext (ctx context.Context , query string , args []driver.NamedValue ) (_ driver.Result , err error ) {
143133 m := queryModeFromContext (ctx , c .defaultQueryMode )
144134 onDone := trace .DatabaseSQLOnConnExec (
@@ -147,7 +137,7 @@ func (c *conn) execContext(ctx context.Context, query string, args []driver.Name
147137 query ,
148138 m .String (),
149139 xcontext .IsIdempotent (ctx ),
150- time . Since ( time . Unix ( atomic . LoadInt64 ( & c . lastUsage ), 0 ) ),
140+ c . sinceLastUsage ( ),
151141 )
152142 defer func () {
153143 atomic .StoreInt64 (& c .lastUsage , time .Now ().Unix ())
@@ -163,38 +153,38 @@ func (c *conn) execContext(ctx context.Context, query string, args []driver.Name
163153 dataQueryOptions (ctx )... ,
164154 )
165155 if err != nil {
166- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
156+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
167157 }
168158 defer func () {
169159 _ = res .Close ()
170160 }()
171161 if err = res .NextResultSetErr (ctx ); ! xerrors .Is (err , nil , io .EOF ) {
172- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
162+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
173163 }
174164 if err = res .Err (); err != nil {
175- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
165+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
176166 }
177167 return driver .ResultNoRows , nil
178168 case SchemeQueryMode :
179169 err = c .session .ExecuteSchemeQuery (ctx , query )
180170 if err != nil {
181- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
171+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
182172 }
183173 return driver .ResultNoRows , nil
184174 case ScriptingQueryMode :
185175 var res result.StreamResult
186176 res , err = c .connector .connection .Scripting ().StreamExecute (ctx , query , toQueryParams (args ))
187177 if err != nil {
188- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
178+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
189179 }
190180 defer func () {
191181 _ = res .Close ()
192182 }()
193183 if err = res .NextResultSetErr (ctx ); ! xerrors .Is (err , nil , io .EOF ) {
194- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
184+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
195185 }
196186 if err = res .Err (); err != nil {
197- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
187+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
198188 }
199189 return driver .ResultNoRows , nil
200190 default :
@@ -203,8 +193,8 @@ func (c *conn) execContext(ctx context.Context, query string, args []driver.Name
203193}
204194
205195func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (_ driver.Result , err error ) {
206- if c . isClosed () {
207- return nil , errClosedConn
196+ if ! c . isReady () {
197+ return nil , badconn . Map ( xerrors . WithStackTrace ( errNotReadyConn ))
208198 }
209199 if c .currentTx != nil {
210200 return c .currentTx .ExecContext (ctx , query , args )
@@ -213,8 +203,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
213203}
214204
215205func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (_ driver.Rows , err error ) {
216- if c . isClosed () {
217- return nil , errClosedConn
206+ if ! c . isReady () {
207+ return nil , badconn . Map ( xerrors . WithStackTrace ( errNotReadyConn ))
218208 }
219209 if c .currentTx != nil {
220210 return c .currentTx .QueryContext (ctx , query , args )
@@ -230,7 +220,7 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
230220 query ,
231221 m .String (),
232222 xcontext .IsIdempotent (ctx ),
233- time . Since ( time . Unix ( atomic . LoadInt64 ( & c . lastUsage ), 0 ) ),
223+ c . sinceLastUsage ( ),
234224 )
235225 defer func () {
236226 atomic .StoreInt64 (& c .lastUsage , time .Now ().Unix ())
@@ -246,10 +236,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
246236 dataQueryOptions (ctx )... ,
247237 )
248238 if err != nil {
249- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
239+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
250240 }
251241 if err = res .Err (); err != nil {
252- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
242+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
253243 }
254244 return & rows {
255245 conn : c ,
@@ -263,10 +253,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
263253 scanQueryOptions (ctx )... ,
264254 )
265255 if err != nil {
266- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
256+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
267257 }
268258 if err = res .Err (); err != nil {
269- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
259+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
270260 }
271261 return & rows {
272262 conn : c ,
@@ -276,7 +266,7 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
276266 var exp table.DataQueryExplanation
277267 exp , err = c .session .Explain (ctx , query )
278268 if err != nil {
279- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
269+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
280270 }
281271 return & single {
282272 values : []sql.NamedArg {
@@ -288,10 +278,10 @@ func (c *conn) queryContext(ctx context.Context, query string, args []driver.Nam
288278 var res result.StreamResult
289279 res , err = c .connector .connection .Scripting ().StreamExecute (ctx , query , toQueryParams (args ))
290280 if err != nil {
291- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
281+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
292282 }
293283 if err = res .Err (); err != nil {
294- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
284+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
295285 }
296286 return & rows {
297287 conn : c ,
@@ -307,25 +297,29 @@ func (c *conn) Ping(ctx context.Context) (err error) {
307297 defer func () {
308298 onDone (err )
309299 }()
310- if c . isClosed () {
311- return errClosedConn
300+ if ! c . isReady () {
301+ return badconn . Map ( xerrors . WithStackTrace ( errNotReadyConn ))
312302 }
313303 if err = c .session .KeepAlive (ctx ); err != nil {
314- return c . checkClosed (xerrors .WithStackTrace (err ))
304+ return badconn . Map (xerrors .WithStackTrace (err ))
315305 }
316306 return nil
317307}
318308
319309func (c * conn ) Close () (err error ) {
320- onDone := trace .DatabaseSQLOnConnClose (c .trace )
321- defer func () {
322- onDone (err )
323- }()
324- err = c .session .Close (context .Background ())
325- if err != nil {
326- return c .checkClosed (xerrors .WithStackTrace (err ))
310+ if atomic .CompareAndSwapUint32 (& c .closed , 0 , 1 ) {
311+ c .connector .detach (c )
312+ onDone := trace .DatabaseSQLOnConnClose (c .trace )
313+ defer func () {
314+ onDone (err )
315+ }()
316+ err = c .session .Close (context .Background ())
317+ if err != nil {
318+ return badconn .Map (xerrors .WithStackTrace (err ))
319+ }
320+ return nil
327321 }
328- return nil
322+ return badconn . Map ( xerrors . WithStackTrace ( errConnClosedEarly ))
329323}
330324
331325func (c * conn ) Prepare (string ) (driver.Stmt , error ) {
@@ -342,8 +336,8 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
342336 defer func () {
343337 onDone (transaction , err )
344338 }()
345- if c . isClosed () {
346- return nil , errClosedConn
339+ if ! c . isReady () {
340+ return nil , badconn . Map ( xerrors . WithStackTrace ( errNotReadyConn ))
347341 }
348342 if c .currentTx != nil {
349343 return nil , xerrors .WithStackTrace (
@@ -357,7 +351,7 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
357351 }
358352 transaction , err = c .session .BeginTransaction (ctx , table .TxSettings (txc ))
359353 if err != nil {
360- return nil , c . checkClosed (xerrors .WithStackTrace (err ))
354+ return nil , badconn . Map (xerrors .WithStackTrace (err ))
361355 }
362356 c .currentTx = & tx {
363357 conn : c ,
@@ -366,13 +360,3 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (_ drive
366360 }
367361 return c .currentTx , nil
368362}
369-
370- func (c * conn ) ResetSession (_ context.Context ) error {
371- if c .currentTx != nil {
372- _ = c .currentTx .Rollback ()
373- }
374- if c .isClosed () {
375- return errClosedConn
376- }
377- return nil
378- }
0 commit comments