Skip to content

Commit 9460c76

Browse files
authored
Merge pull request #3 from schovi/fix/phase1-critical-correctness
Fix 4 critical correctness bugs in daemon server
2 parents 210496f + 7eccca4 commit 9460c76

File tree

3 files changed

+67
-27
lines changed

3 files changed

+67
-27
lines changed

cmd/search.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ func runSearch(cmd *cobra.Command, args []string) error {
5252
after = searchAroundFlag
5353
}
5454

55+
if before < 0 || after < 0 {
56+
return fmt.Errorf("--before, --after, and --around must be non-negative")
57+
}
58+
5559
client := daemon.NewClient()
5660
if err := client.EnsureDaemon(); err != nil {
5761
return fmt.Errorf("daemon: %w", err)

internal/daemon/server.go

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ import (
1717
"github.com/schovi/shelli/internal/ansi"
1818
)
1919

20+
type ptyHandle struct {
21+
f *os.File
22+
closeOnce sync.Once
23+
}
24+
25+
func (p *ptyHandle) Close() {
26+
p.closeOnce.Do(func() {
27+
p.f.Close()
28+
})
29+
}
30+
31+
func (p *ptyHandle) File() *os.File {
32+
return p.f
33+
}
34+
2035
type Session struct {
2136
Name string `json:"name"`
2237
PID int `json:"pid"`
@@ -38,7 +53,7 @@ type SessionInfo struct {
3853
type Server struct {
3954
mu sync.Mutex
4055
sessions map[string]*Session
41-
ptys map[string]*os.File
56+
ptys map[string]*ptyHandle
4257
cmds map[string]*exec.Cmd
4358
doneChans map[string]chan struct{}
4459
frameDetectors map[string]*ansi.FrameDetector
@@ -87,7 +102,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
87102

88103
s := &Server{
89104
sessions: make(map[string]*Session),
90-
ptys: make(map[string]*os.File),
105+
ptys: make(map[string]*ptyHandle),
91106
cmds: make(map[string]*exec.Cmd),
92107
doneChans: make(map[string]chan struct{}),
93108
frameDetectors: make(map[string]*ansi.FrameDetector),
@@ -162,7 +177,10 @@ func (s *Server) Start() error {
162177
for {
163178
conn, err := listener.Accept()
164179
if err != nil {
165-
if s.listener == nil {
180+
s.mu.Lock()
181+
isShutdown := s.listener == nil
182+
s.mu.Unlock()
183+
if isShutdown {
166184
return nil
167185
}
168186
return err
@@ -211,8 +229,8 @@ func (s *Server) Shutdown() {
211229
if done, ok := s.doneChans[name]; ok {
212230
close(done)
213231
}
214-
if ptmx, ok := s.ptys[name]; ok {
215-
ptmx.Close()
232+
if handle, ok := s.ptys[name]; ok {
233+
handle.Close()
216234
}
217235
if cmd, ok := s.cmds[name]; ok {
218236
cmd.Process.Kill()
@@ -380,16 +398,17 @@ func (s *Server) handleCreate(req Request) Response {
380398
CreatedAt: now,
381399
}
382400

401+
handle := &ptyHandle{f: ptmx}
383402
s.sessions[req.Name] = sess
384-
s.ptys[req.Name] = ptmx
403+
s.ptys[req.Name] = handle
385404
s.cmds[req.Name] = cmd
386405
s.doneChans[req.Name] = make(chan struct{})
387406
if req.TUIMode {
388407
s.frameDetectors[req.Name] = ansi.NewFrameDetector(ansi.DefaultTUIStrategy())
389408
s.responders[req.Name] = ansi.NewTerminalResponder(ptmx)
390409
}
391410

392-
go s.captureOutput(req.Name, ptmx, cmd)
411+
go s.captureOutput(req.Name, handle, cmd)
393412

394413
return Response{Success: true, Data: map[string]interface{}{
395414
"name": sess.Name,
@@ -401,14 +420,16 @@ func (s *Server) handleCreate(req Request) Response {
401420
}}
402421
}
403422

404-
func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
423+
func (s *Server) captureOutput(name string, handle *ptyHandle, cmd *exec.Cmd) {
405424
s.mu.Lock()
406425
done := s.doneChans[name]
407426
detector := s.frameDetectors[name]
408427
responder := s.responders[name]
409428
storage := s.storage
410429
s.mu.Unlock()
411430

431+
f := handle.File()
432+
412433
if detector != nil {
413434
defer func() {
414435
if pending := detector.Flush(); len(pending) > 0 {
@@ -425,8 +446,8 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
425446
default:
426447
}
427448

428-
ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
429-
n, err := ptmx.Read(buf)
449+
f.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
450+
n, err := f.Read(buf)
430451
if n > 0 {
431452
data := buf[:n]
432453
if responder != nil {
@@ -450,7 +471,7 @@ func (s *Server) captureOutput(name string, ptmx *os.File, cmd *exec.Cmd) {
450471
}
451472

452473
cmd.Wait()
453-
ptmx.Close()
474+
handle.Close()
454475

455476
s.mu.Lock()
456477
defer s.mu.Unlock()
@@ -626,11 +647,12 @@ func (s *Server) handleSnapshot(req Request) Response {
626647
s.mu.Unlock()
627648
return Response{Success: false, Error: fmt.Sprintf("session %q is not in TUI mode (snapshot requires --tui)", req.Name)}
628649
}
629-
ptmx, ok := s.ptys[req.Name]
650+
handle, ok := s.ptys[req.Name]
630651
if !ok {
631652
s.mu.Unlock()
632653
return Response{Success: false, Error: fmt.Sprintf("session %q PTY not available", req.Name)}
633654
}
655+
ptmx := handle.File()
634656
cmd := s.cmds[req.Name]
635657
storage := s.storage
636658
s.mu.Unlock()
@@ -793,7 +815,7 @@ func (s *Server) handleSend(req Request) Response {
793815
s.mu.Unlock()
794816
return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)}
795817
}
796-
ptmx, ok := s.ptys[req.Name]
818+
handle, ok := s.ptys[req.Name]
797819
s.mu.Unlock()
798820

799821
if !ok {
@@ -805,7 +827,7 @@ func (s *Server) handleSend(req Request) Response {
805827
data += "\n"
806828
}
807829

808-
if _, err := ptmx.WriteString(data); err != nil {
830+
if _, err := handle.File().WriteString(data); err != nil {
809831
return Response{Success: false, Error: err.Error()}
810832
}
811833

@@ -830,8 +852,8 @@ func (s *Server) handleStop(req Request) Response {
830852
delete(s.doneChans, req.Name)
831853
}
832854

833-
if ptmx, ok := s.ptys[req.Name]; ok {
834-
ptmx.Close()
855+
if handle, ok := s.ptys[req.Name]; ok {
856+
handle.Close()
835857
delete(s.ptys, req.Name)
836858
}
837859

@@ -863,29 +885,25 @@ func (s *Server) handleStop(req Request) Response {
863885

864886
func (s *Server) handleKill(req Request) Response {
865887
s.mu.Lock()
866-
defer s.mu.Unlock()
867888

868889
sess, exists := s.sessions[req.Name]
869890
if !exists {
891+
s.mu.Unlock()
870892
return Response{Success: false, Error: fmt.Sprintf("session %q not found", req.Name)}
871893
}
872894

895+
var proc *os.Process
873896
if sess.State == StateRunning {
874897
if done, ok := s.doneChans[req.Name]; ok {
875898
close(done)
876899
delete(s.doneChans, req.Name)
877900
}
878-
879-
if ptmx, ok := s.ptys[req.Name]; ok {
880-
ptmx.Close()
901+
if handle, ok := s.ptys[req.Name]; ok {
902+
handle.Close()
881903
delete(s.ptys, req.Name)
882904
}
883-
884905
if cmd, ok := s.cmds[req.Name]; ok {
885-
cmd.Process.Signal(syscall.SIGTERM)
886-
time.Sleep(KillGracePeriod)
887-
cmd.Process.Signal(syscall.SIGKILL)
888-
cmd.Wait()
906+
proc = cmd.Process
889907
delete(s.cmds, req.Name)
890908
}
891909
}
@@ -894,6 +912,16 @@ func (s *Server) handleKill(req Request) Response {
894912
delete(s.sessions, req.Name)
895913
delete(s.frameDetectors, req.Name)
896914
delete(s.responders, req.Name)
915+
s.mu.Unlock()
916+
917+
if proc != nil {
918+
go func() {
919+
proc.Signal(syscall.SIGTERM)
920+
time.Sleep(KillGracePeriod)
921+
proc.Signal(syscall.SIGKILL)
922+
proc.Wait()
923+
}()
924+
}
897925

898926
return Response{Success: true}
899927
}
@@ -916,6 +944,10 @@ func (s *Server) handleSize(req Request) Response {
916944
}
917945

918946
func (s *Server) handleSearch(req Request) Response {
947+
if req.Before < 0 || req.After < 0 {
948+
return Response{Success: false, Error: "before and after must be non-negative"}
949+
}
950+
919951
s.mu.Lock()
920952
_, exists := s.sessions[req.Name]
921953
if !exists {
@@ -1056,7 +1088,7 @@ func (s *Server) handleResize(req Request) Response {
10561088
return Response{Success: false, Error: fmt.Sprintf("session %q is stopped", req.Name)}
10571089
}
10581090

1059-
ptmx, ok := s.ptys[req.Name]
1091+
handle, ok := s.ptys[req.Name]
10601092
if !ok {
10611093
s.mu.Unlock()
10621094
return Response{Success: false, Error: fmt.Sprintf("session %q not running", req.Name)}
@@ -1078,7 +1110,7 @@ func (s *Server) handleResize(req Request) Response {
10781110
rows = meta.Rows
10791111
}
10801112

1081-
if err := pty.Setsize(ptmx, &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil {
1113+
if err := pty.Setsize(handle.File(), &pty.Winsize{Cols: clampUint16(cols), Rows: clampUint16(rows)}); err != nil {
10821114
return Response{Success: false, Error: fmt.Sprintf("resize: %v", err)}
10831115
}
10841116

internal/mcp/tools.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,10 @@ func (r *ToolRegistry) callSearch(args json.RawMessage) (*CallToolResult, error)
831831
after = a.Around
832832
}
833833

834+
if before < 0 || after < 0 {
835+
return nil, fmt.Errorf("before, after, and around must be non-negative")
836+
}
837+
834838
resp, err := r.client.Search(daemon.SearchRequest{
835839
Name: a.Name,
836840
Pattern: a.Pattern,

0 commit comments

Comments
 (0)