Skip to content

Commit 057e220

Browse files
committed
fix issues
1 parent c7b83a0 commit 057e220

File tree

6 files changed

+90
-25
lines changed

6 files changed

+90
-25
lines changed

common/types/response.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@ type Response struct {
1515
}
1616

1717
func (resp *Response) DecodeData(out interface{}) error {
18-
return mapstructure.Decode(resp.Data, out)
18+
// Decode generically unmarshaled JSON (map[string]any, []any) into a typed struct
19+
// honoring `json` tags and allowing weak type conversions.
20+
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
21+
TagName: "json",
22+
Result: out,
23+
})
24+
if err != nil {
25+
return err
26+
}
27+
return dec.Decode(resp.Data)
1928
}
2029

2130
// RenderJSON renders response with json

coordinator/internal/controller/proxy/controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ func InitController(cfg *config.ProxyConfig, reg prometheus.Registerer) {
3939

4040
Auth = NewAuthController(cfg, clients, proverManager)
4141
GetTask = NewGetTaskController(cfg, clients, proverManager, priorityManager, reg)
42-
SubmitProof = NewSubmitProofController(cfg, clients, proverManager, reg)
42+
SubmitProof = NewSubmitProofController(cfg, clients, proverManager, priorityManager, reg)
4343
}

coordinator/internal/controller/proxy/get_task.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ func (ptc *GetTaskController) GetTasks(ctx *gin.Context) {
146146
// TODO: log error
147147
}
148148
}
149+
ptc.priorityUpstream.Delete(publicKey)
149150

150151
// Create a slice to hold the keys
151152
keys := make([]string, 0, len(ptc.clients))

coordinator/internal/controller/proxy/prover_session.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ type proverSession struct {
6565
completionCtx context.Context
6666
}
6767

68-
func (c *proverSession) maintainLogin(ctx context.Context, cliMgr Client, up string, param *types.LoginParameter, phase uint) (*types.LoginSchema, error) {
68+
func (c *proverSession) maintainLogin(ctx context.Context, cliMgr Client, up string, param *types.LoginParameter, phase uint) (result *types.LoginSchema, nerr error) {
6969
c.Lock()
7070
curPhase := c.proverToken[up].phase
7171
if c.completionCtx != nil {
@@ -89,6 +89,17 @@ func (c *proverSession) maintainLogin(ctx context.Context, cliMgr Client, up str
8989
completeCtx, cf := context.WithCancel(ctx)
9090
defer cf()
9191
c.completionCtx = completeCtx
92+
defer func() {
93+
c.Lock()
94+
c.completionCtx = nil
95+
if result != nil {
96+
c.proverToken[up] = loginToken{
97+
LoginSchema: result,
98+
phase: curPhase + 1,
99+
}
100+
}
101+
c.Unlock()
102+
}()
92103
c.Unlock()
93104

94105
cli := cliMgr.Client(ctx)
@@ -124,18 +135,9 @@ func (c *proverSession) maintainLogin(ctx context.Context, cliMgr Client, up str
124135
return nil, err
125136
}
126137

127-
c.Lock()
128-
defer c.Unlock()
129-
130-
c.proverToken[up] = loginToken{
131-
LoginSchema: &types.LoginSchema{
132-
Token: loginResult.Token,
133-
},
134-
phase: curPhase + 1,
135-
}
136-
c.completionCtx = nil
137-
138-
return c.proverToken[up].LoginSchema, nil
138+
return &types.LoginSchema{
139+
Token: loginResult.Token,
140+
}, nil
139141
}
140142

141143
const expireTolerant = 10 * time.Minute

coordinator/internal/controller/proxy/submit_proof.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,26 @@ import (
1414

1515
// SubmitProofController the submit proof api controller
1616
type SubmitProofController struct {
17-
proverMgr *ProverManager
18-
clients Clients
17+
proverMgr *ProverManager
18+
clients Clients
19+
priorityUpstream *PriorityUpstreamManager
1920
}
2021

2122
// NewSubmitProofController create the submit proof api controller instance
22-
func NewSubmitProofController(cfg *config.ProxyConfig, clients Clients, proverMgr *ProverManager, reg prometheus.Registerer) *SubmitProofController {
23+
func NewSubmitProofController(cfg *config.ProxyConfig, clients Clients, proverMgr *ProverManager, priorityMgr *PriorityUpstreamManager, reg prometheus.Registerer) *SubmitProofController {
2324
return &SubmitProofController{
24-
proverMgr: proverMgr,
25-
clients: clients,
25+
proverMgr: proverMgr,
26+
clients: clients,
27+
priorityUpstream: priorityMgr,
2628
}
2729
}
2830

29-
func upstreamFromTaskName(taskID string) string {
30-
parts, _, found := strings.Cut(taskID, ":")
31+
func upstreamFromTaskName(taskID string) (string, string) {
32+
parts, rest, found := strings.Cut(taskID, ":")
3133
if found {
32-
return parts
34+
return parts, rest
3335
}
34-
return ""
36+
return "", parts
3537
}
3638

3739
func formUpstreamWithTaskName(upstream string, taskID string) string {
@@ -40,6 +42,7 @@ func formUpstreamWithTaskName(upstream string, taskID string) string {
4042

4143
// SubmitProof prover submit the proof to coordinator
4244
func (spc *SubmitProofController) SubmitProof(ctx *gin.Context) {
45+
4346
var submitParameter coordinatorType.SubmitProofParameter
4447
if err := ctx.ShouldBind(&submitParameter); err != nil {
4548
nerr := fmt.Errorf("prover submitProof parameter invalid, err:%w", err)
@@ -53,14 +56,15 @@ func (spc *SubmitProofController) SubmitProof(ctx *gin.Context) {
5356
}
5457

5558
session := spc.proverMgr.Get(publicKey)
56-
upstream := upstreamFromTaskName(submitParameter.TaskID)
59+
upstream, realTaskID := upstreamFromTaskName(submitParameter.TaskID)
5760
cli, existed := spc.clients[upstream]
5861
if !existed {
5962
// TODO: log error
6063
nerr := fmt.Errorf("Invalid upstream name (%s) from taskID %s", upstream, submitParameter.TaskID)
6164
types.RenderFailure(ctx, types.ErrCoordinatorParameterInvalidNo, nerr)
6265
return
6366
}
67+
submitParameter.TaskID = realTaskID
6468

6569
resp, err := session.SubmitProof(ctx, &submitParameter, cli, upstream)
6670
if err != nil {
@@ -71,6 +75,7 @@ func (spc *SubmitProofController) SubmitProof(ctx *gin.Context) {
7175
types.RenderFailure(ctx, resp.ErrCode, fmt.Errorf("%s", resp.ErrMsg))
7276
return
7377
} else {
78+
spc.priorityUpstream.Delete(upstream)
7479
types.RenderSuccess(ctx, resp.Data)
7580
return
7681
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package types
2+
3+
import (
4+
"encoding/json"
5+
"reflect"
6+
"testing"
7+
8+
"scroll-tech/common/types"
9+
)
10+
11+
func TestResponseDecodeData_GetTaskSchema(t *testing.T) {
12+
// Arrange: build a dummy payload and wrap it in Response
13+
in := GetTaskSchema{
14+
UUID: "uuid-123",
15+
TaskID: "task-abc",
16+
TaskType: 1,
17+
UseSnark: true,
18+
TaskData: "dummy-data",
19+
HardForkName: "cancun",
20+
}
21+
22+
resp := types.Response{
23+
ErrCode: 0,
24+
ErrMsg: "",
25+
Data: in,
26+
}
27+
28+
// Act: JSON round-trip the Response to simulate real HTTP encoding/decoding
29+
b, err := json.Marshal(resp)
30+
if err != nil {
31+
t.Fatalf("marshal response: %v", err)
32+
}
33+
34+
var decoded types.Response
35+
if err := json.Unmarshal(b, &decoded); err != nil {
36+
t.Fatalf("unmarshal response: %v", err)
37+
}
38+
39+
var out GetTaskSchema
40+
if err := decoded.DecodeData(&out); err != nil {
41+
t.Fatalf("DecodeData error: %v", err)
42+
}
43+
44+
// Assert: structs match after decode
45+
if !reflect.DeepEqual(in, out) {
46+
t.Fatalf("decoded struct mismatch:\nwant: %+v\n got: %+v", in, out)
47+
}
48+
}

0 commit comments

Comments
 (0)