@@ -21,28 +21,8 @@ type engineContextKeyType struct{}
2121
2222var engineContextKey = engineContextKeyType {}
2323
24- type xormContextType struct {
25- context.Context
26- engine Engine
27- }
28-
29- var xormContext * xormContextType
30-
31- func newContext (ctx context.Context , e Engine ) * xormContextType {
32- return & xormContextType {Context : ctx , engine : e }
33- }
34-
35- // Value shadows Value for context.Context but allows us to get ourselves and an Engined object
36- func (ctx * xormContextType ) Value (key any ) any {
37- if key == engineContextKey {
38- return ctx
39- }
40- return ctx .Context .Value (key )
41- }
42-
43- // WithContext returns this engine tied to this context
44- func (ctx * xormContextType ) WithContext (other context.Context ) * xormContextType {
45- return newContext (ctx , ctx .engine .Context (other ))
24+ func withContextEngine (ctx context.Context , e Engine ) context.Context {
25+ return context .WithValue (ctx , engineContextKey , e )
4626}
4727
4828var (
@@ -81,35 +61,26 @@ func contextSafetyCheck(e Engine) {
8161 callerNum := runtime .Callers (3 , callers ) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
8262 for i := range callerNum {
8363 if slices .Contains (contextSafetyDeniedFuncPCs , callers [i ]) {
84- panic (errors .New ("using database context in an iterator would cause corrupted results" ))
64+ panic (errors .New ("using session context in an iterator would cause corrupted results" ))
8565 }
8666 }
8767}
8868
8969// GetEngine gets an existing db Engine/Statement or creates a new Session
90- func GetEngine (ctx context.Context ) (e Engine ) {
91- defer func () { contextSafetyCheck (e ) }()
92- if e := getExistingEngine (ctx ); e != nil {
93- return e
70+ func GetEngine (ctx context.Context ) Engine {
71+ if engine , ok := ctx .Value (engineContextKey ).(Engine ); ok {
72+ // if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session
73+ contextSafetyCheck (engine )
74+ return engine
9475 }
76+ // no need to do "contextSafetyCheck" because it's a new Session
9577 return xormEngine .Context (ctx )
9678}
9779
9880func GetXORMEngineForTesting () * xorm.Engine {
9981 return xormEngine
10082}
10183
102- // getExistingEngine gets an existing db Engine/Statement from this context or returns nil
103- func getExistingEngine (ctx context.Context ) (e Engine ) {
104- if engined , ok := ctx .(* xormContextType ); ok {
105- return engined .engine
106- }
107- if engined , ok := ctx .Value (engineContextKey ).(* xormContextType ); ok {
108- return engined .engine
109- }
110- return nil
111- }
112-
11384// Committer represents an interface to Commit or Close the Context
11485type Committer interface {
11586 Commit () error
@@ -152,24 +123,23 @@ func (c *halfCommitter) Close() error {
152123// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
153124// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
154125func TxContext (parentCtx context.Context ) (context.Context , Committer , error ) {
155- if sess , ok := inTransaction (parentCtx ); ok {
156- return newContext (parentCtx , sess ), & halfCommitter {committer : sess }, nil
126+ if sess := getTransactionSession (parentCtx ); sess != nil {
127+ return withContextEngine (parentCtx , sess ), & halfCommitter {committer : sess }, nil
157128 }
158129
159130 sess := xormEngine .NewSession ()
160131 if err := sess .Begin (); err != nil {
161132 _ = sess .Close ()
162133 return nil , nil , err
163134 }
164-
165- return newContext (xormContext , sess ), sess , nil
135+ return withContextEngine (parentCtx , sess ), sess , nil
166136}
167137
168138// WithTx represents executing database operations on a transaction, if the transaction exist,
169139// this function will reuse it otherwise will create a new one and close it when finished.
170140func WithTx (parentCtx context.Context , f func (ctx context.Context ) error ) error {
171- if sess , ok := inTransaction (parentCtx ); ok {
172- err := f (newContext (parentCtx , sess ))
141+ if sess := getTransactionSession (parentCtx ); sess != nil {
142+ err := f (withContextEngine (parentCtx , sess ))
173143 if err != nil {
174144 // rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
175145 _ = sess .Close ()
@@ -195,7 +165,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
195165 return err
196166 }
197167
198- if err := f (newContext (parentCtx , sess )); err != nil {
168+ if err := f (withContextEngine (parentCtx , sess )); err != nil {
199169 return err
200170 }
201171
@@ -333,32 +303,15 @@ func CountByBean(ctx context.Context, bean any) (int64, error) {
333303 return GetEngine (ctx ).Count (bean )
334304}
335305
336- // TableName returns the table name according a bean object
337- func TableName (bean any ) string {
338- return xormEngine .TableName (bean )
339- }
340-
341306// InTransaction returns true if the engine is in a transaction otherwise return false
342307func InTransaction (ctx context.Context ) bool {
343- _ , ok := inTransaction (ctx )
344- return ok
308+ return getTransactionSession (ctx ) != nil
345309}
346310
347- func inTransaction (ctx context.Context ) (* xorm.Session , bool ) {
348- e := getExistingEngine (ctx )
349- if e == nil {
350- return nil , false
351- }
352-
353- switch t := e .(type ) {
354- case * xorm.Engine :
355- return nil , false
356- case * xorm.Session :
357- if t .IsInTx () {
358- return t , true
359- }
360- return nil , false
361- default :
362- return nil , false
311+ func getTransactionSession (ctx context.Context ) * xorm.Session {
312+ e , _ := ctx .Value (engineContextKey ).(Engine )
313+ if sess , ok := e .(* xorm.Session ); ok && sess .IsInTx () {
314+ return sess
363315 }
316+ return nil
364317}
0 commit comments