@@ -6,6 +6,12 @@ package db
66import (
77 "context"
88 "database/sql"
9+ "errors"
10+ "runtime"
11+ "slices"
12+ "sync"
13+
14+ "code.gitea.io/gitea/modules/setting"
915
1016 "xorm.io/builder"
1117 "xorm.io/xorm"
@@ -15,76 +21,90 @@ import (
1521// will be overwritten by Init with HammerContext
1622var DefaultContext context.Context
1723
18- // contextKey is a value for use with context.WithValue.
19- type contextKey struct {
20- name string
21- }
24+ type engineContextKeyType struct {}
2225
23- // enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
24- var (
25- enginedContextKey = & contextKey {"engined" }
26- _ Engined = & Context {}
27- )
26+ var engineContextKey = engineContextKeyType {}
2827
2928// Context represents a db context
3029type Context struct {
3130 context.Context
32- e Engine
33- transaction bool
34- }
35-
36- func newContext (ctx context.Context , e Engine , transaction bool ) * Context {
37- return & Context {
38- Context : ctx ,
39- e : e ,
40- transaction : transaction ,
41- }
42- }
43-
44- // InTransaction if context is in a transaction
45- func (ctx * Context ) InTransaction () bool {
46- return ctx .transaction
31+ engine Engine
4732}
4833
49- // Engine returns db engine
50- func (ctx * Context ) Engine () Engine {
51- return ctx .e
34+ func newContext (ctx context.Context , e Engine ) * Context {
35+ return & Context {Context : ctx , engine : e }
5236}
5337
5438// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
5539func (ctx * Context ) Value (key any ) any {
56- if key == enginedContextKey {
40+ if key == engineContextKey {
5741 return ctx
5842 }
5943 return ctx .Context .Value (key )
6044}
6145
6246// WithContext returns this engine tied to this context
6347func (ctx * Context ) WithContext (other context.Context ) * Context {
64- return newContext (ctx , ctx .e .Context (other ), ctx . transaction )
48+ return newContext (ctx , ctx .engine .Context (other ))
6549}
6650
67- // Engined structs provide an Engine
68- type Engined interface {
69- Engine () Engine
51+ var (
52+ contextSafetyOnce sync.Once
53+ contextSafetyDeniedFuncPCs []uintptr
54+ )
55+
56+ func contextSafetyCheck (e Engine ) {
57+ if setting .IsProd && ! setting .IsInTesting {
58+ return
59+ }
60+ if e == nil {
61+ return
62+ }
63+ // Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
64+ contextSafetyOnce .Do (func () {
65+ // try to figure out the bad functions to deny
66+ type m struct {}
67+ _ = e .SQL ("SELECT 1" ).Iterate (& m {}, func (int , any ) error {
68+ callers := make ([]uintptr , 32 )
69+ callerNum := runtime .Callers (1 , callers )
70+ for i := 0 ; i < callerNum ; i ++ {
71+ if funcName := runtime .FuncForPC (callers [i ]).Name (); funcName == "xorm.io/xorm.(*Session).Iterate" {
72+ contextSafetyDeniedFuncPCs = append (contextSafetyDeniedFuncPCs , callers [i ])
73+ }
74+ }
75+ return nil
76+ })
77+ if len (contextSafetyDeniedFuncPCs ) != 1 {
78+ panic (errors .New ("unable to determine the functions to deny" ))
79+ }
80+ })
81+
82+ // it should be very fast: xxxx ns/op
83+ callers := make ([]uintptr , 32 )
84+ callerNum := runtime .Callers (3 , callers ) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
85+ for i := 0 ; i < callerNum ; i ++ {
86+ if slices .Contains (contextSafetyDeniedFuncPCs , callers [i ]) {
87+ panic (errors .New ("using database context in an iterator would cause corrupted results" ))
88+ }
89+ }
7090}
7191
72- // GetEngine will get a db Engine from this context or return an Engine restricted to this context
92+ // GetEngine gets an existing db Engine/Statement or creates a new Session
7393func GetEngine (ctx context.Context ) Engine {
74- if e := getEngine (ctx ); e != nil {
94+ if e := getExistingEngine (ctx ); e != nil {
7595 return e
7696 }
7797 return x .Context (ctx )
7898}
7999
80- // getEngine will get a db Engine from this context or return nil
81- func getEngine (ctx context.Context ) Engine {
82- if engined , ok := ctx .(Engined ); ok {
83- return engined .Engine ()
100+ // getExistingEngine gets an existing db Engine/Statement from this context or returns nil
101+ func getExistingEngine (ctx context.Context ) (e Engine ) {
102+ defer func () { contextSafetyCheck (e ) }()
103+ if engined , ok := ctx .(* Context ); ok {
104+ return engined .engine
84105 }
85- enginedInterface := ctx .Value (enginedContextKey )
86- if enginedInterface != nil {
87- return enginedInterface .(Engined ).Engine ()
106+ if engined , ok := ctx .Value (engineContextKey ).(* Context ); ok {
107+ return engined .engine
88108 }
89109 return nil
90110}
@@ -132,23 +152,23 @@ func (c *halfCommitter) Close() error {
132152// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
133153func TxContext (parentCtx context.Context ) (* Context , Committer , error ) {
134154 if sess , ok := inTransaction (parentCtx ); ok {
135- return newContext (parentCtx , sess , true ), & halfCommitter {committer : sess }, nil
155+ return newContext (parentCtx , sess ), & halfCommitter {committer : sess }, nil
136156 }
137157
138158 sess := x .NewSession ()
139159 if err := sess .Begin (); err != nil {
140- sess .Close ()
160+ _ = sess .Close ()
141161 return nil , nil , err
142162 }
143163
144- return newContext (DefaultContext , sess , true ), sess , nil
164+ return newContext (DefaultContext , sess ), sess , nil
145165}
146166
147167// WithTx represents executing database operations on a transaction, if the transaction exist,
148168// this function will reuse it otherwise will create a new one and close it when finished.
149169func WithTx (parentCtx context.Context , f func (ctx context.Context ) error ) error {
150170 if sess , ok := inTransaction (parentCtx ); ok {
151- err := f (newContext (parentCtx , sess , true ))
171+ err := f (newContext (parentCtx , sess ))
152172 if err != nil {
153173 // rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
154174 _ = sess .Close ()
@@ -165,7 +185,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
165185 return err
166186 }
167187
168- if err := f (newContext (parentCtx , sess , true )); err != nil {
188+ if err := f (newContext (parentCtx , sess )); err != nil {
169189 return err
170190 }
171191
@@ -312,7 +332,7 @@ func InTransaction(ctx context.Context) bool {
312332}
313333
314334func inTransaction (ctx context.Context ) (* xorm.Session , bool ) {
315- e := getEngine (ctx )
335+ e := getExistingEngine (ctx )
316336 if e == nil {
317337 return nil , false
318338 }
0 commit comments