From 9cfe91da28e790faf6707e68b592523742a037b6 Mon Sep 17 00:00:00 2001 From: Rahul Chocha Date: Thu, 21 Sep 2023 18:25:17 +0530 Subject: [PATCH] feat: add support for access token generation and authentication --- go.mod | 3 ++ handlers/auth.go | 106 +++++++++++++++++++++++++++++++++++++++++++ handlers/consumer.go | 35 +++----------- handlers/producer.go | 69 +++++----------------------- middlewares/auth.go | 90 +++++++++++++++++++++++++++++------- models/auth.go | 6 +++ router/auth.go | 1 + 7 files changed, 207 insertions(+), 103 deletions(-) diff --git a/go.mod b/go.mod index 5804831..1a80301 100644 --- a/go.mod +++ b/go.mod @@ -46,3 +46,6 @@ require ( google.golang.org/protobuf v1.31.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) + +// TODO: remove this once sdk changes are released +replace github.com/memphisdev/memphis.go => ../memphis.go diff --git a/handlers/auth.go b/handlers/auth.go index 6741254..b236878 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "errors" "rest-gateway/conf" "rest-gateway/logger" "rest-gateway/memphisSingleton" @@ -13,6 +14,7 @@ import ( "time" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" "github.com/golang-jwt/jwt/v4" "github.com/memphisdev/memphis.go" "github.com/nats-io/nats.go" @@ -40,6 +42,51 @@ type refreshTokenExpiration struct { var ConnectionsCache = map[string]map[string]Connection{} +func getConnectionForUserData(userData models.AuthSchema) (*memphis.Conn, int, error) { + var err error + username := userData.Username + accountId := userData.AccountId + accountIdStr := strconv.Itoa(int(accountId)) + + var conn *memphis.Conn + if userData.AccessKeyID != "" && userData.SecretKey != "" { + conn = ConnectionsCache[accountIdStr][username].Connection + } else { + conn = ConnectionsCache[accountIdStr][userData.AccessKeyID].Connection + } + + if conn == nil { + conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId)) + if err != nil { + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) { + log.Warnf("Could not establish new connection with the broker: Authentication error") + return nil, fiber.StatusUnauthorized, errors.New("Unauthorized") + } + + log.Errorf("Could not establish new connection with the broker: %s", err.Error()) + return nil, fiber.StatusInternalServerError, errors.New("Server error") + } + + if ConnectionsCache[accountIdStr] == nil { + ConnectionsCacheLock.Lock() + ConnectionsCache[accountIdStr] = make(map[string]Connection) + ConnectionsCacheLock.Unlock() + } + + ConnectionsCacheLock.Lock() + if userData.AccessKeyID != "" && userData.SecretKey != "" { + ConnectionsCache[accountIdStr][userData.AccessKeyID] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry} + } else { + ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry} + } + + ConnectionsCacheLock.Unlock() + } + + return conn, 0, nil +} + func Connect(password, username, connectionToken string, accountId int) (*memphis.Conn, error) { if configuration.USER_PASS_BASED_AUTH { if accountId == 0 { @@ -282,6 +329,65 @@ func (ah AuthHandler) RefreshToken(c *fiber.Ctx) error { }) } +func (ah AuthHandler) GenerateAccessToken(c *fiber.Ctx) error { + log := logger.GetLogger(c) + var body models.GenerateAccessTokenSchema + if err := c.BodyParser(&body); err != nil { + log.Errorf("GenerateAccessToken: %s", err.Error()) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Server error", + }) + } + if err := utils.Validate(body); err != nil { + return c.Status(400).JSON(fiber.Map{ + "message": err, + }) + } + userData, ok := c.Locals("userData").(models.AuthSchema) + if !ok { + log.Errorf("GenerateAccessToken: failed to get the user data from the middleware") + c.Status(fiber.StatusInternalServerError) + return c.JSON(&fiber.Map{ + "success": false, + "error": "Server error", + }) + } + + username := userData.Username + accountId := int(userData.AccountId) + password := userData.Password + connectionToken := userData.ConnectionToken + + conn, err := Connect(password, username, connectionToken, accountId) + if err != nil { + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) { + log.Warnf("GenerateAccessToken: Authentication error") + return c.Status(401).JSON(fiber.Map{ + "message": "Unauthorized", + }) + } + + log.Errorf("GenerateAccessToken: %s", err.Error()) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Server error", + }) + } + + generatedTokenData, err := conn.GenerateAccessToken(username, body.Description) + if err != nil { + log.Errorf("GenerateAccessToken: %s", err.Error()) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "message": "Server error", + }) + } + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "access_key_id": generatedTokenData.AccessKeyID, + "secret_key": generatedTokenData.SecretKey, + }) +} + func CleanConnectionsCache() { for range time.Tick(time.Second * 30) { for t, tenant := range ConnectionsCache { diff --git a/handlers/consumer.go b/handlers/consumer.go index 12757a8..e08051f 100644 --- a/handlers/consumer.go +++ b/handlers/consumer.go @@ -4,7 +4,6 @@ import ( "fmt" "rest-gateway/logger" "rest-gateway/models" - "strconv" "strings" "time" @@ -65,36 +64,14 @@ func ConsumeHandleMessage() func(*fiber.Ctx) error { "error": "Server error", }) } - username := userData.Username - accountId := userData.AccountId - accountIdStr := strconv.Itoa(int(accountId)) - conn := ConnectionsCache[accountIdStr][username].Connection - if conn == nil { - conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId)) - if err != nil { - errMsg := strings.ToLower(err.Error()) - if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) { - log.Warnf("Could not establish new connection with the broker: Authentication error") - return c.Status(401).JSON(fiber.Map{ - "message": "Unauthorized", - }) - } - - log.Errorf("Could not establish new connection with the broker: %s", err.Error()) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ - "message": "Server error", - }) - } - if ConnectionsCache[accountIdStr] == nil { - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr] = make(map[string]Connection) - ConnectionsCacheLock.Unlock() - } - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry} - ConnectionsCacheLock.Unlock() + conn, errorCode, err := getConnectionForUserData(userData) + if err != nil { + return c.Status(errorCode).JSON(fiber.Map{ + "message": err.Error(), + }) } + reqBody.initializeDefaults() msgs, err := conn.FetchMessages(stationName, reqBody.ConsumerName, memphis.FetchBatchSize(reqBody.BatchSize), diff --git a/handlers/producer.go b/handlers/producer.go index ea42ba3..9301d2f 100644 --- a/handlers/producer.go +++ b/handlers/producer.go @@ -3,7 +3,6 @@ package handlers import ( "encoding/json" "errors" - "strconv" "rest-gateway/logger" "rest-gateway/models" @@ -66,36 +65,14 @@ func CreateHandleMessage() func(*fiber.Ctx) error { "error": "Server error", }) } - username := userData.Username - accountId := userData.AccountId - accountIdStr := strconv.Itoa(int(accountId)) - conn := ConnectionsCache[accountIdStr][username].Connection - if conn == nil { - conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId)) - if err != nil { - errMsg := strings.ToLower(err.Error()) - if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) { - log.Warnf("Could not establish new connection with the broker: Authentication error") - return c.Status(401).JSON(fiber.Map{ - "message": "Unauthorized", - }) - } - - log.Errorf("Could not establish new connection with the broker: %s", err.Error()) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ - "message": "Server error", - }) - } - if ConnectionsCache[accountIdStr] == nil { - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr] = make(map[string]Connection) - ConnectionsCacheLock.Unlock() - } - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry} - ConnectionsCacheLock.Unlock() + conn, errorCode, err := getConnectionForUserData(userData) + if err != nil { + return c.Status(errorCode).JSON(fiber.Map{ + "message": err.Error(), + }) } + err = conn.Produce(stationName, "rest-gateway", message, []memphis.ProducerOpt{}, []memphis.ProduceOpt{memphis.MsgHeaders(hdrs)}) if err != nil { log.Errorf("CreateHandleMessage - produce: %s", err.Error()) @@ -157,35 +134,11 @@ func CreateHandleBatch() func(*fiber.Ctx) error { }) } - username := userData.Username - accountId := userData.AccountId - accountIdStr := strconv.Itoa(int(accountId)) - conn := ConnectionsCache[accountIdStr][username].Connection - if conn == nil { - conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId)) - if err != nil { - errMsg := strings.ToLower(err.Error()) - if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) { - log.Warnf("Could not establish new connection with the broker: Authentication error") - return c.Status(401).JSON(fiber.Map{ - "message": "Unauthorized", - }) - } - - log.Errorf("Could not establish new connection with the broker: %s", err.Error()) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ - "message": "Server error", - }) - } - if ConnectionsCache[accountIdStr] == nil { - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr] = make(map[string]Connection) - ConnectionsCacheLock.Unlock() - } - - ConnectionsCacheLock.Lock() - ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry} - ConnectionsCacheLock.Unlock() + conn, errorCode, err := getConnectionForUserData(userData) + if err != nil { + return c.Status(errorCode).JSON(fiber.Map{ + "message": err.Error(), + }) } errCount := 0 diff --git a/middlewares/auth.go b/middlewares/auth.go index 974e739..8159423 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" "github.com/golang-jwt/jwt/v4" ) @@ -46,7 +47,7 @@ func extractToken(authHeader string) (string, error) { return tokenString, nil } -func verifyToken(tokenString string, secret string) (models.AuthSchema, error) { +func verifyJWTToken(tokenString string, secret string) (models.AuthSchema, error) { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -89,6 +90,45 @@ func verifyToken(tokenString string, secret string) (models.AuthSchema, error) { return user, nil } +func verifyAccessToken(accessKeyId string, secretKey string) (models.AuthSchema, error) { + accountId := 1 + conn, err := handlers.Connect(configuration.ROOT_PASSWORD, configuration.ROOT_USER, configuration.CONNECTION_TOKEN, accountId) + if err != nil { + return models.AuthSchema{}, err + } + + isValid, err := conn.ValidateAccessToken(accessKeyId, secretKey) + if err != nil { + log.Warnf("Authentication error - access key id and secret key is not valid", err) + return models.AuthSchema{}, err + } + + if !isValid { + log.Warnf("Authentication error - invalid access key id or secret key") + return models.AuthSchema{}, errors.New("invalid access key id or secret key!") + } + + var user models.AuthSchema + if !configuration.USER_PASS_BASED_AUTH { + user = models.AuthSchema{ + Username: configuration.ROOT_USER, + ConnectionToken: configuration.CONNECTION_TOKEN, + AccountId: 1, + AccessKeyID: accessKeyId, + SecretKey: secretKey, + } + } else { + user = models.AuthSchema{ + Username: configuration.ROOT_USER, + Password: configuration.ROOT_PASSWORD, + AccountId: 1, + AccessKeyID: accessKeyId, + SecretKey: secretKey, + } + } + return user, nil +} + func Authenticate(c *fiber.Ctx) error { log := logger.GetLogger(c) path := strings.ToLower(string(c.Context().URI().RequestURI())) @@ -97,11 +137,15 @@ func Authenticate(c *fiber.Ctx) error { path = strings.Split(path, "?")[0] if isAuthNeeded(path) { headers := c.GetReqHeaders() - tokenString, err := extractToken(headers["Authorization"]) - if err != nil || tokenString == "" { - tokenString = c.Query("authorization") - if tokenString == "" { // fallback - get the token from the query params - log.Warnf("Authentication error - jwt token is missing") + + accessKayId := headers["Access-Key-Id"] + secretKay := headers["Secret-Key"] + jwtToken := headers["Authorization"] + + if accessKayId != "" && secretKay != "" { + user, err = verifyAccessToken(accessKayId, secretKay) + if err != nil { + log.Warnf("Authentication error - access key id and secret key is not valid") if configuration.DEBUG { fmt.Printf("Method: %s, Path: %s, IP: %s\nBody: %s\n", c.Method(), c.Path(), c.IP(), string(c.Body())) } @@ -109,16 +153,30 @@ func Authenticate(c *fiber.Ctx) error { "message": "Unauthorized", }) } - } - user, err = verifyToken(tokenString, configuration.JWT_SECRET) - if err != nil { - log.Warnf("Authentication error - jwt token validation has failed") - if configuration.DEBUG { - fmt.Printf("Method: %s, Path: %s, IP: %s\nBody: %s\n", c.Method(), c.Path(), c.IP(), string(c.Body())) + } else { + tokenString, err := extractToken(jwtToken) + if err != nil || tokenString == "" { + tokenString = c.Query("authorization") + if tokenString == "" { // fallback - get the token from the query params + log.Warnf("Authentication error - jwt token is missing") + if configuration.DEBUG { + fmt.Printf("Method: %s, Path: %s, IP: %s\nBody: %s\n", c.Method(), c.Path(), c.IP(), string(c.Body())) + } + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "message": "Unauthorized", + }) + } + } + user, err = verifyJWTToken(tokenString, configuration.JWT_SECRET) + if err != nil { + log.Warnf("Authentication error - jwt token validation has failed") + if configuration.DEBUG { + fmt.Printf("Method: %s, Path: %s, IP: %s\nBody: %s\n", c.Method(), c.Path(), c.IP(), string(c.Body())) + } + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "message": "Unauthorized", + }) } - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ - "message": "Unauthorized", - }) } } else if path == "/auth/refreshtoken" { var body models.RefreshTokenSchema @@ -142,7 +200,7 @@ func Authenticate(c *fiber.Ctx) error { }) } - user, err = verifyToken(body.JwtRefreshToken, configuration.REFRESH_JWT_SECRET) + user, err = verifyJWTToken(body.JwtRefreshToken, configuration.REFRESH_JWT_SECRET) if err != nil { log.Warnf("Authentication error - refresh token validation has failed") if configuration.DEBUG { diff --git a/models/auth.go b/models/auth.go index 2199261..3d9bff3 100644 --- a/models/auth.go +++ b/models/auth.go @@ -8,6 +8,8 @@ type AuthSchema struct { RefreshTokenExpiryMins int `json:"refresh_token_expiry_in_minutes"` AccountId float64 `json:"account_id"` TokenExpiry int64 `json:"token_expiry"` + AccessKeyID string `json:"access_key_id"` + SecretKey string `json:"secret_key"` } type RefreshTokenSchema struct { @@ -15,3 +17,7 @@ type RefreshTokenSchema struct { TokenExpiryMins int `json:"token_expiry_in_minutes"` RefreshTokenExpiryMins int `json:"refresh_token_expiry_in_minutes"` } + +type GenerateAccessTokenSchema struct { + Description string `json:"description"` +} diff --git a/router/auth.go b/router/auth.go index 8d6790e..35b99e3 100644 --- a/router/auth.go +++ b/router/auth.go @@ -12,4 +12,5 @@ func InitilizeAuthRoutes(app *fiber.App) { api := app.Group("/auth", logger.New()) api.Post("/authenticate", authHandler.Authenticate) api.Post("/refreshToken", authHandler.RefreshToken) + api.Post("/generateAccessToken", authHandler.GenerateAccessToken) }