Skip to content

Commit 322766f

Browse files
committed
WIP
1 parent 5614ec3 commit 322766f

File tree

4 files changed

+176
-38
lines changed

4 files changed

+176
-38
lines changed

coordinator/internal/controller/proxy/auth.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ import (
2020

2121
// AuthController is login API
2222
type AuthController struct {
23-
apiLogin *api.AuthController
24-
clients Clients
25-
userTokenCache *UserTokenCache
26-
tokenCacheUpdate chan<- *TokenUpdate
23+
apiLogin *api.AuthController
24+
clients Clients
25+
userTokenCache *UserTokenCache
2726
}
2827

2928
type TokenUpdate struct {
@@ -42,7 +41,8 @@ type UpstreamTokens struct {
4241

4342
type UserTokenCache struct {
4443
sync.RWMutex
45-
data map[string]UpstreamTokens
44+
data map[string]UpstreamTokens
45+
tokenCacheUpdate chan<- *TokenUpdate
4646
}
4747

4848
func newUserTokens() UpstreamTokens {
@@ -51,8 +51,11 @@ func newUserTokens() UpstreamTokens {
5151
}
5252
}
5353

54-
func newUserCache() *UserTokenCache {
55-
return &UserTokenCache{data: make(map[string]UpstreamTokens)}
54+
func newUserCache(tokenCacheUpdate chan<- *TokenUpdate) *UserTokenCache {
55+
return &UserTokenCache{
56+
data: make(map[string]UpstreamTokens),
57+
tokenCacheUpdate: tokenCacheUpdate,
58+
}
5659
}
5760

5861
// get retrieves UpstreamTokens for a given user key, returns empty if still not exists
@@ -117,10 +120,9 @@ func NewAuthController(cfg *config.ProxyConfig, clients Clients, vf *verifier.Ve
117120
tokenCacheUpdateChan := make(chan *TokenUpdate)
118121

119122
authController := &AuthController{
120-
apiLogin: api.NewAuthControllerWithLogic(loginLogic),
121-
clients: clients,
122-
userTokenCache: newUserCache(),
123-
tokenCacheUpdate: tokenCacheUpdateChan,
123+
apiLogin: api.NewAuthControllerWithLogic(loginLogic),
124+
clients: clients,
125+
userTokenCache: newUserCache(tokenCacheUpdateChan),
124126
}
125127

126128
// Launch token cache manager in a separate goroutine
@@ -129,6 +131,8 @@ func NewAuthController(cfg *config.ProxyConfig, clients Clients, vf *verifier.Ve
129131
return authController
130132
}
131133

134+
func (a *AuthController) TokenCache() *UserTokenCache { return a.userTokenCache }
135+
132136
func (a *AuthController) doUpdateRequest(ctx context.Context, req *TokenUpdate) (ret *types.LoginSchema) {
133137
if req.CompleteNotify != nil {
134138
defer func(ctx context.Context) {
@@ -143,13 +147,13 @@ func (a *AuthController) doUpdateRequest(ctx context.Context, req *TokenUpdate)
143147
cli := a.clients[req.Upstream]
144148
if cli := cli.Client(ctx); cli != nil {
145149
var err error
146-
if ret, err = cli.ProxyLogin(ctx, req.LoginParam); err == nil {
150+
if ret, err = cli.ProxyLogin(ctx, &req.LoginParam); err == nil {
147151
a.userTokenCache.partialSet(req.PublicKey, req.Upstream, ret, req.Phase)
148152
} else {
149-
log.Error("proxy login failed during token cache update",
150-
"userKey", req.PublicKey,
151-
"upstream", req.Upstream,
152-
"phase", req.Phase,
153+
log.Error("proxy login failed during token cache update",
154+
"userKey", req.PublicKey,
155+
"upstream", req.Upstream,
156+
"phase", req.Phase,
153157
"error", err)
154158
}
155159
}
@@ -241,7 +245,7 @@ func (a *AuthController) Login(c *gin.Context) (interface{}, error) {
241245
defer close(notify)
242246
select {
243247
case <-c.Done():
244-
case a.tokenCacheUpdate <- &request:
248+
case a.userTokenCache.tokenCacheUpdate <- &request:
245249
}
246250

247251
}
@@ -319,7 +323,7 @@ func (a *AuthController) IdentityHandler(c *gin.Context) interface{} {
319323
a.userTokenCache.RUnlock()
320324
a.userTokenCache.Lock()
321325
if _, exists := a.userTokenCache.data[loginParam.PublicKey]; !exists {
322-
log.Info("creating token cache for user after proxy restart",
326+
log.Info("creating token cache for user after proxy restart",
323327
"publicKey", loginParam.PublicKey,
324328
"proverName", loginParam.Message.ProverName,
325329
"reason", "prover using JWT token from before proxy restart")

coordinator/internal/controller/proxy/client.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ import (
88
"net/http"
99
"time"
1010

11+
ctypes "scroll-tech/common/types"
1112
"scroll-tech/coordinator/internal/config"
1213
"scroll-tech/coordinator/internal/types"
14+
15+
"github.com/mitchellh/mapstructure"
1316
)
1417

1518
type ClientHelper interface {
@@ -90,19 +93,26 @@ func (c *upClient) Login(ctx context.Context) (*types.LoginSchema, error) {
9093

9194
// Parse login response as LoginSchema and store the token
9295
if loginResp.StatusCode == http.StatusOK {
93-
var loginResult types.LoginSchema
94-
if err := json.NewDecoder(loginResp.Body).Decode(&loginResult); err == nil {
96+
var respWithData ctypes.Response
97+
// Note: Body is consumed after decoding, caller should not read it again
98+
if err := json.NewDecoder(loginResp.Body).Decode(&respWithData); err == nil {
99+
var loginResult types.LoginSchema
100+
err = mapstructure.Decode(respWithData.Data, &loginResult)
101+
if err != nil {
102+
return nil, fmt.Errorf("login parsing data fail, get %v", respWithData.Data)
103+
}
95104
c.loginToken = loginResult.Token
105+
return &loginResult, nil
106+
} else {
107+
return nil, fmt.Errorf("login parsing response failed: %v", err)
96108
}
97-
// Note: Body is consumed after decoding, caller should not read it again
98-
return &loginResult, nil
99109
}
100110

101111
return nil, fmt.Errorf("login request failed with status: %d", loginResp.StatusCode)
102112
}
103113

104114
// ProxyLogin makes a POST request to /v1/proxy_login with LoginParameter
105-
func (c *upClient) ProxyLogin(ctx context.Context, param types.LoginParameter) (*types.LoginSchema, error) {
115+
func (c *upClient) ProxyLogin(ctx context.Context, param *types.LoginParameter) (*types.LoginSchema, error) {
106116
url := fmt.Sprintf("%s/coordinator/v1/proxy_login", c.baseURL)
107117

108118
jsonData, err := json.Marshal(param)
@@ -141,7 +151,7 @@ func (c *upClient) ProxyLogin(ctx context.Context, param types.LoginParameter) (
141151
}
142152

143153
// GetTask makes a POST request to /v1/get_task with GetTaskParameter
144-
func (c *upClient) GetTask(ctx context.Context, param types.GetTaskParameter, token string) (*http.Response, error) {
154+
func (c *upClient) GetTask(ctx context.Context, param *types.GetTaskParameter, token string) (*http.Response, error) {
145155
url := fmt.Sprintf("%s/coordinator/v1/get_task", c.baseURL)
146156

147157
jsonData, err := json.Marshal(param)
@@ -163,7 +173,7 @@ func (c *upClient) GetTask(ctx context.Context, param types.GetTaskParameter, to
163173
}
164174

165175
// SubmitProof makes a POST request to /v1/submit_proof with SubmitProofParameter
166-
func (c *upClient) SubmitProof(ctx context.Context, param types.SubmitProofParameter, token string) (*http.Response, error) {
176+
func (c *upClient) SubmitProof(ctx context.Context, param *types.SubmitProofParameter, token string) (*http.Response, error) {
167177
url := fmt.Sprintf("%s/coordinator/v1/submit_proof", c.baseURL)
168178

169179
jsonData, err := json.Marshal(param)

coordinator/internal/controller/proxy/client_manager.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
type Client interface {
2020
Client(context.Context) *upClient
21+
PeekClient() *upClient
2122
}
2223

2324
type ClientManager struct {
@@ -95,6 +96,13 @@ func (cliMgr *ClientManager) doLogin(ctx context.Context, loginCli *upClient) ti
9596
}
9697
}
9798

99+
func (cliMgr *ClientManager) PeekClient() *upClient {
100+
cliMgr.cachedCli.RLock()
101+
defer cliMgr.cachedCli.RUnlock()
102+
103+
return cliMgr.cachedCli.cli
104+
}
105+
98106
func (cliMgr *ClientManager) Client(ctx context.Context) *upClient {
99107
cliMgr.cachedCli.RLock()
100108
if cliMgr.cachedCli.cli != nil {
Lines changed: 129 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,62 @@
11
package proxy
22

33
import (
4+
"encoding/json"
5+
"fmt"
6+
"net/http"
7+
48
"github.com/gin-gonic/gin"
9+
"github.com/mitchellh/mapstructure"
510
"github.com/prometheus/client_golang/prometheus"
6-
"github.com/scroll-tech/go-ethereum/params"
7-
"gorm.io/gorm"
11+
"github.com/scroll-tech/go-ethereum/log"
812

9-
"scroll-tech/common/types/message"
13+
"scroll-tech/common/types"
1014

1115
"scroll-tech/coordinator/internal/config"
12-
"scroll-tech/coordinator/internal/logic/provertask"
13-
"scroll-tech/coordinator/internal/logic/verifier"
1416
coordinatorType "scroll-tech/coordinator/internal/types"
1517
)
1618

19+
func getSessionData(ctx *gin.Context) (string, *coordinatorType.LoginParameter) {
20+
21+
publicKeyData, publicKeyExist := ctx.Get(coordinatorType.PublicKey)
22+
publicKey, castOk := publicKeyData.(string)
23+
if !publicKeyExist || !castOk {
24+
nerr := fmt.Errorf("no public key binding: %v", publicKeyData)
25+
log.Warn("get_task parameter fail", "error", nerr)
26+
27+
types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr)
28+
return "", nil
29+
}
30+
31+
loginParamData, publicKeyExist := ctx.Get(LoginParamCache)
32+
loginParam, castOk := loginParamData.(*coordinatorType.LoginParameter)
33+
if !publicKeyExist || !castOk {
34+
nerr := fmt.Errorf("no login param binding: %v", loginParamData)
35+
log.Warn("get_task parameter fail", "error", nerr)
36+
37+
types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr)
38+
return "", nil
39+
}
40+
41+
return publicKey, loginParam
42+
}
43+
1744
// GetTaskController the get prover task api controller
1845
type GetTaskController struct {
19-
proverTasks map[message.ProofType]provertask.ProverTask
46+
tokenCache *UserTokenCache
47+
clients Clients
48+
priorityUpstream map[string]string
2049

2150
getTaskAccessCounter *prometheus.CounterVec
2251
}
2352

2453
// NewGetTaskController create a get prover task controller
25-
func NewGetTaskController(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, verifier *verifier.Verifier, reg prometheus.Registerer) *GetTaskController {
54+
func NewGetTaskController(cfg *config.Config, clients Clients, tokenCache *UserTokenCache, reg prometheus.Registerer) *GetTaskController {
2655
// TODO: implement proxy get task controller initialization
2756
return &GetTaskController{
28-
proverTasks: make(map[message.ProofType]provertask.ProverTask),
57+
priorityUpstream: make(map[string]string),
58+
tokenCache: tokenCache,
59+
clients: clients,
2960
}
3061
}
3162

@@ -36,10 +67,95 @@ func (ptc *GetTaskController) incGetTaskAccessCounter(ctx *gin.Context) error {
3667

3768
// GetTasks get assigned chunk/batch task
3869
func (ptc *GetTaskController) GetTasks(ctx *gin.Context) {
39-
// TODO: implement proxy get tasks logic
70+
var getTaskParameter coordinatorType.GetTaskParameter
71+
if err := ctx.ShouldBind(&getTaskParameter); err != nil {
72+
nerr := fmt.Errorf("prover task parameter invalid, err:%w", err)
73+
types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr)
74+
return
75+
}
76+
77+
publicKey, loginParam := getSessionData(ctx)
78+
if publicKey == "" || loginParam == nil {
79+
return
80+
}
81+
82+
tokens := ptc.tokenCache.Get(publicKey)
83+
84+
onClientFail := func(upstream string) {
85+
//TODO: log re-connect request in info level
86+
87+
request := TokenUpdate{
88+
PublicKey: publicKey,
89+
Upstream: upstream,
90+
Phase: tokens.LoginPhase,
91+
LoginParam: *loginParam,
92+
CompleteNotify: nil,
93+
}
94+
select {
95+
case <-ctx.Done():
96+
case ptc.tokenCache.tokenCacheUpdate <- &request:
97+
}
98+
99+
}
100+
101+
priorityUpstream, exist := ptc.priorityUpstream[publicKey]
102+
if exist {
103+
cli := ptc.clients[priorityUpstream]
104+
loginSchema := tokens.LoginData[priorityUpstream]
105+
if loginSchema == nil {
106+
onClientFail(priorityUpstream)
107+
} else {
108+
ret, triggerUpdate := getTaskFromClient(ctx, cli, &getTaskParameter, loginSchema.Token)
109+
if ret != nil {
110+
111+
} else if triggerUpdate {
112+
onClientFail(priorityUpstream)
113+
}
114+
}
115+
types.RenderFailure(ctx, types.ErrCoordinatorEmptyProofData, fmt.Errorf("get empty prover task"))
116+
}
117+
118+
for n, cli := range ptc.clients {
119+
120+
}
40121
}
41122

42-
func (ptc *GetTaskController) proofType(para *coordinatorType.GetTaskParameter) message.ProofType {
43-
// TODO: implement proxy proof type logic
44-
return message.ProofTypeChunk
45-
}
123+
func getTaskFromClient(ctx *gin.Context, cli Client, param *coordinatorType.GetTaskParameter, token string) (*coordinatorType.GetTaskSchema, bool) {
124+
125+
theCli := cli.PeekClient()
126+
if theCli == nil {
127+
return nil, true
128+
}
129+
130+
resp, err := theCli.GetTask(ctx, param, token)
131+
if err != nil {
132+
// log the err in error level
133+
return nil, false
134+
}
135+
136+
// Parse response
137+
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusUnauthorized {
138+
unAuth := resp.StatusCode == http.StatusUnauthorized
139+
var respWithData types.Response
140+
// Note: Body is consumed after decoding, caller should not read it again
141+
if err := json.NewDecoder(resp.Body).Decode(&respWithData); err == nil {
142+
if unAuth && respWithData.ErrCode == types.ErrJWTTokenExpired {
143+
return nil, true
144+
}
145+
146+
var getTaskResult coordinatorType.GetTaskSchema
147+
err = mapstructure.Decode(respWithData.Data, &getTaskResult)
148+
if err != nil {
149+
log.Error("parse get task data fail", "respdata", respWithData.Data)
150+
return nil, false
151+
}
152+
return &getTaskResult, false
153+
} else {
154+
log.Error("parse get task response failed", "error", err)
155+
//fmt.Errorf("login parsing response failed: %v", err)
156+
return nil, false
157+
}
158+
}
159+
160+
return nil, false
161+
}

0 commit comments

Comments
 (0)