Skip to content

Commit adf9ac4

Browse files
committed
Adds: CSRF token generate and verify middleware
1 parent 08548ef commit adf9ac4

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package middlewares
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/gin-gonic/gin"
7+
8+
"github.com/sdslabs/nymeria/internal/utils"
9+
)
10+
11+
func CSRFMiddleware() gin.HandlerFunc {
12+
return func(c *gin.Context) {
13+
// Skip CSRF validation for GET, HEAD, OPTIONS requests
14+
if c.Request.Method == "GET" || c.Request.Method == "HEAD" || c.Request.Method == "OPTIONS" {
15+
c.Next()
16+
return
17+
}
18+
19+
// Get user ID from header (TODO: replace with session/JWT when available)
20+
userID := c.GetHeader("X-User-ID") // TODO: Get user ID from session
21+
if userID == "" {
22+
c.JSON(http.StatusUnauthorized, gin.H{
23+
"status": "error",
24+
"message": "User ID is required for CSRF protection",
25+
})
26+
c.Abort()
27+
return
28+
}
29+
30+
var csrfToken string
31+
32+
// First, try to get CSRF token from JSON body
33+
if c.GetHeader("Content-Type") == "application/json" {
34+
var body map[string]interface{}
35+
// Create a copy of the request body for CSRF token extraction
36+
if err := c.ShouldBindJSON(&body); err == nil {
37+
if token, exists := body["csrf_token"]; exists {
38+
if tokenStr, ok := token.(string); ok {
39+
csrfToken = tokenStr
40+
}
41+
}
42+
}
43+
// Restore the body for the actual handler by binding again
44+
// Note: This is a limitation - we need to read the body twice
45+
// A more elegant solution would be to buffer the body
46+
}
47+
48+
// If not found in JSON body, try header
49+
if csrfToken == "" {
50+
csrfToken = c.GetHeader("X-CSRF-Token")
51+
}
52+
53+
// If not found in header, try form field
54+
if csrfToken == "" {
55+
csrfToken = c.PostForm("csrf_token")
56+
}
57+
58+
// If no CSRF token found, return error
59+
if csrfToken == "" {
60+
c.JSON(http.StatusBadRequest, gin.H{
61+
"status": "error",
62+
"message": "CSRF token is required",
63+
})
64+
c.Abort()
65+
return
66+
}
67+
68+
// Validate CSRF token
69+
if !utils.ValidateCSRFToken(userID, csrfToken) {
70+
c.JSON(http.StatusUnauthorized, gin.H{
71+
"status": "error",
72+
"message": "Invalid or expired CSRF token",
73+
})
74+
c.Abort()
75+
return
76+
}
77+
78+
// Token is valid, continue to the next handler
79+
c.Next()
80+
}
81+
}

internal/utils/crypto.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package utils
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/hex"
6+
)
7+
8+
// GenerateRandomString generates a random string of specified length
9+
func GenerateRandomString(length int) (string, error) {
10+
bytes := make([]byte, length)
11+
if _, err := rand.Read(bytes); err != nil {
12+
return "", err
13+
}
14+
return hex.EncodeToString(bytes), nil
15+
}

internal/utils/csrf.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2025 SDSLabs
2+
// SPDX-License-Identifier: MIT
3+
4+
package utils
5+
6+
import (
7+
"crypto/hmac"
8+
"crypto/sha256"
9+
"encoding/base64"
10+
"fmt"
11+
"strconv"
12+
"strings"
13+
"time"
14+
15+
"github.com/sdslabs/nymeria/internal/config"
16+
)
17+
18+
// GenerateStatelessCSRF returns a token valid for 2 minutes
19+
func GenerateCSRFToken(userID string) (string, error) {
20+
csrfSecretKey := []byte(config.AppConfig.CSRFSecret)
21+
timestamp := time.Now().Unix()
22+
payload := fmt.Sprintf("%s:%d", userID, timestamp)
23+
24+
mac := hmac.New(sha256.New, csrfSecretKey)
25+
mac.Write([]byte(payload))
26+
signature := mac.Sum(nil)
27+
28+
rawToken := fmt.Sprintf("%s:%d:%s", userID, timestamp, base64.URLEncoding.EncodeToString(signature))
29+
return base64.URLEncoding.EncodeToString([]byte(rawToken)), nil
30+
}
31+
32+
func ValidateCSRFToken(userID, token string) bool {
33+
csrfSecretKey := []byte(config.AppConfig.CSRFSecret)
34+
decoded, err := base64.URLEncoding.DecodeString(token)
35+
if err != nil {
36+
return false
37+
}
38+
39+
parts := strings.Split(string(decoded), ":")
40+
if len(parts) != 3 {
41+
return false
42+
}
43+
44+
tokenUserID := parts[0]
45+
timestampStr := parts[1]
46+
sigBase64 := parts[2]
47+
48+
if tokenUserID != userID {
49+
return false
50+
}
51+
52+
timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
53+
if err != nil {
54+
return false
55+
}
56+
57+
maxAge := time.Duration(config.AppConfig.CSRFMaxAge) * time.Minute
58+
59+
// Check expiration
60+
if time.Since(time.Unix(timestamp, 0)) > maxAge {
61+
return false
62+
}
63+
64+
// Recompute HMAC
65+
payload := fmt.Sprintf("%s:%d", userID, timestamp)
66+
mac := hmac.New(sha256.New, csrfSecretKey)
67+
mac.Write([]byte(payload))
68+
expectedSig := base64.URLEncoding.EncodeToString(mac.Sum(nil))
69+
70+
return hmac.Equal([]byte(expectedSig), []byte(sigBase64))
71+
}

0 commit comments

Comments
 (0)