Skip to content

Commit b9a00b6

Browse files
author
cici
committed
feat: fix session termination and add per-session task serialization
- Add per-session mutex to serialize tasks: stop command fully completes before new message starts, preventing race conditions - Poll task-level status (not session-level) so cancellation only affects the specific task, not new tasks for the same session - Send SIGTERM instead of SIGKILL on context cancel so CLI processes can save session state before exiting (10s grace period) - Add GET /v1/machines/{id}/tasks/{id}/status endpoint for task-level status polling
1 parent 3427c95 commit b9a00b6

6 files changed

Lines changed: 84 additions & 4 deletions

File tree

internal/app/runner_claude.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"os/exec"
1111
"path/filepath"
1212
"strings"
13+
"syscall"
1314
"time"
1415

1516
"github.com/shotforward/codewithphone/internal/config"
@@ -143,6 +144,12 @@ func (r *claudeRunner) RunTurn(ctx context.Context, dispatch taskDispatch, provi
143144
cmd := exec.CommandContext(ctx, r.claudeBin, args...)
144145
cmd.Dir = dispatch.WorkspaceRoot
145146
cmd.Env = os.Environ()
147+
// Send SIGTERM instead of SIGKILL on context cancel so the CLI
148+
// process can save its session state before exiting.
149+
cmd.Cancel = func() error {
150+
return cmd.Process.Signal(syscall.SIGTERM)
151+
}
152+
cmd.WaitDelay = 10 * time.Second // SIGKILL after 10s if still alive
146153

147154
stdout, err := cmd.StdoutPipe()
148155
if err != nil {

internal/app/runner_codex.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"strconv"
1616
"strings"
1717
"sync/atomic"
18+
"syscall"
1819
"time"
1920

2021
"github.com/shotforward/codewithphone/internal/config"
@@ -789,6 +790,12 @@ type codexRPCClient struct {
789790

790791
func startCodexRPC(ctx context.Context, codexBin string) (*codexRPCClient, error) {
791792
cmd := exec.CommandContext(ctx, codexBin, "app-server", "--listen", "stdio://")
793+
// Send SIGTERM instead of SIGKILL on context cancel so the CLI
794+
// process can save its session/thread state before exiting.
795+
cmd.Cancel = func() error {
796+
return cmd.Process.Signal(syscall.SIGTERM)
797+
}
798+
cmd.WaitDelay = 10 * time.Second
792799
stdin, err := cmd.StdinPipe()
793800
if err != nil {
794801
return nil, err

internal/app/runner_gemini.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"path/filepath"
1919
"strings"
2020
"sync"
21+
"syscall"
2122
"time"
2223

2324
"github.com/shotforward/codewithphone/internal/config"
@@ -261,6 +262,12 @@ deny_message = "No user interactive console is available. Use mcp_pocketcode_run
261262
cmd := exec.CommandContext(ctx, r.geminiBin, args...)
262263
cmd.Dir = dispatch.WorkspaceRoot
263264
cmd.Env = append(os.Environ(), "GEMINI_CLI_HOME="+geminiHome)
265+
// Send SIGTERM instead of SIGKILL on context cancel so the CLI
266+
// process can save its session state before exiting.
267+
cmd.Cancel = func() error {
268+
return cmd.Process.Signal(syscall.SIGTERM)
269+
}
270+
cmd.WaitDelay = 10 * time.Second
264271

265272
// Ensure GEMINI_API_KEY is set when auth type is "gemini-api-key".
266273
// The Gemini CLI's validateAuthMethod checks for this env var before

internal/app/server_client.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,33 @@ func (c serverClient) confirmBinding(ctx context.Context, code, confirmNonce str
405405
return result.MachineToken, nil
406406
}
407407

408+
func (c serverClient) fetchTaskStatus(ctx context.Context, taskRunID string) (string, error) {
409+
url := strings.TrimRight(c.BaseURL, "/") + "/v1/machines/" + c.MachineID + "/tasks/" + taskRunID + "/status"
410+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
411+
if err != nil {
412+
return "", err
413+
}
414+
req.Header.Set("Accept", "application/json")
415+
if c.MachineToken != "" {
416+
req.Header.Set("X-Machine-Token", c.MachineToken)
417+
}
418+
resp, err := c.httpClient().Do(req)
419+
if err != nil {
420+
return "", err
421+
}
422+
defer resp.Body.Close()
423+
if resp.StatusCode != http.StatusOK {
424+
return "", newHTTPStatusError("fetch task status", resp)
425+
}
426+
var result struct {
427+
Status string `json:"status"`
428+
}
429+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
430+
return "", err
431+
}
432+
return strings.TrimSpace(result.Status), nil
433+
}
434+
408435
func (c serverClient) fetchSessionStatus(ctx context.Context, sessionID string) (string, error) {
409436
url := strings.TrimRight(c.BaseURL, "/") + "/v1/sessions/" + sessionID
410437
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)

internal/app/service.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type Service struct {
5656
actualAddr string // resolved listen address (useful when port=0)
5757
providerSessions map[string]string
5858
sessionWorkspaces map[string]string // sessionID -> first workspaceRoot
59+
sessionLocks map[string]*sync.Mutex // per-session lock for serial task execution
5960
taskWorkspaces map[string]string // taskRunID -> workspaceRoot
6061
taskWorkspaceSnapshots map[string]string // taskRunID -> snapshot.Root for "vs turn start" diffs
6162
taskProfiles map[string]turnExecutionProfile
@@ -129,6 +130,7 @@ func New(cfg config.Config, opts ...Option) *Service {
129130
interactive: true, // default: foreground with terminal
130131
providerSessions: map[string]string{},
131132
sessionWorkspaces: map[string]string{},
133+
sessionLocks: map[string]*sync.Mutex{},
132134
taskWorkspaces: map[string]string{},
133135
taskWorkspaceSnapshots: map[string]string{},
134136
taskProfiles: map[string]turnExecutionProfile{},
@@ -367,6 +369,20 @@ func (s *Service) getProviderSession(sessionID string) string {
367369
return s.providerSessions[sessionID]
368370
}
369371

372+
// getSessionLock returns a per-session mutex, creating one if needed.
373+
// Tasks for the same session are serialized through this lock so that
374+
// a stop command is fully processed before a new message starts.
375+
func (s *Service) getSessionLock(sessionID string) *sync.Mutex {
376+
s.mu.Lock()
377+
defer s.mu.Unlock()
378+
lock, ok := s.sessionLocks[sessionID]
379+
if !ok {
380+
lock = &sync.Mutex{}
381+
s.sessionLocks[sessionID] = lock
382+
}
383+
return lock
384+
}
385+
370386
func (s *Service) setProviderSession(sessionID, providerSessionRef string) {
371387
s.mu.Lock()
372388
defer s.mu.Unlock()

internal/app/task_loop.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ func (s *Service) handleCodexDispatch(ctx context.Context, dispatch taskDispatch
8585
}
8686

8787
func (s *Service) handleRunnerDispatch(ctx context.Context, dispatch taskDispatch, runner turnRunner) error {
88+
// Serialize tasks for the same session: wait for any previous task
89+
// (including a cancelled one being drained) to finish before starting.
90+
sessionLock := s.getSessionLock(dispatch.SessionID)
91+
sessionLock.Lock()
92+
defer sessionLock.Unlock()
93+
8894
dispatchStart := time.Now()
8995
profile := planTurnExecution(dispatch.Prompt)
9096
// Gemini session is directory-scoped; pin workspace before any task-scoped
@@ -129,7 +135,7 @@ func (s *Service) handleRunnerDispatch(ctx context.Context, dispatch taskDispatc
129135
done := make(chan struct{})
130136
defer close(done)
131137
terminated := make(chan struct{})
132-
go s.watchSessionTermination(taskCtx, dispatch.SessionID, cancel, done, terminated)
138+
go s.watchSessionTermination(taskCtx, dispatch.SessionID, dispatch.TaskRunID, cancel, done, terminated)
133139

134140
t0 := time.Now()
135141
if err := s.serverClient.postEvent(ctx, daemonEvent{
@@ -335,7 +341,7 @@ func (s *Service) handleRunnerDispatch(ctx context.Context, dispatch taskDispatc
335341
}
336342

337343

338-
func (s *Service) watchSessionTermination(ctx context.Context, sessionID string, cancel context.CancelFunc, done <-chan struct{}, terminated chan struct{}) {
344+
func (s *Service) watchSessionTermination(ctx context.Context, sessionID, taskRunID string, cancel context.CancelFunc, done <-chan struct{}, terminated chan struct{}) {
339345
ticker := time.NewTicker(sessionTerminationPollInterval)
340346
defer ticker.Stop()
341347
for {
@@ -345,11 +351,21 @@ func (s *Service) watchSessionTermination(ctx context.Context, sessionID string,
345351
case <-ctx.Done():
346352
return
347353
case <-ticker.C:
348-
status, err := s.serverClient.fetchSessionStatus(ctx, sessionID)
354+
status, err := s.serverClient.fetchTaskStatus(ctx, taskRunID)
349355
if err != nil {
356+
// Fallback to session-level check for backward compatibility.
357+
sessStatus, sessErr := s.serverClient.fetchSessionStatus(ctx, sessionID)
358+
if sessErr != nil {
359+
continue
360+
}
361+
if sessStatus == "terminated" {
362+
cancel()
363+
close(terminated)
364+
return
365+
}
350366
continue
351367
}
352-
if status == "terminated" {
368+
if status == "cancelled" {
353369
cancel()
354370
close(terminated)
355371
return

0 commit comments

Comments
 (0)