Skip to content

Commit b7f601c

Browse files
authored
feat: support sse in api files (#5074)
Signed-off-by: Kevin Wan <[email protected]>
1 parent 1ebbc6f commit b7f601c

File tree

2 files changed

+186
-1
lines changed

2 files changed

+186
-1
lines changed

tools/goctl/api/gogen/genroutes.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
4040
`
4141
routesAdditionTemplate = `
4242
server.AddRoutes(
43-
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}}
43+
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}} {{.sse}}
4444
)
4545
`
4646
timeoutThreshold = time.Millisecond
@@ -63,6 +63,7 @@ type (
6363
routes []route
6464
jwtEnabled bool
6565
signatureEnabled bool
66+
sseEnabled bool
6667
authName string
6768
timeout string
6869
middlewares []string
@@ -123,10 +124,17 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
123124
if len(g.jwtTrans) > 0 {
124125
jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s.PrevSecret,serverCtx.Config.%s.Secret),", g.jwtTrans, g.jwtTrans)
125126
}
127+
126128
var signature, prefix string
127129
if g.signatureEnabled {
128130
signature = "\n rest.WithSignature(serverCtx.Config.Signature),"
129131
}
132+
133+
var sse string
134+
if g.sseEnabled {
135+
sse = "\n rest.WithSSE(),"
136+
}
137+
130138
if len(g.prefix) > 0 {
131139
prefix = fmt.Sprintf(`
132140
rest.WithPrefix("%s"),`, g.prefix)
@@ -172,6 +180,7 @@ rest.WithPrefix("%s"),`, g.prefix)
172180
"routes": routes,
173181
"jwt": jwt,
174182
"signature": signature,
183+
"sse": sse,
175184
"prefix": prefix,
176185
"timeout": timeout,
177186
"maxBytes": maxBytes,
@@ -281,6 +290,10 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
281290
if signature == "true" {
282291
groupedRoutes.signatureEnabled = true
283292
}
293+
sse := g.GetAnnotation("sse")
294+
if sse == "true" {
295+
groupedRoutes.sseEnabled = true
296+
}
284297
middleware := g.GetAnnotation("middleware")
285298
if len(middleware) > 0 {
286299
groupedRoutes.middlewares = append(groupedRoutes.middlewares,

tools/goctl/api/gogen/genroutes_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package gogen
33
import (
44
"testing"
55
"time"
6+
7+
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
68
)
79

810
func Test_formatDuration(t *testing.T) {
@@ -25,3 +27,173 @@ func Test_formatDuration(t *testing.T) {
2527
}
2628
}
2729
}
30+
31+
func TestSSESupport(t *testing.T) {
32+
// Test API spec with SSE enabled
33+
apiSpec := &spec.ApiSpec{
34+
Service: spec.Service{
35+
Groups: []spec.Group{
36+
{
37+
Annotation: spec.Annotation{
38+
Properties: map[string]string{
39+
"sse": "true",
40+
"prefix": "/api/v1",
41+
},
42+
},
43+
Routes: []spec.Route{
44+
{
45+
Method: "get",
46+
Path: "/events",
47+
Handler: "StreamEvents",
48+
},
49+
},
50+
},
51+
},
52+
},
53+
}
54+
55+
groups, err := getRoutes(apiSpec)
56+
if err != nil {
57+
t.Fatalf("getRoutes failed: %v", err)
58+
}
59+
60+
if len(groups) != 1 {
61+
t.Fatalf("Expected 1 group, got %d", len(groups))
62+
}
63+
64+
group := groups[0]
65+
if !group.sseEnabled {
66+
t.Error("Expected SSE to be enabled")
67+
}
68+
69+
if group.prefix != "/api/v1" {
70+
t.Errorf("Expected prefix '/api/v1', got '%s'", group.prefix)
71+
}
72+
73+
if len(group.routes) != 1 {
74+
t.Fatalf("Expected 1 route, got %d", len(group.routes))
75+
}
76+
77+
route := group.routes[0]
78+
if route.method != "http.MethodGet" {
79+
t.Errorf("Expected method 'http.MethodGet', got '%s'", route.method)
80+
}
81+
82+
if route.path != "/events" {
83+
t.Errorf("Expected path '/events', got '%s'", route.path)
84+
}
85+
}
86+
87+
func TestSSEWithOtherFeatures(t *testing.T) {
88+
// Test API spec with SSE and other features
89+
apiSpec := &spec.ApiSpec{
90+
Service: spec.Service{
91+
Groups: []spec.Group{
92+
{
93+
Annotation: spec.Annotation{
94+
Properties: map[string]string{
95+
"sse": "true",
96+
"jwt": "Auth",
97+
"signature": "true",
98+
"prefix": "/api/v1",
99+
"timeout": "30s",
100+
"middleware": "AuthMiddleware,LogMiddleware",
101+
},
102+
},
103+
Routes: []spec.Route{
104+
{
105+
Method: "get",
106+
Path: "/events",
107+
Handler: "StreamEvents",
108+
},
109+
},
110+
},
111+
},
112+
},
113+
}
114+
115+
groups, err := getRoutes(apiSpec)
116+
if err != nil {
117+
t.Fatalf("getRoutes failed: %v", err)
118+
}
119+
120+
if len(groups) != 1 {
121+
t.Fatalf("Expected 1 group, got %d", len(groups))
122+
}
123+
124+
group := groups[0]
125+
126+
// Verify all features are enabled
127+
if !group.sseEnabled {
128+
t.Error("Expected SSE to be enabled")
129+
}
130+
131+
if !group.jwtEnabled {
132+
t.Error("Expected JWT to be enabled")
133+
}
134+
135+
if !group.signatureEnabled {
136+
t.Error("Expected signature to be enabled")
137+
}
138+
139+
if group.authName != "Auth" {
140+
t.Errorf("Expected authName 'Auth', got '%s'", group.authName)
141+
}
142+
143+
if group.prefix != "/api/v1" {
144+
t.Errorf("Expected prefix '/api/v1', got '%s'", group.prefix)
145+
}
146+
147+
if group.timeout != "30s" {
148+
t.Errorf("Expected timeout '30s', got '%s'", group.timeout)
149+
}
150+
151+
expectedMiddlewares := []string{"AuthMiddleware", "LogMiddleware"}
152+
if len(group.middlewares) != len(expectedMiddlewares) {
153+
t.Errorf("Expected %d middlewares, got %d", len(expectedMiddlewares), len(group.middlewares))
154+
}
155+
156+
for i, expected := range expectedMiddlewares {
157+
if group.middlewares[i] != expected {
158+
t.Errorf("Expected middleware[%d] '%s', got '%s'", i, expected, group.middlewares[i])
159+
}
160+
}
161+
}
162+
163+
func TestSSEDisabled(t *testing.T) {
164+
// Test API spec without SSE
165+
apiSpec := &spec.ApiSpec{
166+
Service: spec.Service{
167+
Groups: []spec.Group{
168+
{
169+
Annotation: spec.Annotation{
170+
Properties: map[string]string{
171+
"prefix": "/api/v1",
172+
},
173+
},
174+
Routes: []spec.Route{
175+
{
176+
Method: "get",
177+
Path: "/status",
178+
Handler: "GetStatus",
179+
},
180+
},
181+
},
182+
},
183+
},
184+
}
185+
186+
groups, err := getRoutes(apiSpec)
187+
if err != nil {
188+
t.Fatalf("getRoutes failed: %v", err)
189+
}
190+
191+
if len(groups) != 1 {
192+
t.Fatalf("Expected 1 group, got %d", len(groups))
193+
}
194+
195+
group := groups[0]
196+
if group.sseEnabled {
197+
t.Error("Expected SSE to be disabled")
198+
}
199+
}

0 commit comments

Comments
 (0)