Skip to content

Commit 781bc15

Browse files
fix(oauth): prevent stale session timeouts after login
- stop callback forwarders by instance to avoid cross-session shutdowns - clear pending sessions for a provider after successful auth
1 parent 05d201e commit 781bc15

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

internal/api/handlers/management/auth_files.go

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) {
197197
stopForwarderInstance(port, forwarder)
198198
}
199199

200+
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
201+
if forwarder == nil {
202+
return
203+
}
204+
callbackForwardersMu.Lock()
205+
if current := callbackForwarders[port]; current == forwarder {
206+
delete(callbackForwarders, port)
207+
}
208+
callbackForwardersMu.Unlock()
209+
210+
stopForwarderInstance(port, forwarder)
211+
}
212+
200213
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
201214
if forwarder == nil || forwarder.server == nil {
202215
return
@@ -785,14 +798,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
785798
RegisterOAuthSession(state, "anthropic")
786799

787800
isWebUI := isWebUIRequest(c)
801+
var forwarder *callbackForwarder
788802
if isWebUI {
789803
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
790804
if errTarget != nil {
791805
log.WithError(errTarget).Error("failed to compute anthropic callback target")
792806
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
793807
return
794808
}
795-
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
809+
var errStart error
810+
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
796811
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
797812
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
798813
return
@@ -801,14 +816,17 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
801816

802817
go func() {
803818
if isWebUI {
804-
defer stopCallbackForwarder(anthropicCallbackPort)
819+
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
805820
}
806821

807822
// Helper: wait for callback file
808823
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
809824
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
810825
deadline := time.Now().Add(timeout)
811826
for {
827+
if !IsOAuthSessionPending(state, "anthropic") {
828+
return nil, errOAuthSessionNotPending
829+
}
812830
if time.Now().After(deadline) {
813831
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
814832
return nil, fmt.Errorf("timeout waiting for OAuth callback")
@@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
828846
// Wait up to 5 minutes
829847
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
830848
if errWait != nil {
849+
if errors.Is(errWait, errOAuthSessionNotPending) {
850+
return
851+
}
831852
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
832853
log.Error(claude.GetUserFriendlyMessage(authErr))
833854
return
@@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
933954
}
934955
fmt.Println("You can now use Claude services through this CLI")
935956
CompleteOAuthSession(state)
957+
CompleteOAuthSessionsByProvider("anthropic")
936958
}()
937959

938960
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -968,14 +990,16 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
968990
RegisterOAuthSession(state, "gemini")
969991

970992
isWebUI := isWebUIRequest(c)
993+
var forwarder *callbackForwarder
971994
if isWebUI {
972995
targetURL, errTarget := h.managementCallbackURL("/google/callback")
973996
if errTarget != nil {
974997
log.WithError(errTarget).Error("failed to compute gemini callback target")
975998
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
976999
return
9771000
}
978-
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
1001+
var errStart error
1002+
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
9791003
log.WithError(errStart).Error("failed to start gemini callback forwarder")
9801004
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
9811005
return
@@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
9841008

9851009
go func() {
9861010
if isWebUI {
987-
defer stopCallbackForwarder(geminiCallbackPort)
1011+
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
9881012
}
9891013

9901014
// Wait for callback file written by server route
@@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
9931017
deadline := time.Now().Add(5 * time.Minute)
9941018
var authCode string
9951019
for {
1020+
if !IsOAuthSessionPending(state, "gemini") {
1021+
return
1022+
}
9961023
if time.Now().After(deadline) {
9971024
log.Error("oauth flow timed out")
9981025
SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1168,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
11681195
}
11691196

11701197
CompleteOAuthSession(state)
1198+
CompleteOAuthSessionsByProvider("gemini")
11711199
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
11721200
}()
11731201

@@ -1209,14 +1237,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
12091237
RegisterOAuthSession(state, "codex")
12101238

12111239
isWebUI := isWebUIRequest(c)
1240+
var forwarder *callbackForwarder
12121241
if isWebUI {
12131242
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
12141243
if errTarget != nil {
12151244
log.WithError(errTarget).Error("failed to compute codex callback target")
12161245
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
12171246
return
12181247
}
1219-
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
1248+
var errStart error
1249+
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
12201250
log.WithError(errStart).Error("failed to start codex callback forwarder")
12211251
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
12221252
return
@@ -1225,14 +1255,17 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
12251255

12261256
go func() {
12271257
if isWebUI {
1228-
defer stopCallbackForwarder(codexCallbackPort)
1258+
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
12291259
}
12301260

12311261
// Wait for callback file
12321262
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
12331263
deadline := time.Now().Add(5 * time.Minute)
12341264
var code string
12351265
for {
1266+
if !IsOAuthSessionPending(state, "codex") {
1267+
return
1268+
}
12361269
if time.Now().After(deadline) {
12371270
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
12381271
log.Error(codex.GetUserFriendlyMessage(authErr))
@@ -1348,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
13481381
}
13491382
fmt.Println("You can now use Codex services through this CLI")
13501383
CompleteOAuthSession(state)
1384+
CompleteOAuthSessionsByProvider("codex")
13511385
}()
13521386

13531387
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -1393,14 +1427,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
13931427
RegisterOAuthSession(state, "antigravity")
13941428

13951429
isWebUI := isWebUIRequest(c)
1430+
var forwarder *callbackForwarder
13961431
if isWebUI {
13971432
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
13981433
if errTarget != nil {
13991434
log.WithError(errTarget).Error("failed to compute antigravity callback target")
14001435
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
14011436
return
14021437
}
1403-
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
1438+
var errStart error
1439+
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
14041440
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
14051441
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
14061442
return
@@ -1409,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
14091445

14101446
go func() {
14111447
if isWebUI {
1412-
defer stopCallbackForwarder(antigravityCallbackPort)
1448+
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
14131449
}
14141450

14151451
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
14161452
deadline := time.Now().Add(5 * time.Minute)
14171453
var authCode string
14181454
for {
1455+
if !IsOAuthSessionPending(state, "antigravity") {
1456+
return
1457+
}
14191458
if time.Now().After(deadline) {
14201459
log.Error("oauth flow timed out")
14211460
SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1578,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
15781617
}
15791618

15801619
CompleteOAuthSession(state)
1620+
CompleteOAuthSessionsByProvider("antigravity")
15811621
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
15821622
if projectID != "" {
15831623
fmt.Printf("Using GCP project: %s\n", projectID)
@@ -1655,14 +1695,16 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
16551695
RegisterOAuthSession(state, "iflow")
16561696

16571697
isWebUI := isWebUIRequest(c)
1698+
var forwarder *callbackForwarder
16581699
if isWebUI {
16591700
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
16601701
if errTarget != nil {
16611702
log.WithError(errTarget).Error("failed to compute iflow callback target")
16621703
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
16631704
return
16641705
}
1665-
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
1706+
var errStart error
1707+
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
16661708
log.WithError(errStart).Error("failed to start iflow callback forwarder")
16671709
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
16681710
return
@@ -1671,14 +1713,17 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
16711713

16721714
go func() {
16731715
if isWebUI {
1674-
defer stopCallbackForwarder(iflowauth.CallbackPort)
1716+
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
16751717
}
16761718
fmt.Println("Waiting for authentication...")
16771719

16781720
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
16791721
deadline := time.Now().Add(5 * time.Minute)
16801722
var resultMap map[string]string
16811723
for {
1724+
if !IsOAuthSessionPending(state, "iflow") {
1725+
return
1726+
}
16821727
if time.Now().After(deadline) {
16831728
SetOAuthSessionError(state, "Authentication failed")
16841729
fmt.Println("Authentication failed: timeout waiting for callback")
@@ -1745,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
17451790
}
17461791
fmt.Println("You can now use iFlow services through this CLI")
17471792
CompleteOAuthSession(state)
1793+
CompleteOAuthSessionsByProvider("iflow")
17481794
}()
17491795

17501796
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})

internal/api/handlers/management/oauth_sessions.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) {
111111
delete(s.sessions, state)
112112
}
113113

114+
func (s *oauthSessionStore) CompleteProvider(provider string) int {
115+
provider = strings.ToLower(strings.TrimSpace(provider))
116+
if provider == "" {
117+
return 0
118+
}
119+
now := time.Now()
120+
121+
s.mu.Lock()
122+
defer s.mu.Unlock()
123+
124+
s.purgeExpiredLocked(now)
125+
removed := 0
126+
for state, session := range s.sessions {
127+
if strings.EqualFold(session.Provider, provider) {
128+
delete(s.sessions, state)
129+
removed++
130+
}
131+
}
132+
return removed
133+
}
134+
114135
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
115136
state = strings.TrimSpace(state)
116137
now := time.Now()
@@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state,
153174

154175
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
155176

177+
func CompleteOAuthSessionsByProvider(provider string) int {
178+
return oauthSessions.CompleteProvider(provider)
179+
}
180+
156181
func GetOAuthSession(state string) (provider string, status string, ok bool) {
157182
session, ok := oauthSessions.Get(state)
158183
if !ok {

0 commit comments

Comments
 (0)