Skip to content

Commit 5614ec3

Browse files
committed
WIP
1 parent 5a07a16 commit 5614ec3

File tree

5 files changed

+346
-23
lines changed

5 files changed

+346
-23
lines changed

coordinator/internal/controller/api/auth.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ type AuthController struct {
1919
loginLogic *auth.LoginLogic
2020
}
2121

22+
func NewAuthControllerWithLogic(loginLogic *auth.LoginLogic) *AuthController {
23+
return &AuthController{
24+
loginLogic: loginLogic,
25+
}
26+
}
27+
2228
// NewAuthController returns an LoginController instance
2329
func NewAuthController(db *gorm.DB, cfg *config.Config, vf *verifier.Verifier) *AuthController {
2430
return &AuthController{
@@ -102,10 +108,6 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} {
102108
c.Set(types.ProverName, proverName)
103109
}
104110

105-
if publicKey, ok := claims[types.PublicKey]; ok {
106-
c.Set(types.PublicKey, publicKey)
107-
}
108-
109111
if proverVersion, ok := claims[types.ProverVersion]; ok {
110112
c.Set(types.ProverVersion, proverVersion)
111113
}
@@ -118,5 +120,9 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} {
118120
c.Set(types.ProverProviderTypeKey, providerType)
119121
}
120122

123+
if publicKey, ok := claims[types.PublicKey]; ok {
124+
return publicKey
125+
}
126+
121127
return nil
122128
}

coordinator/internal/controller/proxy/auth.go

Lines changed: 304 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ package proxy
22

33
import (
44
"fmt"
5+
"sync"
56

7+
"context"
8+
"time"
9+
10+
jwt "github.com/appleboy/gin-jwt/v2"
611
"github.com/gin-gonic/gin"
12+
"github.com/scroll-tech/go-ethereum/log"
713

814
"scroll-tech/coordinator/internal/config"
915
"scroll-tech/coordinator/internal/controller/api"
@@ -14,33 +20,319 @@ import (
1420

1521
// AuthController is login API
1622
type AuthController struct {
17-
*api.AuthController
18-
clients Clients
23+
apiLogin *api.AuthController
24+
clients Clients
25+
userTokenCache *UserTokenCache
26+
tokenCacheUpdate chan<- *TokenUpdate
27+
}
28+
29+
type TokenUpdate struct {
30+
PublicKey string
31+
Upstream string
32+
Phase uint
33+
LoginParam types.LoginParameter
34+
CompleteNotify chan<- *types.LoginSchema
35+
}
36+
37+
type UpstreamTokens struct {
38+
LoginData map[string]*types.LoginSchema
39+
LoginPhase uint
40+
NextLoginPhase uint
41+
}
42+
43+
type UserTokenCache struct {
44+
sync.RWMutex
45+
data map[string]UpstreamTokens
46+
}
47+
48+
func newUserTokens() UpstreamTokens {
49+
return UpstreamTokens{
50+
LoginData: make(map[string]*types.LoginSchema),
51+
}
52+
}
53+
54+
func newUserCache() *UserTokenCache {
55+
return &UserTokenCache{data: make(map[string]UpstreamTokens)}
56+
}
57+
58+
// get retrieves UpstreamTokens for a given user key, returns empty if still not exists
59+
func (c *UserTokenCache) Get(userKey string) *UpstreamTokens {
60+
c.RLock()
61+
defer c.RUnlock()
62+
63+
tokens, exists := c.data[userKey]
64+
if !exists {
65+
return nil
66+
}
67+
68+
return &tokens
69+
}
70+
71+
// prepare for a total update via Login request
72+
func (c *UserTokenCache) updatePrepare(userKey string) UpstreamTokens {
73+
c.Lock()
74+
defer c.Unlock()
75+
76+
if _, exists := c.data[userKey]; !exists {
77+
log.Info("initializing user token cache", "userKey", userKey)
78+
c.data[userKey] = newUserTokens()
79+
}
80+
updated := c.data[userKey]
81+
updated.NextLoginPhase = updated.LoginPhase + 1
82+
c.data[userKey] = updated
83+
return updated
1984
}
2085

86+
// partialSet updates a single entry in upstreamTokens for a given user
87+
func (c *UserTokenCache) partialSet(userKey string, upstreamName string, loginSchema *types.LoginSchema, phase uint) {
88+
c.Lock()
89+
defer c.Unlock()
90+
91+
// Get existing tokens or create new map
92+
tokens, exists := c.data[userKey]
93+
if exists && tokens.NextLoginPhase == phase {
94+
// Update the specific upstream entry
95+
tokens.LoginData[upstreamName] = loginSchema
96+
}
97+
}
98+
99+
// LoginParameterWithHardForkName constructs new payload for login
100+
type LoginParameterWithUpstreamTokens struct {
101+
*types.LoginParameter
102+
Tokens UpstreamTokens
103+
}
104+
105+
const upstreamConnTimeout = time.Second * 2
106+
const expireTolerant = 10 * time.Minute
107+
const LoginParamCache = "login_param"
108+
const ProverTypesKey = "prover_types"
109+
const SignatureKey = "prover_signature"
110+
21111
// NewAuthController returns an LoginController instance
22112
func NewAuthController(cfg *config.ProxyConfig, clients Clients, vf *verifier.Verifier) *AuthController {
23113

24114
loginLogic := auth.NewLoginLogicWithSimpleDEduplicator(cfg.ProxyManager.Verifier, vf)
25-
auth := api.NewAuthControllerWithLogic(loginLogic)
26-
return &AuthController{
27-
AuthController: auth,
28-
clients: clients,
115+
116+
// Create the token cache update channel
117+
tokenCacheUpdateChan := make(chan *TokenUpdate)
118+
119+
authController := &AuthController{
120+
apiLogin: api.NewAuthControllerWithLogic(loginLogic),
121+
clients: clients,
122+
userTokenCache: newUserCache(),
123+
tokenCacheUpdate: tokenCacheUpdateChan,
124+
}
125+
126+
// Launch token cache manager in a separate goroutine
127+
go authController.toeknCacheManager(tokenCacheUpdateChan)
128+
129+
return authController
130+
}
131+
132+
func (a *AuthController) doUpdateRequest(ctx context.Context, req *TokenUpdate) (ret *types.LoginSchema) {
133+
if req.CompleteNotify != nil {
134+
defer func(ctx context.Context) {
135+
select {
136+
case <-ctx.Done():
137+
case req.CompleteNotify <- ret:
138+
}
139+
140+
}(ctx)
29141
}
142+
143+
cli := a.clients[req.Upstream]
144+
if cli := cli.Client(ctx); cli != nil {
145+
var err error
146+
if ret, err = cli.ProxyLogin(ctx, req.LoginParam); err == nil {
147+
a.userTokenCache.partialSet(req.PublicKey, req.Upstream, ret, req.Phase)
148+
} else {
149+
log.Error("proxy login failed during token cache update",
150+
"userKey", req.PublicKey,
151+
"upstream", req.Upstream,
152+
"phase", req.Phase,
153+
"error", err)
154+
}
155+
}
156+
return
157+
158+
}
159+
160+
func (a *AuthController) toeknCacheManager(request <-chan *TokenUpdate) {
161+
162+
ctx := context.TODO()
163+
var managerStatusLock sync.Mutex
164+
managerStatus := make(map[string]map[string]uint)
165+
166+
for {
167+
req, ok := <-request
168+
if !ok {
169+
return
170+
}
171+
172+
// ensure the manager request is not outdated
173+
tokens := a.userTokenCache.Get(req.PublicKey)
174+
if tokens == nil {
175+
// Highly not possible, if raise, the reason is unknown, just log the Error
176+
continue
177+
}
178+
phase := tokens.NextLoginPhase
179+
if req.Phase < phase {
180+
// drop the out-dated request
181+
continue
182+
}
183+
184+
// ensure only one login request is launched for the same phase
185+
managerStatusLock.Lock()
186+
stat, ok := managerStatus[req.Upstream]
187+
if !ok {
188+
managerStatus[req.Upstream] = make(map[string]uint)
189+
stat = managerStatus[req.Upstream]
190+
}
191+
if phase, running := stat[req.PublicKey]; running && phase >= req.Phase {
192+
managerStatusLock.Unlock()
193+
continue
194+
} else {
195+
stat[req.PublicKey] = req.Phase
196+
}
197+
managerStatusLock.Unlock()
198+
199+
go a.doUpdateRequest(ctx, req)
200+
201+
}
202+
30203
}
31204

32205
// Login extended the Login hander in api controller
33206
func (a *AuthController) Login(c *gin.Context) (interface{}, error) {
34207

35-
ret, err := a.AuthController.Login(c)
208+
loginRes, err := a.apiLogin.Login(c)
36209
if err != nil {
37210
return nil, err
38211
}
39-
loginParam := ret.(types.LoginParameterWithHardForkName)
40-
// band recursive proxy now ...
41-
if loginParam.Message.ProverProviderType == types.ProverProviderTypeProxy {
42-
return nil, fmt.Errorf("do not allow recursive proxy for login %v", loginParam.Message)
212+
loginParam := loginRes.(types.LoginParameterWithHardForkName)
213+
214+
if loginParam.LoginParameter.Message.ProverProviderType == types.ProverProviderTypeProxy {
215+
return nil, fmt.Errorf("proxy do not support recursive login")
216+
}
217+
218+
tokens := a.userTokenCache.updatePrepare(loginParam.PublicKey)
219+
notifies := make([]chan *types.LoginSchema, len(a.clients))
220+
221+
for n := range a.clients {
222+
223+
// Check if we have a valid cached token that hasn't expired
224+
if knownEntry, existed := tokens.LoginData[n]; existed {
225+
timeRemaining := time.Until(knownEntry.Time)
226+
if timeRemaining > expireTolerant {
227+
// Token is still valid enouth, continue to next client
228+
continue
229+
}
230+
}
231+
232+
notify := make(chan *types.LoginSchema)
233+
notifies = append(notifies, notify)
234+
request := TokenUpdate{
235+
PublicKey: loginParam.PublicKey,
236+
Upstream: n,
237+
Phase: tokens.NextLoginPhase,
238+
LoginParam: loginParam.LoginParameter,
239+
CompleteNotify: notify,
240+
}
241+
defer close(notify)
242+
select {
243+
case <-c.Done():
244+
case a.tokenCacheUpdate <- &request:
245+
}
246+
247+
}
248+
249+
// collect all request's compeletions
250+
for _, chn := range notifies {
251+
select {
252+
case <-c.Done():
253+
case <-chn:
254+
}
255+
}
256+
257+
return LoginParameterWithUpstreamTokens{
258+
LoginParameter: &loginParam.LoginParameter,
259+
Tokens: tokens,
260+
}, nil
261+
}
262+
263+
// PayloadFunc returns jwt.MapClaims with {public key, prover name}.
264+
func (a *AuthController) PayloadFunc(data interface{}) jwt.MapClaims {
265+
v, ok := data.(LoginParameterWithUpstreamTokens)
266+
if !ok {
267+
return jwt.MapClaims{}
268+
}
269+
270+
return jwt.MapClaims{
271+
types.PublicKey: v.PublicKey,
272+
types.ProverName: v.Message.ProverName,
273+
types.ProverVersion: v.Message.ProverVersion,
274+
types.ProverProviderTypeKey: v.Message.ProverProviderType,
275+
SignatureKey: v.Signature,
276+
ProverTypesKey: v.Message.ProverTypes,
277+
}
278+
}
279+
280+
// IdentityHandler replies to client for /login
281+
func (a *AuthController) IdentityHandler(c *gin.Context) interface{} {
282+
claims := jwt.ExtractClaims(c)
283+
loginParam := &types.LoginParameter{}
284+
285+
if proverName, ok := claims[types.ProverName]; ok {
286+
loginParam.Message.ProverName, _ = proverName.(string)
287+
}
288+
289+
if proverVersion, ok := claims[types.ProverVersion]; ok {
290+
loginParam.Message.ProverVersion, _ = proverVersion.(string)
291+
}
292+
293+
if providerType, ok := claims[types.ProverProviderTypeKey]; ok {
294+
num, _ := providerType.(float64)
295+
loginParam.Message.ProverProviderType = types.ProverProviderType(num)
296+
}
297+
298+
if signature, ok := claims[SignatureKey]; ok {
299+
loginParam.Signature, _ = signature.(string)
300+
}
301+
302+
if proverTypes, ok := claims[ProverTypesKey]; ok {
303+
arr, _ := proverTypes.([]any)
304+
for _, elm := range arr {
305+
num, _ := elm.(float64)
306+
loginParam.Message.ProverTypes = append(loginParam.Message.ProverTypes, types.ProverType(num))
307+
}
308+
}
309+
310+
if publicKey, ok := claims[types.PublicKey]; ok {
311+
loginParam.PublicKey, _ = publicKey.(string)
312+
}
313+
314+
if loginParam.PublicKey != "" {
315+
// ensure tokenCache
316+
a.userTokenCache.RLock()
317+
_, exists := a.userTokenCache.data[loginParam.PublicKey]
318+
if !exists {
319+
a.userTokenCache.RUnlock()
320+
a.userTokenCache.Lock()
321+
if _, exists := a.userTokenCache.data[loginParam.PublicKey]; !exists {
322+
log.Info("creating token cache for user after proxy restart",
323+
"publicKey", loginParam.PublicKey,
324+
"proverName", loginParam.Message.ProverName,
325+
"reason", "prover using JWT token from before proxy restart")
326+
a.userTokenCache.data[loginParam.PublicKey] = newUserTokens()
327+
}
328+
a.userTokenCache.Unlock()
329+
} else {
330+
a.userTokenCache.RUnlock()
331+
}
332+
333+
c.Set(LoginParamCache, loginParam)
334+
return loginParam.PublicKey
43335
}
44336

45-
return loginParam, nil
337+
return nil
46338
}

0 commit comments

Comments
 (0)