Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ require (
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)

// TODO: remove this once sdk changes are released
replace github.com/memphisdev/memphis.go => ../memphis.go
106 changes: 106 additions & 0 deletions handlers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"encoding/json"
"errors"
"rest-gateway/conf"
"rest-gateway/logger"
"rest-gateway/memphisSingleton"
Expand All @@ -13,6 +14,7 @@ import (
"time"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/golang-jwt/jwt/v4"
"github.com/memphisdev/memphis.go"
"github.com/nats-io/nats.go"
Expand Down Expand Up @@ -40,6 +42,51 @@ type refreshTokenExpiration struct {

var ConnectionsCache = map[string]map[string]Connection{}

func getConnectionForUserData(userData models.AuthSchema) (*memphis.Conn, int, error) {
var err error
username := userData.Username
accountId := userData.AccountId
accountIdStr := strconv.Itoa(int(accountId))

var conn *memphis.Conn
if userData.AccessKeyID != "" && userData.SecretKey != "" {
conn = ConnectionsCache[accountIdStr][username].Connection
} else {
conn = ConnectionsCache[accountIdStr][userData.AccessKeyID].Connection
}

if conn == nil {
conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId))
if err != nil {
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) {
log.Warnf("Could not establish new connection with the broker: Authentication error")
return nil, fiber.StatusUnauthorized, errors.New("Unauthorized")
}

log.Errorf("Could not establish new connection with the broker: %s", err.Error())
return nil, fiber.StatusInternalServerError, errors.New("Server error")
}

if ConnectionsCache[accountIdStr] == nil {
ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr] = make(map[string]Connection)
ConnectionsCacheLock.Unlock()
}

ConnectionsCacheLock.Lock()
if userData.AccessKeyID != "" && userData.SecretKey != "" {
ConnectionsCache[accountIdStr][userData.AccessKeyID] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry}
} else {
ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry}
}

ConnectionsCacheLock.Unlock()
}

return conn, 0, nil
}

func Connect(password, username, connectionToken string, accountId int) (*memphis.Conn, error) {
if configuration.USER_PASS_BASED_AUTH {
if accountId == 0 {
Expand Down Expand Up @@ -282,6 +329,65 @@ func (ah AuthHandler) RefreshToken(c *fiber.Ctx) error {
})
}

func (ah AuthHandler) GenerateAccessToken(c *fiber.Ctx) error {
log := logger.GetLogger(c)
var body models.GenerateAccessTokenSchema
if err := c.BodyParser(&body); err != nil {
log.Errorf("GenerateAccessToken: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}
if err := utils.Validate(body); err != nil {
return c.Status(400).JSON(fiber.Map{
"message": err,
})
}
userData, ok := c.Locals("userData").(models.AuthSchema)
if !ok {
log.Errorf("GenerateAccessToken: failed to get the user data from the middleware")
c.Status(fiber.StatusInternalServerError)
return c.JSON(&fiber.Map{
"success": false,
"error": "Server error",
})
}

username := userData.Username
accountId := int(userData.AccountId)
password := userData.Password
connectionToken := userData.ConnectionToken

conn, err := Connect(password, username, connectionToken, accountId)
if err != nil {
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) {
log.Warnf("GenerateAccessToken: Authentication error")
return c.Status(401).JSON(fiber.Map{
"message": "Unauthorized",
})
}

log.Errorf("GenerateAccessToken: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}

generatedTokenData, err := conn.GenerateAccessToken(username, body.Description)
if err != nil {
log.Errorf("GenerateAccessToken: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}

return c.Status(fiber.StatusOK).JSON(fiber.Map{
"access_key_id": generatedTokenData.AccessKeyID,
"secret_key": generatedTokenData.SecretKey,
})
}

func CleanConnectionsCache() {
for range time.Tick(time.Second * 30) {
for t, tenant := range ConnectionsCache {
Expand Down
35 changes: 6 additions & 29 deletions handlers/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"rest-gateway/logger"
"rest-gateway/models"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -65,36 +64,14 @@ func ConsumeHandleMessage() func(*fiber.Ctx) error {
"error": "Server error",
})
}
username := userData.Username
accountId := userData.AccountId
accountIdStr := strconv.Itoa(int(accountId))
conn := ConnectionsCache[accountIdStr][username].Connection
if conn == nil {
conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId))
if err != nil {
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) {
log.Warnf("Could not establish new connection with the broker: Authentication error")
return c.Status(401).JSON(fiber.Map{
"message": "Unauthorized",
})
}

log.Errorf("Could not establish new connection with the broker: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}
if ConnectionsCache[accountIdStr] == nil {
ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr] = make(map[string]Connection)
ConnectionsCacheLock.Unlock()
}

ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry}
ConnectionsCacheLock.Unlock()
conn, errorCode, err := getConnectionForUserData(userData)
if err != nil {
return c.Status(errorCode).JSON(fiber.Map{
"message": err.Error(),
})
}

reqBody.initializeDefaults()
msgs, err := conn.FetchMessages(stationName, reqBody.ConsumerName,
memphis.FetchBatchSize(reqBody.BatchSize),
Expand Down
69 changes: 11 additions & 58 deletions handlers/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handlers
import (
"encoding/json"
"errors"
"strconv"

"rest-gateway/logger"
"rest-gateway/models"
Expand Down Expand Up @@ -66,36 +65,14 @@ func CreateHandleMessage() func(*fiber.Ctx) error {
"error": "Server error",
})
}
username := userData.Username
accountId := userData.AccountId
accountIdStr := strconv.Itoa(int(accountId))
conn := ConnectionsCache[accountIdStr][username].Connection
if conn == nil {
conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId))
if err != nil {
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) {
log.Warnf("Could not establish new connection with the broker: Authentication error")
return c.Status(401).JSON(fiber.Map{
"message": "Unauthorized",
})
}

log.Errorf("Could not establish new connection with the broker: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}
if ConnectionsCache[accountIdStr] == nil {
ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr] = make(map[string]Connection)
ConnectionsCacheLock.Unlock()
}

ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry}
ConnectionsCacheLock.Unlock()
conn, errorCode, err := getConnectionForUserData(userData)
if err != nil {
return c.Status(errorCode).JSON(fiber.Map{
"message": err.Error(),
})
}

err = conn.Produce(stationName, "rest-gateway", message, []memphis.ProducerOpt{}, []memphis.ProduceOpt{memphis.MsgHeaders(hdrs)})
if err != nil {
log.Errorf("CreateHandleMessage - produce: %s", err.Error())
Expand Down Expand Up @@ -157,35 +134,11 @@ func CreateHandleBatch() func(*fiber.Ctx) error {
})
}

username := userData.Username
accountId := userData.AccountId
accountIdStr := strconv.Itoa(int(accountId))
conn := ConnectionsCache[accountIdStr][username].Connection
if conn == nil {
conn, err = Connect(userData.Password, username, userData.ConnectionToken, int(accountId))
if err != nil {
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, ErrorMsgAuthorizationViolation) || strings.Contains(errMsg, "token") || strings.Contains(errMsg, ErrorMsgMissionAccountId) {
log.Warnf("Could not establish new connection with the broker: Authentication error")
return c.Status(401).JSON(fiber.Map{
"message": "Unauthorized",
})
}

log.Errorf("Could not establish new connection with the broker: %s", err.Error())
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"message": "Server error",
})
}
if ConnectionsCache[accountIdStr] == nil {
ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr] = make(map[string]Connection)
ConnectionsCacheLock.Unlock()
}

ConnectionsCacheLock.Lock()
ConnectionsCache[accountIdStr][username] = Connection{Connection: conn, ExpirationTime: userData.TokenExpiry}
ConnectionsCacheLock.Unlock()
conn, errorCode, err := getConnectionForUserData(userData)
if err != nil {
return c.Status(errorCode).JSON(fiber.Map{
"message": err.Error(),
})
}

errCount := 0
Expand Down
Loading