Skip to content

Commit 17d6777

Browse files
committed
Wire up checkpointing in cog-runtime
1 parent 29d6d6a commit 17d6777

File tree

6 files changed

+552
-30
lines changed

6 files changed

+552
-30
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package checkpointer
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"os/exec"
9+
"path/filepath"
10+
"strconv"
11+
"strings"
12+
"time"
13+
)
14+
15+
const (
16+
// Configuration environment variables
17+
locationEnvVar = "R8_LOCATION"
18+
shouldCheckpointEnvVar = "R8_CUDA_CHECKPOINT"
19+
leaseFileEnvVar = "R8_LEASE_FILE"
20+
cudaCheckpointDirEnvVar = "R8_CUDA_CHECKPOINT_DIR"
21+
cudaReadyFileEnvVar = "R8_CUDA_READY_LOCK_FILE"
22+
23+
// Dependencies for the checkpoint process
24+
cudaCheckpointURLFmtStr = "https://r8-public-assets-%s.cwobject.com/cuda-checkpoint"
25+
criuURLFmtStr = "https://r8-public-assets-%s.cwobject.com/criu.tar.gz"
26+
cudaCheckpointPath = "/tmp/cuda-checkpoint"
27+
criuPath = "/tmp/criu"
28+
29+
// Metadata storage paths
30+
cudaCmdFileName = "cuda-cmd"
31+
checkpointSubdirName = "checkpoint"
32+
)
33+
34+
var (
35+
errNoCheckpointDir = errors.New("Could not find checkpoint directory environment variable")
36+
)
37+
38+
type FatalCheckpointErr struct {
39+
err error
40+
}
41+
42+
func (e *FatalCheckpointErr) Error() string {
43+
return e.Error()
44+
}
45+
46+
type Checkpointer interface {
47+
Disable()
48+
HasCheckpoint() bool
49+
Prepare(ctx context.Context) error
50+
Checkpoint(ctx context.Context, cmd *exec.Cmd) error
51+
Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error)
52+
}
53+
54+
type checkpointer struct {
55+
enabled bool
56+
hasCheckpoint bool
57+
checkpointDir string
58+
leaseFile string
59+
}
60+
61+
func NewCheckpointer(ctx context.Context) Checkpointer {
62+
return &checkpointer{
63+
enabled: os.Getenv(shouldCheckpointEnvVar) == "true",
64+
checkpointDir: os.Getenv(cudaCheckpointDirEnvVar),
65+
leaseFile: os.Getenv(leaseFileEnvVar),
66+
}
67+
}
68+
69+
func (c *checkpointer) Disable() {
70+
c.enabled = false
71+
}
72+
73+
func (c *checkpointer) HasCheckpoint() bool {
74+
if !c.enabled {
75+
return false
76+
}
77+
78+
return c.hasCheckpoint
79+
}
80+
81+
func (c *checkpointer) Prepare(ctx context.Context) error {
82+
if !c.enabled {
83+
return nil
84+
}
85+
86+
// Download dependencies
87+
err := downloadCUDACheckpointBinaries(ctx)
88+
if err != nil {
89+
return err
90+
}
91+
92+
// Wait for IPC lease file to be deleted
93+
if c.leaseFile != "" {
94+
err = pollForFileDeletion(c.leaseFile, 5*time.Minute, 10*time.Second)
95+
if err != nil {
96+
return err
97+
}
98+
}
99+
100+
empty, err := isDirEmpty(filepath.Join(c.checkpointDir, checkpointSubdirName))
101+
// If the err is not nil, it probably means the directory does not exist
102+
if err == nil && !empty {
103+
c.hasCheckpoint = true
104+
}
105+
106+
return nil
107+
}
108+
109+
func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) error {
110+
if !c.enabled {
111+
return nil
112+
}
113+
114+
if c.checkpointDir == "" {
115+
return errNoCheckpointDir
116+
}
117+
118+
err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666)
119+
if err != nil {
120+
return err
121+
}
122+
123+
pid := strconv.Itoa(cogletCmd.Process.Pid)
124+
125+
// Find the PID of the command that is actually using the GPU
126+
cudaPIDBytes, err := exec.CommandContext(ctx, "nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader").Output()
127+
if err != nil {
128+
return err
129+
}
130+
131+
cudaPID := strings.TrimSpace(string(cudaPIDBytes))
132+
133+
// Get the command for this PID - it is _not_ always the root python process
134+
data, err := exec.CommandContext(ctx, "ps", "-o", "cmd=", cudaPID).Output()
135+
if err != nil {
136+
return err
137+
}
138+
139+
cudaCmd := strings.TrimSpace(string(data))
140+
141+
// Write said command to a file for later
142+
err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o666)
143+
if err != nil {
144+
return err
145+
}
146+
147+
// Toggle CUDA off
148+
cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID))
149+
if err := cmd.Run(); err != nil {
150+
return err
151+
}
152+
153+
// CRIU checkpoint (leaving process running)
154+
cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid)
155+
if err := cmd.Run(); err != nil {
156+
// Try to toggle CUDA back on. If we aren't able to restart CUDA, the process
157+
// will hang indefinitely, so we should kill it and try to start a new one
158+
// without checkpointing
159+
cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID))
160+
if cudaErr := cmd.Run(); cudaErr != nil {
161+
// Return a fatal error so upstream knows we cannot continue in the current state
162+
return &FatalCheckpointErr{
163+
err: cudaErr,
164+
}
165+
}
166+
// Return the original checkpointing error
167+
return err
168+
}
169+
170+
// Toggle CUDA back on. If we aren't able to restart CUDA, the process
171+
// will hang indefinitely, so we should kill it and try to start a new
172+
// one without checkpointing
173+
cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID))
174+
if err := cmd.Run(); err != nil {
175+
// Return a fatal error so upstream knows we cannot continue in the current state
176+
return &FatalCheckpointErr{
177+
err: err,
178+
}
179+
}
180+
181+
return setStatusReady()
182+
}
183+
184+
func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) {
185+
if !c.enabled {
186+
return nil, nil, nil
187+
}
188+
189+
// Read process from sentinel file
190+
cudaCmd, err := os.ReadFile(filepath.Join(c.checkpointDir, cudaCmdFileName))
191+
if err != nil {
192+
return nil, nil, err
193+
}
194+
195+
// Set up restore command
196+
restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName))
197+
198+
// Set up callback function once restore is started
199+
callback := func(con context.Context) error {
200+
// Get the PID for the command
201+
cudaPID, err := exec.CommandContext(con, "pgrep", "-fx", string(cudaCmd)).Output()
202+
if err != nil {
203+
// If this command failed, we want to best effort try to kill the started process,
204+
// since we'll start a new one
205+
restoreCmd.Process.Kill()
206+
207+
return err
208+
}
209+
210+
// Toggle CUDA on for the restored process
211+
cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID))
212+
if err := cmd.Run(); err != nil {
213+
// If this command failed, we want to best effort try to kill the started process,
214+
// since we'll start a new one
215+
restoreCmd.Process.Kill()
216+
217+
return err
218+
}
219+
220+
err = setStatusReady()
221+
if err != nil {
222+
// If this command failed, we want to best effort try to kill the started process,
223+
// since we'll start a new one
224+
restoreCmd.Process.Kill()
225+
226+
return err
227+
}
228+
229+
return nil
230+
}
231+
232+
// The restored command is a running instance of coglet
233+
return restoreCmd, callback, nil
234+
}
235+
236+
func downloadCUDACheckpointBinaries(ctx context.Context) error {
237+
location := os.Getenv("R8_LOCATION")
238+
239+
// Download the cuda-checkpoint binary
240+
err := downloadAndChmod(fmt.Sprintf(cudaCheckpointURLFmtStr, location), cudaCheckpointPath)
241+
if err != nil {
242+
return fmt.Errorf("failed to download and chmod cuda-checkpoint binary: %w", err)
243+
}
244+
// CRIU gets downloaded as a tar with its dependencies. So we need to extract the tar, then
245+
// link the LD_LIBRARY_PATH to the dependencies
246+
dir := filepath.Dir(criuPath)
247+
err = downloadAndUntar(ctx, fmt.Sprintf(criuURLFmtStr, location), dir)
248+
if err != nil {
249+
return fmt.Errorf("failed to download and untar CRIU: %w", err)
250+
}
251+
updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib"))
252+
return nil
253+
}

0 commit comments

Comments
 (0)