Skip to content

Commit 03e4a4b

Browse files
authored
refactor(jwt): support singing method (#10)
1 parent 1f27ba0 commit 03e4a4b

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

jwt/jwt.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ import (
1212
// JWT represents a JWT handler using a shared secret and generic claims data.
1313
type JWT[T any] struct {
1414
secret []byte
15+
method jwt.SigningMethod
1516
}
1617

1718
// New creates a new JWT instance using the given secret string.
18-
func New[T any](secret string) *JWT[T] {
19+
func New[T any](secret string, method jwt.SigningMethod) *JWT[T] {
1920
return &JWT[T]{
2021
secret: []byte(secret),
22+
method: method,
2123
}
2224
}
2325

@@ -31,15 +33,19 @@ type Claims[T any] struct {
3133

3234
// Generate creates and signs a JWT token using the provided claims.
3335
func (x *JWT[T]) Generate(claims *Claims[T]) (string, error) {
34-
v := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
36+
v := jwt.NewWithClaims(x.method, claims)
3537

3638
return v.SignedString(x.secret)
3739
}
3840

3941
// Parse parses and validates a JWT token string and returns the claims
4042
// if the token is valid.
4143
func (x *JWT[T]) Parse(tokenString string) (*Claims[T], error) {
42-
token, err := jwt.ParseWithClaims(tokenString, &Claims[T]{}, func(_ *jwt.Token) (any, error) {
44+
token, err := jwt.ParseWithClaims(tokenString, &Claims[T]{}, func(t *jwt.Token) (any, error) {
45+
if t.Method.Alg() != x.method.Alg() {
46+
return nil, jwt.ErrTokenSignatureInvalid
47+
}
48+
4349
return x.secret, nil
4450
})
4551
if err != nil {

jwt/jwt_test.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type userClaims struct {
1313
Name string
1414
}
1515

16-
var j = mjwt.New[userClaims]("secret")
16+
var j = mjwt.New[userClaims]("secret", jwt.SigningMethodHS256)
1717

1818
func TestJWT_GenerateAndParse(t *testing.T) {
1919
token, err := j.Generate(&mjwt.Claims[userClaims]{
@@ -59,3 +59,20 @@ func TestJWT_ParseInvalid(t *testing.T) {
5959
t.Fatalf("expected parsed to be nil, got %+v", parsed)
6060
}
6161
}
62+
63+
func TestJWT_WrongSigningMethod(t *testing.T) {
64+
j1 := mjwt.New[userClaims]("secret", jwt.SigningMethodHS256)
65+
j2 := mjwt.New[userClaims]("secret", jwt.SigningMethodHS384)
66+
67+
token, err := j1.Generate(&mjwt.Claims[userClaims]{
68+
Data: &userClaims{ID: 1},
69+
})
70+
if err != nil {
71+
t.Fatal(err)
72+
}
73+
74+
_, err = j2.Parse(token)
75+
if err == nil {
76+
t.Fatal("expected signing method mismatch error")
77+
}
78+
}

0 commit comments

Comments
 (0)