Skip to content

Commit 04f83b0

Browse files
committed
Implement graceful shutdown with proper prediction completion
This implements a comprehensive graceful shutdown mechanism that waits for in-flight predictions to complete before stopping runners and the service. Key changes: **Runner-level graceful shutdown:** - Add shutdownWhenIdle atomic flag and readyForShutdown channel to Runner - GracefulShutdown() signals runners to shutdown when idle - updateStatus() automatically closes readyForShutdown when becoming READY with no pending predictions - Add nil check with warning for test compatibility **Handler-level prediction rejection:** - Add gracefulShutdown atomic flag to reject new predictions during shutdown - Handler.Stop() sets flag and waits for manager shutdown - Predict() returns 503 Service Unavailable during shutdown **Manager-level coordinated shutdown:** - Manager.Stop() signals all runners for graceful shutdown - Use WaitGroup.Go() for independent parallel runner shutdowns - Respect RunnerShutdownGracePeriod timeout before force stopping - Wait on runner.readyForShutdown channel or timeout **Service-level errgroup coordination:** - Fix errgroup goroutines to exit on shutdown signal - Add shutdown case to force shutdown monitor goroutine - Signal handler already had proper shutdown case - Add contextcheck nolint for long-lived errgroup context **Test coverage:** - Add E2E test for 503 rejection of new predictions during shutdown - Verify graceful shutdown waits for in-flight predictions - Test service properly stops after shutdown completes This restores the graceful shutdown behavior from commit 575d218 that was lost during the server refactor, ensuring predictions complete naturally during the grace period rather than being immediately force-killed.
1 parent b4f70de commit 04f83b0

File tree

9 files changed

+542
-158
lines changed

9 files changed

+542
-158
lines changed

internal/runner/manager.go

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"time"
1818

1919
"go.uber.org/zap"
20-
"golang.org/x/sync/errgroup"
2120

2221
"github.com/replicate/cog-runtime/internal/config"
2322
"github.com/replicate/cog-runtime/internal/webhook"
@@ -747,43 +746,48 @@ func (m *Manager) Stop() error {
747746
log.Info("stopping runner manager")
748747

749748
m.mu.Lock()
750-
defer m.mu.Unlock()
751-
752-
// Stop all runners
753-
for i, runner := range m.runners {
749+
runnerList := make([]*Runner, 0, len(m.runners))
750+
for _, runner := range m.runners {
754751
if runner != nil {
755-
log.Infow("stopping runner", "name", runner.runnerCtx.id, "slot", i)
756-
if err := runner.Stop(); err != nil {
757-
log.Errorw("error stopping runner", "name", runner.runnerCtx.id, "error", err)
758-
if stopErr == nil {
759-
stopErr = err
760-
}
761-
}
752+
runnerList = append(runnerList, runner)
762753
}
763754
}
755+
m.mu.Unlock()
764756

765-
// Wait for runners to stop concurrently
766-
eg := errgroup.Group{}
767-
for i, runner := range m.runners {
768-
if runner != nil {
769-
name := runner.runnerCtx.id
770-
eg.Go(func() error {
771-
log.Infow("waiting for runner to stop", "name", name, "slot", i)
772-
runner.WaitForStop()
773-
return nil
774-
})
775-
}
757+
// Signal all runners for graceful shutdown
758+
for _, runner := range runnerList {
759+
runner.GracefulShutdown()
776760
}
777761

778-
if err := eg.Wait(); err != nil {
779-
log.Errorw("error waiting for runners to stop", "error", err)
780-
if stopErr == nil {
781-
stopErr = err
782-
}
783-
} else {
784-
log.Info("all runners stopped successfully")
762+
// Wait for runners to become idle or timeout using WaitGroup
763+
gracePeriod := m.cfg.RunnerShutdownGracePeriod
764+
log.Infow("grace period configuration", "grace_period", gracePeriod)
765+
graceCtx, cancel := context.WithTimeout(m.ctx, gracePeriod)
766+
defer cancel()
767+
768+
var wg sync.WaitGroup
769+
for _, runner := range runnerList {
770+
wg.Go(func() {
771+
log.Debugw("waiting for runner to become idle", "name", runner.runnerCtx.id, "grace_period", gracePeriod)
772+
// Wait for this runner to become idle OR timeout
773+
select {
774+
case <-runner.readyForShutdown:
775+
log.Infow("runner became idle naturally", "name", runner.runnerCtx.id)
776+
case <-graceCtx.Done():
777+
log.Warnw("grace period expired for runner", "name", runner.runnerCtx.id, "context_err", graceCtx.Err())
778+
}
779+
780+
// Always try to stop, handle errors independently
781+
if err := runner.Stop(); err != nil {
782+
log.Errorw("failed to stop runner gracefully", "name", runner.runnerCtx.id, "error", err)
783+
}
784+
})
785785
}
786786

787+
// Wait for all runners to complete shutdown (success or failure)
788+
wg.Wait()
789+
790+
log.Info("all runners stopped successfully")
787791
close(m.stopped)
788792
})
789793

internal/runner/runner.go

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"regexp"
1414
"strings"
1515
"sync"
16+
"sync/atomic"
1617
"syscall"
1718
"time"
1819

@@ -303,6 +304,8 @@ type Runner struct {
303304
procedureHash string
304305
mu sync.RWMutex
305306
stopped chan bool
307+
shutdownWhenIdle atomic.Bool
308+
readyForShutdown chan struct{} // closed when idle and ready to be stopped
306309
setupComplete chan struct{} // closed on first READY after setup
307310
webhookSender webhook.Sender
308311
logCaptureComplete chan struct{} // closed when both stdout/stderr capture complete
@@ -345,6 +348,34 @@ func (r *Runner) WaitForStop() {
345348
<-r.stopped
346349
}
347350

351+
func (r *Runner) GracefulShutdown() {
352+
log := r.logger.Sugar()
353+
if !r.shutdownWhenIdle.CompareAndSwap(false, true) {
354+
log.Debugw("graceful shutdown already initiated", "runner_id", r.runnerCtx.id)
355+
return
356+
}
357+
358+
r.mu.RLock()
359+
shouldSignal := (r.status == StatusReady && len(r.pending) == 0)
360+
r.mu.RUnlock()
361+
362+
log.Debugw("graceful shutdown initiated", "runner_id", r.runnerCtx.id, "status", r.status, "pending_count", len(r.pending), "should_signal", shouldSignal)
363+
364+
if shouldSignal {
365+
if r.readyForShutdown == nil {
366+
log.Warnw("readyForShutdown channel is nil, cannot signal shutdown readiness", "runner_id", r.runnerCtx.id)
367+
} else {
368+
select {
369+
case <-r.readyForShutdown:
370+
log.Debugw("readyForShutdown already closed", "runner_id", r.runnerCtx.id)
371+
default:
372+
log.Debugw("closing readyForShutdown channel", "runner_id", r.runnerCtx.id)
373+
close(r.readyForShutdown)
374+
}
375+
}
376+
}
377+
}
378+
348379
func (r *Runner) Start(ctx context.Context) error {
349380
log := r.logger.Sugar()
350381
r.mu.Lock()
@@ -676,44 +707,32 @@ func (r *Runner) ForceKill() {
676707

677708
func (r *Runner) verifyProcessCleanup(pid int) {
678709
log := r.logger.Sugar()
679-
const checkInterval = 10 * time.Millisecond
680-
681710
log.Infow("starting process cleanup verification", "pid", pid)
682711

683712
timeout := r.cleanupTimeout
684713
if timeout == 0 {
685-
timeout = 10 * time.Second // Default fallback
714+
timeout = 10 * time.Second
686715
}
687716

688717
timer := time.NewTimer(timeout)
689718
defer timer.Stop()
690719

691-
ticker := time.NewTicker(checkInterval)
692-
defer ticker.Stop()
693-
694-
for {
720+
select {
721+
case <-r.stopped:
722+
log.Infow("process cleanup verified successfully", "pid", pid)
695723
select {
696-
case <-timer.C:
697-
log.Errorw("process cleanup timeout exceeded, forcing server exit",
698-
"pid", pid, "timeout", timeout)
699-
// Signal forced shutdown - this is idempotent and never blocks
700-
if r.forceShutdown.TriggerForceShutdown() {
701-
log.Errorw("triggered force shutdown signal")
702-
}
703-
return
724+
case r.cleanupSlot <- struct{}{}:
725+
default:
726+
}
727+
return
704728

705-
case <-ticker.C:
706-
// Verify if process group has been terminated
707-
if err := r.verifyFn(pid); err == nil {
708-
log.Infow("process cleanup verified successfully", "pid", pid)
709-
// Return cleanup token to allow future cleanup
710-
select {
711-
case r.cleanupSlot <- struct{}{}:
712-
default:
713-
}
714-
return
715-
}
729+
case <-timer.C:
730+
log.Errorw("process cleanup timeout exceeded, forcing server exit",
731+
"pid", pid, "timeout", timeout)
732+
if r.forceShutdown.TriggerForceShutdown() {
733+
log.Errorw("triggered force shutdown signal")
716734
}
735+
return
717736
}
718737
}
719738

@@ -806,6 +825,17 @@ func (r *Runner) updateStatus(statusStr string) error {
806825
return err
807826
}
808827
r.status = status
828+
829+
// Close readyForShutdown channel when idle and shutdown requested
830+
if status == StatusReady && r.shutdownWhenIdle.Load() && len(r.pending) == 0 {
831+
select {
832+
case <-r.readyForShutdown:
833+
// Already closed
834+
default:
835+
close(r.readyForShutdown)
836+
}
837+
}
838+
809839
return nil
810840
}
811841

@@ -909,14 +939,14 @@ func (r *Runner) updateSetupResult() {
909939
switch r.setupResult.Status {
910940
case SetupSucceeded:
911941
r.status = StatusReady
912-
log.Debug("setup succeeded", "status", r.status.String())
942+
log.Debugw("setup succeeded", "status", r.status.String())
913943
case SetupFailed:
914944
r.status = StatusSetupFailed
915-
log.Debug("setup failed", "status", r.status.String())
945+
log.Debugw("setup failed", "status", r.status.String())
916946
default:
917947
r.setupResult.Status = SetupFailed
918948
r.status = StatusSetupFailed
919-
log.Debug("unknown setup status, defaulting to failed", "status", r.status.String())
949+
log.Debugw("unknown setup status, defaulting to failed", "status", r.status.String())
920950
}
921951
}
922952

@@ -972,6 +1002,7 @@ func NewRunner(ctx context.Context, ctxCancel context.CancelFunc, runnerCtx Runn
9721002
verifyFn: verifyProcessGroupTerminated,
9731003
cleanupSlot: make(chan struct{}, 1),
9741004
stopped: make(chan bool),
1005+
readyForShutdown: make(chan struct{}),
9751006
setupComplete: make(chan struct{}),
9761007
logCaptureComplete: make(chan struct{}),
9771008
cleanupTimeout: cleanupTimeout,

internal/runner/runner_test.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ func TestRunnerTempDirectoryCleanup(t *testing.T) {
10131013
err = r.Stop()
10141014
require.NoError(t, err, "Runner.Stop() should not error")
10151015

1016-
time.Sleep(100 * time.Millisecond)
1016+
time.Sleep(1 * time.Millisecond)
10171017

10181018
_, err = os.Stat(tmpDir)
10191019
assert.True(t, os.IsNotExist(err), "Temp directory should be cleaned up after Stop()")
@@ -1424,30 +1424,25 @@ func TestForceKillCleanupFailures(t *testing.T) {
14241424

14251425
t.Run("procedure mode cleanup success returns token", func(t *testing.T) {
14261426
forceShutdown := config.NewForceShutdownSignal()
1427-
verifyCallCount := 0
14281427

14291428
r := &Runner{
1430-
cleanupTimeout: 100 * time.Millisecond,
1429+
cleanupTimeout: 1 * time.Millisecond,
14311430
forceShutdown: forceShutdown,
14321431
cleanupSlot: make(chan struct{}, 1),
1433-
verifyFn: func(pid int) error {
1434-
verifyCallCount++
1435-
if verifyCallCount >= 3 {
1436-
return nil // Success after a few attempts
1437-
}
1438-
return fmt.Errorf("process still exists")
1439-
},
1440-
logger: zaptest.NewLogger(t),
1432+
stopped: make(chan bool),
1433+
logger: zaptest.NewLogger(t),
14411434
}
14421435

14431436
// Start verification process
14441437
go r.verifyProcessCleanup(12345)
14451438

1446-
// Wait for verification to complete
1447-
time.Sleep(150 * time.Millisecond)
1439+
// Signal process stopped to trigger cleanup completion
1440+
close(r.stopped)
1441+
1442+
// Give a moment for cleanup to complete
1443+
time.Sleep(1 * time.Millisecond)
14481444

14491445
// Verify cleanup token was returned
14501446
assert.Len(t, r.cleanupSlot, 1)
1451-
assert.GreaterOrEqual(t, verifyCallCount, 3)
14521447
})
14531448
}

internal/server/mux.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ func NewServeMux(handler *Handler, useProcedureMode bool) *http.ServeMux {
1818
serveMux.HandleFunc("GET /{$}", handler.Root)
1919
serveMux.HandleFunc("GET /health-check", handler.HealthCheck)
2020
serveMux.HandleFunc("GET /openapi.json", handler.OpenAPI)
21-
serveMux.HandleFunc("POST /shutdown", handler.Shutdown)
2221

2322
if useProcedureMode {
2423
serveMux.HandleFunc("POST /procedures", handler.Predict)

internal/server/server.go

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/http"
1111
"os"
1212
"path"
13+
"sync/atomic"
1314
"time"
1415

1516
"go.uber.org/zap"
@@ -37,22 +38,21 @@ type IPC struct {
3738
}
3839

3940
type Handler struct {
40-
cfg config.Config
41-
shutdown context.CancelFunc
42-
startedAt time.Time
43-
runnerManager *runner.Manager
41+
cfg config.Config
42+
startedAt time.Time
43+
runnerManager *runner.Manager
44+
gracefulShutdown atomic.Bool
4445

4546
cwd string
4647

4748
logger *zap.Logger
4849
}
4950

50-
func NewHandler(ctx context.Context, cfg config.Config, shutdown context.CancelFunc, baseLogger *zap.Logger) (*Handler, error) {
51+
func NewHandler(ctx context.Context, cfg config.Config, baseLogger *zap.Logger) (*Handler, error) {
5152
runnerManager := runner.NewManager(ctx, cfg, baseLogger)
5253

5354
h := &Handler{
5455
cfg: cfg,
55-
shutdown: shutdown,
5656
startedAt: time.Now(),
5757
runnerManager: runnerManager,
5858
cwd: cfg.WorkingDirectory,
@@ -133,30 +133,21 @@ func (h *Handler) OpenAPI(w http.ResponseWriter, r *http.Request) {
133133
h.writeBytes(w, []byte(schema))
134134
}
135135

136-
func (h *Handler) Shutdown(w http.ResponseWriter, r *http.Request) {
137-
err := h.Stop()
138-
if err != nil {
139-
http.Error(w, err.Error(), http.StatusInternalServerError)
140-
} else {
141-
w.WriteHeader(http.StatusOK)
142-
}
143-
}
144-
145136
// ForceKillAll immediately force-kills all runners (for test cleanup)
146137
func (h *Handler) ForceKillAll() {
147138
h.runnerManager.ForceKillAll()
148139
}
149140

150141
func (h *Handler) Stop() error {
151-
// Stop the runner manager and handle shutdown in background
152-
go func() {
153-
log := h.logger.Sugar()
154-
if err := h.runnerManager.Stop(); err != nil {
155-
log.Errorw("failed to stop runner manager", "error", err)
156-
os.Exit(1)
157-
}
158-
h.shutdown()
159-
}()
142+
// Set graceful shutdown flag to reject new predictions
143+
h.gracefulShutdown.Store(true)
144+
145+
// Stop the runner manager synchronously
146+
log := h.logger.Sugar()
147+
if err := h.runnerManager.Stop(); err != nil {
148+
log.Errorw("failed to stop runner manager", "error", err)
149+
return err
150+
}
160151
return nil
161152
}
162153

@@ -207,6 +198,13 @@ func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) {
207198

208199
func (h *Handler) Predict(w http.ResponseWriter, r *http.Request) {
209200
log := h.logger.Sugar()
201+
202+
// Reject new predictions during graceful shutdown
203+
if h.gracefulShutdown.Load() {
204+
http.Error(w, "server shutting down", http.StatusServiceUnavailable)
205+
return
206+
}
207+
210208
if r.Header.Get("Content-Type") != "application/json" {
211209
http.Error(w, "invalid content type", http.StatusUnsupportedMediaType)
212210
return

0 commit comments

Comments
 (0)