Skip to content

Commit 8752ce8

Browse files
Implement UID-based cleanup for one-shot mode in procedure isolation (#213)
Add comprehensive cleanup functionality for setUID-isolated procedure runners: - Clean up /tmp files owned by isolated UIDs after predictions complete - Use os.Root for secure file operations to prevent path traversal - Implement one-shot mode with graceful Stop() and ForceKill() fallback - Add timeout-based cleanup with configurable CleanupTimeout - Skip cleanup of workingdir/tmpdir when they're under /tmp - Add test procedures and comprehensive test script - Add CI job for cleanup testing alongside existing setuid tests The cleanup ensures that files created by isolated procedure runners are properly removed, preventing accumulation of orphaned files in containerized environments like Docker and Kubernetes.
1 parent 04f83b0 commit 8752ce8

File tree

13 files changed

+417
-41
lines changed

13 files changed

+417
-41
lines changed

.github/workflows/ci.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,22 @@ jobs:
127127
- run: ./script/init.sh
128128
- run: ./script/test-setuid.sh
129129

130+
test-setuid-cleanup:
131+
name: Test Set UID cleanup in one-shot mode
132+
runs-on: ubuntu-latest
133+
strategy:
134+
fail-fast: false
135+
steps:
136+
- uses: actions/checkout@v4
137+
with:
138+
fetch-depth: 0
139+
- uses: astral-sh/setup-uv@v6
140+
- run: ./script/init.sh
141+
- name: Run setuid cleanup test
142+
run: ./script/test-setuid-cleanup.sh
143+
env:
144+
PORT: 8080
145+
130146
build-python:
131147
name: Build & verify python package
132148
runs-on: ubuntu-latest
@@ -178,6 +194,8 @@ jobs:
178194
- lint-go
179195
- test-go
180196
- test-python
197+
- test-set-uid
198+
- test-setuid-cleanup
181199
if: startsWith(github.ref, 'refs/tags/')
182200
runs-on: ubuntu-latest
183201
permissions:

cmd/cog/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ func buildServiceConfig(s *ServerCmd) (config.Config, error) {
9898
MaxRunners: s.MaxRunners,
9999
RunnerShutdownGracePeriod: s.RunnerShutdownGracePeriod,
100100
CleanupTimeout: s.CleanupTimeout,
101+
CleanupDirectories: []string{"/tmp"},
101102
}
102103

103104
log.Infow("service configuration",

internal/config/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ type Config struct {
2727
RunnerShutdownGracePeriod time.Duration
2828

2929
// Cleanup configuration
30-
CleanupTimeout time.Duration
30+
CleanupTimeout time.Duration
31+
CleanupDirectories []string // Directories to walk for cleanup of files owned by isolated UIDs
3132

3233
// Environment configuration
3334
EnvSet map[string]string

internal/runner/manager.go

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"io/fs"
1111
"os"
1212
"os/exec"
13-
"path/filepath"
1413
"runtime"
1514
"sync"
1615
"syscall"
@@ -348,12 +347,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) {
348347
tmpDir: tmpDir,
349348
uploader: uploader,
350349
}
351-
// Only enable forced shutdown for procedure mode
352-
var forceShutdown *config.ForceShutdownSignal
353-
if m.cfg.UseProcedureMode {
354-
forceShutdown = m.cfg.ForceShutdown
355-
}
356-
runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, cogYaml.Concurrency.Max, m.cfg.CleanupTimeout, forceShutdown, m.baseLogger)
350+
runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, cogYaml.Concurrency.Max, m.cfg, m.baseLogger)
357351
if err != nil {
358352
return nil, err
359353
}
@@ -419,6 +413,36 @@ func (m *Manager) allocatePrediction(runner *Runner, req PredictionRequest) { //
419413
delete(runner.pending, req.ID)
420414
runner.mu.Unlock()
421415

416+
// In one-shot mode, stop runner after prediction completes to trigger cleanup
417+
if m.cfg.OneShot && finalResponse.Status.IsCompleted() {
418+
go func() {
419+
logger := m.logger.Sugar()
420+
logger.Infow("one-shot mode: stopping runner after prediction completion", "prediction_id", req.ID, "runner_id", runner.runnerCtx.id)
421+
422+
// Try graceful stop with timeout
423+
stopDone := make(chan error, 1)
424+
go func() {
425+
stopDone <- runner.Stop()
426+
}()
427+
428+
timeout := m.cfg.CleanupTimeout
429+
if timeout == 0 {
430+
timeout = 10 * time.Second // Default timeout
431+
}
432+
433+
select {
434+
case err := <-stopDone:
435+
if err != nil {
436+
logger.Errorw("failed to stop runner in one-shot mode", "error", err, "runner_id", runner.runnerCtx.id)
437+
}
438+
runner.ForceKill()
439+
case <-time.After(timeout):
440+
logger.Warnw("stop timeout exceeded in one-shot mode, falling back to force kill", "timeout", timeout, "runner_id", runner.runnerCtx.id)
441+
runner.ForceKill()
442+
}
443+
}()
444+
}
445+
422446
if cancel != nil {
423447
cancel()
424448
}
@@ -626,20 +650,28 @@ func (m *Manager) createProcedureRunner(runnerName, procedureHash string) (*Runn
626650
env = append(env, "TMPDIR="+tmpDir)
627651
cmd.Env = env
628652

629-
// Apply setUID isolation for procedure runners if needed
653+
var allocatedUID *int
630654
if m.shouldUseSetUID() {
631655
uid, err := AllocateUID()
632656
if err != nil {
633657
runtimeCancel()
634658
return nil, fmt.Errorf("failed to allocate UID: %w", err)
635659
}
660+
allocatedUID = &uid
661+
662+
// Use os.Root for secure ownership changes
663+
workingRoot, err := os.OpenRoot(workingDir)
664+
if err != nil {
665+
runtimeCancel()
666+
return nil, fmt.Errorf("failed to open working directory root: %w", err)
667+
}
668+
defer func() { _ = workingRoot.Close() }()
636669

637-
// Change ownership of source directory (workingDir)
638-
err = filepath.WalkDir(workingDir, func(path string, d fs.DirEntry, err error) error {
670+
err = fs.WalkDir(workingRoot.FS(), ".", func(path string, d fs.DirEntry, err error) error {
639671
if err != nil {
640672
return err
641673
}
642-
if lchownErr := os.Lchown(path, uid, NoGroupGID); lchownErr != nil {
674+
if lchownErr := workingRoot.Lchown(path, uid, NoGroupGID); lchownErr != nil {
643675
log.Errorw("failed to change ownership", "path", path, "uid", uid, "error", lchownErr)
644676
return lchownErr
645677
}
@@ -650,19 +682,24 @@ func (m *Manager) createProcedureRunner(runnerName, procedureHash string) (*Runn
650682
return nil, fmt.Errorf("failed to change ownership of source directory: %w", err)
651683
}
652684

653-
// Make working dir writable by unprivileged Python process
654-
if err := os.Lchown(workingDir, uid, NoGroupGID); err != nil {
685+
if err := workingRoot.Lchown(".", uid, NoGroupGID); err != nil {
655686
log.Errorw("failed to change ownership of working directory", "path", workingDir, "uid", uid, "error", err)
656687
runtimeCancel()
657688
return nil, fmt.Errorf("failed to change ownership of working directory: %w", err)
658689
}
659-
// Change ownership of temp directory
660-
if err := os.Lchown(tmpDir, uid, NoGroupGID); err != nil {
690+
691+
tmpRoot, err := os.OpenRoot(tmpDir)
692+
if err != nil {
693+
runtimeCancel()
694+
return nil, fmt.Errorf("failed to open temp directory root: %w", err)
695+
}
696+
defer func() { _ = tmpRoot.Close() }()
697+
698+
if err := tmpRoot.Lchown(".", uid, NoGroupGID); err != nil {
661699
log.Errorw("failed to change ownership of temp directory", "path", tmpDir, "uid", uid, "error", err)
662700
runtimeCancel()
663701
return nil, fmt.Errorf("failed to change ownership of temp directory: %w", err)
664702
}
665-
// Use syscall.Credential to run process as unprivileged user from start
666703
cmd.SysProcAttr.Credential = &syscall.Credential{
667704
Uid: uint32(uid), //nolint:gosec // this is guarded in isolation .allocate, cannot exceed const MaxUID
668705
Gid: uint32(NoGroupGID),
@@ -675,19 +712,17 @@ func (m *Manager) createProcedureRunner(runnerName, procedureHash string) (*Runn
675712
if m.cfg.UploadURL != "" {
676713
uploader = newUploader(m.cfg.UploadURL)
677714
}
715+
678716
runnerCtx := RunnerContext{
679-
id: runnerName,
680-
workingdir: workingDir,
681-
tmpDir: tmpDir,
682-
uploader: uploader,
717+
id: runnerName,
718+
workingdir: workingDir,
719+
tmpDir: tmpDir,
720+
uploader: uploader,
721+
uid: allocatedUID,
722+
cleanupDirectories: m.cfg.CleanupDirectories,
683723
}
684724

685-
// Only enable forced shutdown for procedure mode
686-
var forceShutdown *config.ForceShutdownSignal
687-
if m.cfg.UseProcedureMode {
688-
forceShutdown = m.cfg.ForceShutdown
689-
}
690-
runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, 1, m.cfg.CleanupTimeout, forceShutdown, m.baseLogger)
725+
runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, 1, m.cfg, m.baseLogger)
691726
if err != nil {
692727
return nil, fmt.Errorf("failed to create runner: %w", err)
693728
}

internal/runner/runner.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ func verifyProcessGroupTerminated(pid int) error {
983983
}
984984

985985
// NewRunner creates a new runner instance with the given context
986-
func NewRunner(ctx context.Context, ctxCancel context.CancelFunc, runnerCtx RunnerContext, command *exec.Cmd, maxConcurrency int, cleanupTimeout time.Duration, forceShutdown *config.ForceShutdownSignal, logger *zap.Logger) (*Runner, error) {
986+
func NewRunner(ctx context.Context, ctxCancel context.CancelFunc, runnerCtx RunnerContext, command *exec.Cmd, maxConcurrency int, cfg config.Config, logger *zap.Logger) (*Runner, error) {
987987
if maxConcurrency <= 0 {
988988
maxConcurrency = 1
989989
}
@@ -1005,8 +1005,8 @@ func NewRunner(ctx context.Context, ctxCancel context.CancelFunc, runnerCtx Runn
10051005
readyForShutdown: make(chan struct{}),
10061006
setupComplete: make(chan struct{}),
10071007
logCaptureComplete: make(chan struct{}),
1008-
cleanupTimeout: cleanupTimeout,
1009-
forceShutdown: forceShutdown,
1008+
cleanupTimeout: cfg.CleanupTimeout,
1009+
forceShutdown: cfg.ForceShutdown,
10101010
logger: runnerLogger,
10111011
}
10121012

internal/runner/runner_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,8 @@ func TestNewRunner(t *testing.T) {
789789

790790
ctx, cancel := context.WithCancel(t.Context())
791791
defer cancel()
792-
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, 0, nil, zaptest.NewLogger(t))
792+
cfg := config.Config{}
793+
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, cfg, zaptest.NewLogger(t))
793794
require.NoError(t, err)
794795

795796
assert.Equal(t, "test-runner", r.runnerCtx.id)
@@ -825,7 +826,8 @@ func TestNewRunner(t *testing.T) {
825826

826827
ctx, cancel := context.WithCancel(t.Context())
827828
defer cancel()
828-
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, 0, nil, zaptest.NewLogger(t))
829+
cfg := config.Config{}
830+
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, cfg, zaptest.NewLogger(t))
829831
require.NoError(t, err)
830832

831833
// Should store the command correctly
@@ -849,7 +851,8 @@ func TestNewRunner(t *testing.T) {
849851

850852
ctx, cancel := context.WithCancel(t.Context())
851853
defer cancel()
852-
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, 0, nil, zaptest.NewLogger(t))
854+
cfg := config.Config{}
855+
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, cfg, zaptest.NewLogger(t))
853856
require.NoError(t, err)
854857
require.NotNil(t, r)
855858

@@ -887,7 +890,8 @@ func TestProcedureRunnerCreation(t *testing.T) {
887890

888891
ctx, cancel := context.WithCancel(t.Context())
889892
defer cancel()
890-
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, 0, nil, zaptest.NewLogger(t))
893+
cfg := config.Config{}
894+
r, err := NewRunner(ctx, cancel, runnerCtx, cmd, 1, cfg, zaptest.NewLogger(t))
891895
require.NoError(t, err)
892896

893897
assert.Equal(t, "proc-runner", r.runnerCtx.id)

internal/runner/types.go

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ import (
77
"encoding/base32"
88
"encoding/json"
99
"fmt"
10+
"io/fs"
1011
"os"
12+
"path/filepath"
1113
"strings"
1214
"sync"
1315
"sync/atomic"
16+
"syscall"
1417
"time"
1518

1619
"github.com/replicate/cog-runtime/internal/util"
@@ -210,16 +213,98 @@ func (r RunnerID) String() string {
210213

211214
// RunnerContext contains everything a runner needs to operate
212215
type RunnerContext struct {
213-
id string
214-
workingdir string
215-
tmpDir string
216-
uploader *uploader
216+
id string
217+
workingdir string
218+
tmpDir string
219+
uploader *uploader
220+
uid *int // UID used for setUID isolation, nil if not using setUID
221+
cleanupDirectories []string // Directories to walk for cleanup of files owned by isolated UIDs
217222
}
218223

219224
func (rc *RunnerContext) Cleanup() error {
220225
if rc.tmpDir != "" {
221-
return os.RemoveAll(rc.tmpDir)
226+
if err := os.RemoveAll(rc.tmpDir); err != nil {
227+
return err
228+
}
229+
}
230+
231+
// Clean up files in configured directories owned by this UID when using setUID isolation
232+
if rc.uid != nil && len(rc.cleanupDirectories) > 0 {
233+
return rc.cleanupDirectoriesFiles()
234+
}
235+
236+
return nil
237+
}
238+
239+
// cleanupDirectoriesFiles removes files in configured directories owned by the isolated UID
240+
func (rc *RunnerContext) cleanupDirectoriesFiles() error {
241+
if rc.uid == nil {
242+
return nil
222243
}
244+
245+
// Avoid cleaning our own workingdir/tmpdir if they're in the cleanup directories
246+
skipPaths := make(map[string]bool)
247+
for _, cleanupDir := range rc.cleanupDirectories {
248+
if strings.HasPrefix(rc.workingdir, cleanupDir+"/") {
249+
skipPaths[rc.workingdir] = true
250+
}
251+
if strings.HasPrefix(rc.tmpDir, cleanupDir+"/") {
252+
skipPaths[rc.tmpDir] = true
253+
}
254+
}
255+
256+
for _, cleanupDir := range rc.cleanupDirectories {
257+
// Use os.OpenRoot to create a secure chrooted view of the cleanup directory
258+
root, err := os.OpenRoot(cleanupDir)
259+
if err != nil {
260+
continue // Skip directories we can't root into
261+
}
262+
263+
err = fs.WalkDir(root.FS(), ".", func(path string, d fs.DirEntry, err error) error {
264+
if err != nil {
265+
return nil // Continue walking on errors
266+
}
267+
268+
// Convert relative path back to absolute for skipPaths check
269+
absPath := filepath.Join(cleanupDir, path)
270+
if skipPaths[absPath] {
271+
return filepath.SkipDir
272+
}
273+
274+
// Don't follow symlinks
275+
if d.Type()&fs.ModeSymlink != 0 {
276+
return nil
277+
}
278+
279+
// Check if file is owned by our UID using root.Stat
280+
info, err := root.Stat(path)
281+
if err != nil {
282+
return nil // Continue on stat errors
283+
}
284+
285+
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
286+
if int(stat.Uid) == *rc.uid {
287+
if err := root.RemoveAll(path); err != nil {
288+
// Log error but continue cleanup
289+
return nil
290+
}
291+
if d.IsDir() {
292+
return filepath.SkipDir
293+
}
294+
}
295+
}
296+
297+
return nil
298+
})
299+
300+
// Close the root after processing this directory
301+
_ = root.Close()
302+
303+
if err != nil {
304+
return err
305+
}
306+
}
307+
223308
return nil
224309
}
225310

0 commit comments

Comments
 (0)