11package handlers
22
33import (
4+ "encoding/json"
45 "rest-gateway/conf"
56 "rest-gateway/logger"
7+ "rest-gateway/memphisSingleton"
68 "rest-gateway/models"
79 "rest-gateway/utils"
810 "strconv"
@@ -13,6 +15,7 @@ import (
1315 "github.com/gofiber/fiber/v2"
1416 "github.com/golang-jwt/jwt/v4"
1517 "github.com/memphisdev/memphis.go"
18+ "github.com/nats-io/nats.go"
1619)
1720
1821var configuration = conf .GetConfig ()
@@ -82,6 +85,7 @@ func (ah AuthHandler) Authenticate(c *fiber.Ctx) error {
8285 accountId = accId
8386 }
8487 }
88+
8589 conn , err := Connect (body .Password , body .Username , body .ConnectionToken , accountId )
8690 if err != nil {
8791 errMsg := strings .ToLower (err .Error ())
@@ -119,6 +123,43 @@ func (ah AuthHandler) Authenticate(c *fiber.Ctx) error {
119123 ConnectionsCacheLock .Lock ()
120124 ConnectionsCache [accountIdStr ][username ] = Connection {Connection : conn , ExpirationTime : tokenExpiry }
121125 ConnectionsCacheLock .Unlock ()
126+
127+ mc , err := memphisSingleton .GetMemphisConnection ("" , "" , "" ) // already initialized on logger creation
128+ if err != nil {
129+ log .Errorf ("Authenticate: %s" , err .Error ())
130+ return c .Status (fiber .StatusInternalServerError ).JSON (fiber.Map {
131+ "message" : "Server error" ,
132+ })
133+ }
134+
135+ update := models.RestGwUpdate {
136+ Type : "update_connection" ,
137+ Update : map [string ]interface {}{
138+ "password" : body .Password ,
139+ "username" : body .Username ,
140+ "connection_token" : body .ConnectionToken ,
141+ "account_id" : accountId ,
142+ "token_expiry" : tokenExpiry ,
143+ },
144+ }
145+
146+ msg , err := json .Marshal (update )
147+ if err != nil {
148+ log .Errorf ("Authenticate: %s" , err .Error ())
149+ return c .Status (fiber .StatusInternalServerError ).JSON (fiber.Map {
150+ "message" : "Server error" ,
151+ })
152+ }
153+
154+ // send to other rest GWs to update their cache
155+ err = mc .Publish (configuration .REST_GW_UPDATES_SUBJ , msg )
156+ if err != nil {
157+ log .Errorf ("Authenticate: %s" , err .Error ())
158+ return c .Status (fiber .StatusInternalServerError ).JSON (fiber.Map {
159+ "message" : "Server error" ,
160+ })
161+ }
162+
122163 return c .Status (fiber .StatusOK ).JSON (fiber.Map {
123164 "jwt" : token ,
124165 "expires_in" : tokenExpiry * 60 * 1000 ,
@@ -242,7 +283,11 @@ func CleanConnectionsCache() {
242283 currentTime := time .Now ()
243284 unixTimeNow := currentTime .Unix ()
244285 conn := ConnectionsCache [t ][u ].Connection
245- if unixTimeNow > int64 (user .ExpirationTime ) {
286+ if ! conn .IsConnected () {
287+ ConnectionsCacheLock .Lock ()
288+ delete (ConnectionsCache [t ], u )
289+ ConnectionsCacheLock .Unlock ()
290+ } else if unixTimeNow > int64 (user .ExpirationTime ) {
246291 conn .Close ()
247292 ConnectionsCacheLock .Lock ()
248293 delete (ConnectionsCache [t ], u )
@@ -257,3 +302,59 @@ func CleanConnectionsCache() {
257302 }
258303 }
259304}
305+
306+ func ListenForUpdates (log * logger.Logger ) error {
307+ mc , err := memphisSingleton .GetMemphisConnection ("" , "" , "" ) // already initialized on logger creation
308+ if err != nil {
309+ return err
310+ }
311+
312+ _ , err = mc .Subscribe (configuration .REST_GW_UPDATES_SUBJ , func (msg * nats.Msg ) {
313+ var update models.RestGwUpdate
314+ err := json .Unmarshal (msg .Data , & update )
315+ if err != nil {
316+ log .Errorf ("update unmarshal error: %v\n " , err .Error ())
317+ return
318+ }
319+
320+ switch update .Type {
321+ case "update_connection" :
322+ username := update .Update ["username" ].(string )
323+ accountId := int (update .Update ["account_id" ].(float64 ))
324+ username = strings .ToLower (username )
325+ accountIdStr := strconv .Itoa (accountId )
326+
327+ if ConnectionsCache [accountIdStr ] != nil {
328+ _ , exists := ConnectionsCache [accountIdStr ][username ]
329+ if exists {
330+ return // connection already exists, nothing to update
331+ }
332+ }
333+
334+ conn , err := Connect (update .Update ["password" ].(string ), username , update .Update ["connection_token" ].(string ), accountId )
335+ if err != nil {
336+ errMsg := strings .ToLower (err .Error ())
337+ if strings .Contains (errMsg , ErrorMsgAuthorizationViolation ) || strings .Contains (errMsg , "token" ) || strings .Contains (errMsg , ErrorMsgMissionAccountId ) {
338+ return
339+ }
340+
341+ log .Errorf ("ListenForUpdates: %s" , err .Error ())
342+ return
343+ }
344+
345+ if ConnectionsCache [accountIdStr ] == nil {
346+ ConnectionsCacheLock .Lock ()
347+ ConnectionsCache [accountIdStr ] = make (map [string ]Connection )
348+ ConnectionsCacheLock .Unlock ()
349+ }
350+
351+ ConnectionsCacheLock .Lock ()
352+ ConnectionsCache [accountIdStr ][username ] = Connection {Connection : conn , ExpirationTime : int64 (update .Update ["token_expiry" ].(float64 ))}
353+ ConnectionsCacheLock .Unlock ()
354+ }
355+ })
356+ if err != nil {
357+ return err
358+ }
359+ return nil
360+ }
0 commit comments