Skip to content

Commit 3d00607

Browse files
committed
Add tests for workload manager JWT & Fix comments
Signed-off-by: Zhou Zihang <z@mcac.cc>
1 parent ad17424 commit 3d00607

File tree

5 files changed

+302
-71
lines changed

5 files changed

+302
-71
lines changed

pkg/workloadmanager/handlers.go

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/gin-gonic/gin"
1313
redisv9 "github.com/redis/go-redis/v9"
14+
"k8s.io/client-go/dynamic"
1415
sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1"
1516
"sigs.k8s.io/agent-sandbox/controllers"
1617
extensionsv1alpha1 "sigs.k8s.io/agent-sandbox/extensions/api/v1alpha1"
@@ -28,9 +29,27 @@ func (s *Server) handleHealth(c *gin.Context) {
2829

2930
// handleCreateSandbox handles sandbox creation requests
3031
// nolint :gocyclo
32+
// extractUserK8sClient extracts user information from the context and creates a user-specific Kubernetes client.
33+
// It returns the dynamic client for the user and an error if authentication fails or client creation fails.
34+
func (s *Server) extractUserK8sClient(c *gin.Context) (dynamic.Interface, error) {
35+
// Extract user information from context
36+
userToken, userNamespace, _, serviceAccountName := extractUserInfo(c)
37+
if userToken == "" || userNamespace == "" || serviceAccountName == "" {
38+
return nil, errors.New("unable to extract user credentials")
39+
}
40+
41+
// Create sandbox using user's K8s client
42+
userClient, err := s.k8sClient.GetOrCreateUserK8sClient(userToken, userNamespace, serviceAccountName)
43+
if err != nil {
44+
log.Printf("create user client failed: %v", err)
45+
return nil, fmt.Errorf("create user client failed: %w", err)
46+
}
47+
return userClient.dynamicClient, nil
48+
}
49+
3150
func (s *Server) handleCreateSandbox(c *gin.Context) {
32-
createAgentRequest := &types.CreateSandboxRequest{}
33-
if err := c.ShouldBindJSON(createAgentRequest); err != nil {
51+
sandboxReq := &types.CreateSandboxRequest{}
52+
if err := c.ShouldBindJSON(sandboxReq); err != nil {
3453
log.Printf("parse request body failed: %v", err)
3554
respondError(c, http.StatusBadRequest, "INVALID_REQUEST", "Invalid request body")
3655
return
@@ -39,13 +58,13 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
3958
reqPath := c.Request.URL.Path
4059
switch {
4160
case strings.HasSuffix(reqPath, "/agent-runtime"):
42-
createAgentRequest.Kind = types.AgentRuntimeKind
61+
sandboxReq.Kind = types.AgentRuntimeKind
4362
case strings.HasSuffix(reqPath, "/code-interpreter"):
44-
createAgentRequest.Kind = types.CodeInterpreterKind
63+
sandboxReq.Kind = types.CodeInterpreterKind
4564
default:
4665
}
4766

48-
if err := createAgentRequest.Validate(); err != nil {
67+
if err := sandboxReq.Validate(); err != nil {
4968
log.Printf("request body validation failed: %v", err)
5069
respondError(c, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
5170
return
@@ -55,14 +74,14 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
5574
var sandboxClaim *extensionsv1alpha1.SandboxClaim
5675
var externalInfo *sandboxExternalInfo
5776
var err error
58-
switch createAgentRequest.Kind {
77+
switch sandboxReq.Kind {
5978
case types.AgentRuntimeKind:
60-
sandbox, externalInfo, err = buildSandboxByAgentRuntime(createAgentRequest.Namespace, createAgentRequest.Name, s.informers)
79+
sandbox, externalInfo, err = buildSandboxByAgentRuntime(sandboxReq.Namespace, sandboxReq.Name, s.informers)
6180
case types.CodeInterpreterKind:
62-
sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(createAgentRequest.Namespace, createAgentRequest.Name, s.informers)
81+
sandbox, sandboxClaim, externalInfo, err = buildSandboxByCodeInterpreter(sandboxReq.Namespace, sandboxReq.Name, s.informers)
6382
default:
64-
log.Printf("invalid request kind: %v", createAgentRequest.Kind)
65-
respondError(c, http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("invalid request kind: %v", createAgentRequest.Kind))
83+
log.Printf("invalid request kind: %v", sandboxReq.Kind)
84+
respondError(c, http.StatusBadRequest, "INVALID_REQUEST", fmt.Sprintf("invalid request kind: %v", sandboxReq.Kind))
6685
return
6786
}
6887

@@ -78,22 +97,12 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
7897

7998
dynamicClient := s.k8sClient.dynamicClient
8099
if s.config.EnableAuth {
81-
// Extract user information from context
82-
userToken, userNamespace, _, serviceAccountName := extractUserInfo(c)
83-
if userToken == "" || userNamespace == "" || serviceAccountName == "" {
84-
respondError(c, http.StatusUnauthorized, "UNAUTHORIZED", "Unable to extract user credentials")
85-
return
86-
}
87-
88-
// Create sandbox using user's K8s client
89-
userClient, err := s.k8sClient.GetOrCreateUserK8sClient(userToken, userNamespace, serviceAccountName)
100+
userDynamicClient, err := s.extractUserK8sClient(c)
90101
if err != nil {
91-
log.Printf("create user client failed: %v", err)
92-
respondError(c, http.StatusInternalServerError, "CLIENT_CREATION_FAILED", err.Error())
102+
respondError(c, http.StatusUnauthorized, "UNAUTHORIZED", err.Error())
93103
return
94104
}
95-
96-
dynamicClient = userClient.dynamicClient
105+
dynamicClient = userDynamicClient
97106
}
98107

99108
// CRITICAL: Register watcher BEFORE creating sandbox
@@ -182,7 +191,7 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
182191
EntryPoints: redisCacheInfo.EntryPoints,
183192
}
184193

185-
if createAgentRequest.Kind != types.CodeInterpreterKind {
194+
if sandboxReq.Kind != types.CodeInterpreterKind {
186195
err = s.redisClient.UpdateSandbox(c.Request.Context(), redisCacheInfo, RedisNoExpiredTTL)
187196
if err != nil {
188197
log.Printf("update redis cache failed: %v", err)
@@ -196,9 +205,14 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
196205
}
197206

198207
if len(redisCacheInfo.EntryPoints) == 0 {
199-
respondError(c, http.StatusInternalServerError, "SANDBOX_INIT_FAILED",
200-
"No access endpoint found for sandbox initialization")
201-
return
208+
// Fallback to default http://ip:8080
209+
defaultEntryPoint := types.SandboxEntryPoints{
210+
Path: "/",
211+
Protocol: "http",
212+
Endpoint: fmt.Sprintf("%s:8080", podIP),
213+
}
214+
redisCacheInfo.EntryPoints = []types.SandboxEntryPoints{defaultEntryPoint}
215+
response.EntryPoints = redisCacheInfo.EntryPoints
202216
}
203217

204218
// Code Interpreter sandbox created, init code interpreter
@@ -221,10 +235,10 @@ func (s *Server) handleCreateSandbox(c *gin.Context) {
221235
err = s.InitCodeInterpreterSandbox(
222236
c.Request.Context(),
223237
initEndpoint,
224-
sandbox.Labels[SessionIdLabelKey],
225-
createAgentRequest.PublicKey,
226-
createAgentRequest.Metadata,
227-
createAgentRequest.InitTimeoutSeconds,
238+
externalInfo.SessionID,
239+
sandboxReq.PublicKey,
240+
sandboxReq.Metadata,
241+
sandboxReq.InitTimeoutSeconds,
228242
)
229243

230244
if err != nil {
@@ -265,22 +279,12 @@ func (s *Server) handleDeleteSandbox(c *gin.Context) {
265279

266280
dynamicClient := s.k8sClient.dynamicClient
267281
if s.config.EnableAuth {
268-
// Extract user information from context
269-
userToken, userNamespace, _, serviceAccountName := extractUserInfo(c)
270-
271-
if userToken == "" || userNamespace == "" || serviceAccountName == "" {
272-
respondError(c, http.StatusUnauthorized, "UNAUTHORIZED", "Unable to extract user credentials")
273-
return
274-
}
275-
276-
// Delete sandbox using user's K8s client
277-
userClient, clientErr := s.k8sClient.GetOrCreateUserK8sClient(userToken, userNamespace, serviceAccountName)
278-
if clientErr != nil {
279-
respondError(c, http.StatusInternalServerError, "CLIENT_CREATION_FAILED", clientErr.Error())
282+
userDynamicClient, err := s.extractUserK8sClient(c)
283+
if err != nil {
284+
respondError(c, http.StatusUnauthorized, "UNAUTHORIZED", err.Error())
280285
return
281286
}
282-
283-
dynamicClient = userClient.dynamicClient
287+
dynamicClient = userDynamicClient
284288
}
285289

286290
if sandbox.Kind == types.SandboxClaimsKind {

pkg/workloadmanager/jwt.go

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/x509"
88
"encoding/pem"
99
"fmt"
10+
"log"
1011
"os"
1112
"time"
1213

@@ -22,7 +23,7 @@ const (
2223
// JWT token expiration time
2324
jwtExpiration = 5 * time.Minute
2425
// JWTPublicKeySecretName is the name of the secret storing JWT public key
25-
JWTPublicKeySecretName = "agentcube-jwt-public-key"
26+
JWTPublicKeySecretName = "agentcube-jwt-public-key" //nolint:gosec // This is a name reference, not a credential
2627
// JWTPublicKeyDataKey is the key in the secret data map
2728
JWTPublicKeyDataKey = "public-key.pem"
2829
// JWTPublicKeyVolumeName the name of JWT PublicKey volume
@@ -90,24 +91,13 @@ func (jm *JWTManager) GetPublicKeyPEM() ([]byte, error) {
9091
}
9192

9293
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
93-
Type: "RSA PUBLIC KEY",
94+
Type: "PUBLIC KEY",
9495
Bytes: pubKeyBytes,
9596
})
9697

9798
return pubKeyPEM, nil
9899
}
99100

100-
// GetPrivateKeyPEM returns the private key in PEM format (for debugging/backup purposes)
101-
func (jm *JWTManager) GetPrivateKeyPEM() ([]byte, error) {
102-
privKeyBytes := x509.MarshalPKCS1PrivateKey(jm.privateKey)
103-
privKeyPEM := pem.EncodeToMemory(&pem.Block{
104-
Type: "RSA PRIVATE KEY",
105-
Bytes: privKeyBytes,
106-
})
107-
108-
return privKeyPEM, nil
109-
}
110-
111101
// StoreJWTPublicKeyInSecret stores the JWT public key in a Kubernetes secret
112102
//
113103
// Currently, the JWT public key secret is stored in a single namespace (default
@@ -133,7 +123,7 @@ func (c *K8sClient) StoreJWTPublicKeyInSecret(ctx context.Context, publicKeyPEM
133123
}
134124

135125
// Try to get existing secret
136-
existingSecret, err := c.clientset.CoreV1().Secrets(JWTPublicKeySecretNamespace).Get(
126+
_, err := c.clientset.CoreV1().Secrets(JWTPublicKeySecretNamespace).Get(
137127
ctx,
138128
JWTPublicKeySecretName,
139129
metav1.GetOptions{},
@@ -148,24 +138,17 @@ func (c *K8sClient) StoreJWTPublicKeyInSecret(ctx context.Context, publicKeyPEM
148138
metav1.CreateOptions{},
149139
)
150140
if err != nil {
141+
if apierrors.IsAlreadyExists(err) {
142+
return nil
143+
}
151144
return fmt.Errorf("failed to create JWT public key secret: %w", err)
152145
}
153-
fmt.Printf("JWT public key secret %s/%s created", secret.Namespace, secret.Name)
146+
log.Printf("JWT public key secret %s/%s created", secret.Namespace, secret.Name)
154147
return nil
155148
}
156149
return fmt.Errorf("failed to get JWT public key secret: %w", err)
157150
}
158151

159-
// Secret exists, update it
160-
existingSecret.Data = secret.Data
161-
_, err = c.clientset.CoreV1().Secrets(JWTPublicKeySecretNamespace).Update(
162-
ctx,
163-
existingSecret,
164-
metav1.UpdateOptions{},
165-
)
166-
if err != nil {
167-
return fmt.Errorf("failed to update JWT public key secret: %w", err)
168-
}
169-
fmt.Printf("JWT public key secret %s/%s updated", secret.Namespace, secret.Name)
152+
// Secret exists, do not update it
170153
return nil
171154
}

pkg/workloadmanager/jwt_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package workloadmanager
2+
3+
import (
4+
"crypto/x509"
5+
"encoding/pem"
6+
"testing"
7+
"time"
8+
9+
"github.com/golang-jwt/jwt/v5"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestNewJWTManager(t *testing.T) {
15+
jm, err := NewJWTManager()
16+
require.NoError(t, err)
17+
require.NotNil(t, jm)
18+
assert.NotNil(t, jm.privateKey)
19+
assert.NotNil(t, jm.publicKey)
20+
assert.Equal(t, &jm.privateKey.PublicKey, jm.publicKey)
21+
}
22+
23+
func TestGenerateToken(t *testing.T) {
24+
jm, err := NewJWTManager()
25+
require.NoError(t, err)
26+
27+
claims := map[string]interface{}{
28+
"sub": "test-subject",
29+
"role": "admin",
30+
}
31+
32+
tokenString, err := jm.GenerateToken(claims)
33+
require.NoError(t, err)
34+
assert.NotEmpty(t, tokenString)
35+
36+
// Verify token
37+
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
38+
// Validate signing method
39+
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
40+
return nil, jwt.ErrSignatureInvalid
41+
}
42+
return jm.publicKey, nil
43+
})
44+
45+
require.NoError(t, err)
46+
assert.True(t, token.Valid)
47+
48+
// Validate claims
49+
mapClaims, ok := token.Claims.(jwt.MapClaims)
50+
assert.True(t, ok)
51+
assert.Equal(t, "test-subject", mapClaims["sub"])
52+
assert.Equal(t, "admin", mapClaims["role"])
53+
54+
// Validate expiration exists
55+
assert.Contains(t, mapClaims, "exp")
56+
assert.Contains(t, mapClaims, "iat")
57+
}
58+
59+
func TestGetPublicKeyPEM(t *testing.T) {
60+
jm, err := NewJWTManager()
61+
require.NoError(t, err)
62+
63+
pemBytes, err := jm.GetPublicKeyPEM()
64+
require.NoError(t, err)
65+
assert.NotEmpty(t, pemBytes)
66+
67+
// Parse PEM
68+
block, _ := pem.Decode(pemBytes)
69+
require.NotNil(t, block)
70+
assert.Equal(t, "PUBLIC KEY", block.Type)
71+
72+
// Parse Key
73+
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
74+
require.NoError(t, err)
75+
assert.NotNil(t, pubKey)
76+
}
77+
78+
func TestTokenExpiration(t *testing.T) {
79+
jm, err := NewJWTManager()
80+
require.NoError(t, err)
81+
82+
claims := map[string]interface{}{"foo": "bar"}
83+
tokenString, err := jm.GenerateToken(claims)
84+
require.NoError(t, err)
85+
86+
token, err := jwt.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
87+
return jm.publicKey, nil
88+
})
89+
require.NoError(t, err)
90+
91+
mapClaims, ok := token.Claims.(jwt.MapClaims)
92+
assert.True(t, ok)
93+
94+
expFloat, ok := mapClaims["exp"].(float64)
95+
assert.True(t, ok)
96+
exp := int64(expFloat)
97+
98+
iatFloat, ok := mapClaims["iat"].(float64)
99+
assert.True(t, ok)
100+
iat := int64(iatFloat)
101+
// jwtExpiration is 5 minutes
102+
expectedExp := iat + int64(5*time.Minute/time.Second)
103+
104+
// Allow 1 second delta for execution time
105+
assert.InDelta(t, expectedExp, exp, 1)
106+
}
107+
108+
func TestVerifyTokenWithDifferentKey(t *testing.T) {
109+
jm1, err := NewJWTManager()
110+
require.NoError(t, err)
111+
112+
jm2, err := NewJWTManager()
113+
require.NoError(t, err)
114+
115+
tokenString, err := jm1.GenerateToken(map[string]interface{}{"foo": "bar"})
116+
require.NoError(t, err)
117+
118+
// Try to verify with jm2's public key (should fail)
119+
token, err := jwt.Parse(tokenString, func(_ *jwt.Token) (interface{}, error) {
120+
return jm2.publicKey, nil
121+
})
122+
123+
assert.Error(t, err)
124+
if token != nil {
125+
assert.False(t, token.Valid)
126+
}
127+
}

0 commit comments

Comments
 (0)