diff --git a/cmd/cog/main.go b/cmd/cog/main.go index 9640414..90f4109 100644 --- a/cmd/cog/main.go +++ b/cmd/cog/main.go @@ -23,6 +23,7 @@ type ServerCmd struct { UseProcedureMode bool `help:"Enable procedure mode for concurrent predictions" name:"use-procedure-mode" env:"COG_USE_PROCEDURE_MODE"` AwaitExplicitShutdown bool `help:"Wait for explicit shutdown signal instead of auto-shutdown" name:"await-explicit-shutdown" env:"COG_AWAIT_EXPLICIT_SHUTDOWN"` OneShot bool `help:"Enable one-shot mode (single runner, wait for cleanup before ready)" name:"one-shot" env:"COG_ONE_SHOT"` + SignalMode bool `help:"Enable signal mode (use signals instead of webhooks for IPC communication)" name:"signal-mode" env:"COG_SIGNAL_MODE"` UploadURL string `help:"Base URL for uploading prediction output files" name:"upload-url" env:"COG_UPLOAD_URL"` WorkingDirectory string `help:"Override the working directory for predictions" name:"working-directory" env:"COG_WORKING_DIRECTORY"` RunnerShutdownGracePeriod time.Duration `help:"Grace period before force-killing prediction runners" name:"runner-shutdown-grace-period" default:"600s" env:"COG_RUNNER_SHUTDOWN_GRACE_PERIOD"` @@ -78,6 +79,7 @@ func buildServiceConfig(s *ServerCmd) (config.Config, error) { WorkingDirectory: workingDir, UploadURL: s.UploadURL, IPCUrl: fmt.Sprintf("http://localhost:%d/_ipc", s.Port), + SignalMode: s.SignalMode, MaxRunners: s.MaxRunners, RunnerShutdownGracePeriod: s.RunnerShutdownGracePeriod, CleanupTimeout: s.CleanupTimeout, diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go new file mode 100644 index 0000000..e850789 --- /dev/null +++ b/internal/checkpointer/checkpointer.go @@ -0,0 +1,259 @@ +// There are some commands in here that are susceptible to injection. However, cog +// is a vehicle to let people run their own code... so why go through the hassle of +// injection? Cog is not run with any more permissions than the user code. +// +//nolint:gosec // See above +package checkpointer + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/replicate/cog-runtime/internal/logging" +) + +const ( + // Configuration environment variables + locationEnvVar = "R8_LOCATION" + shouldCheckpointEnvVar = "R8_CUDA_CHECKPOINT" + leaseFileEnvVar = "R8_LEASE_FILE" + cudaCheckpointDirEnvVar = "R8_CUDA_CHECKPOINT_DIR" + cudaReadyFileEnvVar = "R8_CUDA_READY_LOCK_FILE" + + // Dependencies for the checkpoint process + cudaCheckpointURLFmtStr = "https://r8-public-assets-%s.cwobject.com/cuda-checkpoint" + criuURLFmtStr = "https://r8-public-assets-%s.cwobject.com/criu.tar.gz" + cudaCheckpointPath = "/tmp/cuda-checkpoint" + criuPath = "/tmp/criu" + + // Metadata storage paths + checkpointSubdirName = "checkpoint" +) + +var errNoCheckpointDir = errors.New("could not find checkpoint directory environment variable") + +type FatalCheckpointError struct { + err error +} + +func (e *FatalCheckpointError) Error() string { + return e.err.Error() +} + +type Checkpointer interface { + Disable() + HasCheckpoint() bool + Prepare(ctx context.Context) error + Checkpoint(ctx context.Context, cmd *exec.Cmd, waitFunc func() error) error + Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) + WriteReadyFile() error +} + +type checkpointer struct { + enabled bool + hasCheckpoint bool + checkpointDir string + leaseFile string + log *logging.SugaredLogger +} + +func NewCheckpointer(ctx context.Context, log *logging.SugaredLogger) Checkpointer { + return &checkpointer{ + enabled: os.Getenv(shouldCheckpointEnvVar) == "true", + checkpointDir: os.Getenv(cudaCheckpointDirEnvVar), + leaseFile: os.Getenv(leaseFileEnvVar), + log: log, + } +} + +func (c *checkpointer) Disable() { + c.enabled = false +} + +func (c *checkpointer) HasCheckpoint() bool { + if !c.enabled { + return false + } + + return c.hasCheckpoint +} + +func (c *checkpointer) Prepare(ctx context.Context) error { + if !c.enabled { + return nil + } + + // Download dependencies + err := downloadCUDACheckpointBinaries(ctx) + if err != nil { + return err + } + + // Wait for IPC lease file to be deleted + if c.leaseFile != "" { + err = pollForFileDeletion(c.leaseFile, 5*time.Minute, 10*time.Second) + if err != nil { + return err + } + } + + empty, err := isDirEmpty(filepath.Join(c.checkpointDir, checkpointSubdirName)) + // If the err is not nil, it probably means the directory does not exist + if err == nil && !empty { + c.hasCheckpoint = true + } + + return nil +} + +func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, waitFunc func() error) error { + if !c.enabled { + return nil + } + + if c.checkpointDir == "" { + return errNoCheckpointDir + } + + if err := waitFunc(); err != nil { + return err + } + + err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) + if err != nil { + return err + } + + pid := strconv.Itoa(cogletCmd.Process.Pid) + + // Find the PID of the command that is actually using the GPU + cudaPIDBytes, err := exec.CommandContext(ctx, "nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader").Output() + if err != nil { + return err + } + + cudaPID := strings.TrimSpace(string(cudaPIDBytes)) + + // Toggle CUDA off + cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + if err := cmd.Run(); err != nil { + return err + } + + // CRIU checkpoint (leaving process running) + cmd = exec.CommandContext(ctx, criuPath, "dump", "--shell-job", "--leave-running", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + if err := cmd.Run(); err != nil { + // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process + // will hang indefinitely, so we should kill it and try to start a new one + // without checkpointing + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + if cudaErr := cmd.Run(); cudaErr != nil { + // Return a fatal error so upstream knows we cannot continue in the current state + return &FatalCheckpointError{ + err: cudaErr, + } + } + // Return the original checkpointing error + return err + } + + // Toggle CUDA back on. If we aren't able to restart CUDA, the process + // will hang indefinitely, so we should kill it and try to start a new + // one without checkpointing + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + if err := cmd.Run(); err != nil { + // Return a fatal error so upstream knows we cannot continue in the current state + return &FatalCheckpointError{ + err: err, + } + } + + return nil +} + +func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) { + if !c.enabled { + return nil, nil, nil + } + + // Set up restore command + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + + // Set up callback function once restore is started + callback := func(con context.Context) error { + out, err := exec.CommandContext(con, "ps", "aux").Output() + if err != nil { + c.log.Infow(err.Error()) + } + c.log.Infow(string(out)) + c.log.Infow(strconv.Itoa(restoreCmd.Process.Pid)) + // Toggle CUDA on for the restored process + cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + c.log.Errorw("failed to toggle CUDA on", "error", err) + // If this command failed, we want to best effort try to kill the started process, + // since we'll start a new one + killProcess(restoreCmd) //nolint:errcheck // This is just best effort + + return err + } + + return nil + } + + // The restored command is a running instance of coglet + return restoreCmd, callback, nil +} + +func killProcess(cmd *exec.Cmd) error { + err := cmd.Process.Kill() + if err != nil { + return err + } + + // Wait for the process to exit with a 5 second timeout + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + select { + case err = <-done: + return err + case <-time.After(5 * time.Second): + return nil + } +} + +func (c *checkpointer) WriteReadyFile() error { + // If it isn't expected, make this a no-op + if os.Getenv(shouldCheckpointEnvVar) != "true" { + return nil + } + return writeCudaReadyFile() +} + +func downloadCUDACheckpointBinaries(ctx context.Context) error { + location := os.Getenv("R8_LOCATION") + + // Download the cuda-checkpoint binary + err := downloadAndChmod(fmt.Sprintf(cudaCheckpointURLFmtStr, location), cudaCheckpointPath) + if err != nil { + return fmt.Errorf("failed to download and chmod cuda-checkpoint binary: %w", err) + } + // CRIU gets downloaded as a tar with its dependencies. So we need to extract the tar, then + // link the LD_LIBRARY_PATH to the dependencies + dir := filepath.Dir(criuPath) + err = downloadAndUntar(ctx, fmt.Sprintf(criuURLFmtStr, location), dir) + if err != nil { + return fmt.Errorf("failed to download and untar CRIU: %w", err) + } + return updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib")) +} diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go new file mode 100644 index 0000000..a63b5a9 --- /dev/null +++ b/internal/checkpointer/utils.go @@ -0,0 +1,168 @@ +// There are some commands in here that are susceptible to injection. However, cog +// is a vehicle to let people run their own code... so why go through the hassle of +// injection? Cog is not run with any more permissions than the user code. +// +//nolint:gosec // See above +package checkpointer + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "time" +) + +var errTimedOutPolling = errors.New("timed out while polling for file") + +// updateEnvVar updates an environment variable in-place, adding an item to it +// if it exists or creating it if it doesn't exist +func updateEnvVar(envVarName, newItem string) error { + old := os.Getenv(envVarName) + if old == "" { + return os.Setenv(envVarName, newItem) + } + path := newItem + string(os.PathListSeparator) + os.Getenv(envVarName) + return os.Setenv(envVarName, path) +} + +// downloadFile downloads a file from the URL provided to the path provided +func downloadFile(url, path string) error { + filename := filepath.Base(path) + err := os.MkdirAll(filepath.Dir(path), 0o600) + if err != nil { + return err + } + + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download %s: %w", filename, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download %s: %w", filename, err) + } + + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to touch file: %w", err) + } + defer file.Close() //nolint:errcheck // nothing to do with this error + + _, err = io.Copy(file, resp.Body) + if err != nil { + return fmt.Errorf("failed to save %s: %w", filename, err) + } + + return nil +} + +// downloadAndChmod downloads a file to the path provided and chmods it for +// execution. This expects the downloaded file to be a binary +func downloadAndChmod(url, path string) error { + err := downloadFile(url, path) + if err != nil { + return err + } + + if err := os.Chmod(path, 0o700); err != nil { + return fmt.Errorf("failed to chmod file: %w", err) + } + return nil +} + +// downloadAndUntar downloads a tar and extracts it to a path. The path is expected +// to be a directory +func downloadAndUntar(ctx context.Context, url, path string) error { + // Download to `${path}/tmp.tar.gz` + downloadPath := filepath.Join(path, "tmp.tar.gz") + err := downloadFile(url, downloadPath) + if err != nil { + return err + } + + // Untar into the `${path}` dir + cmd := exec.CommandContext(ctx, "tar", "-xf", downloadPath, "-C", path) + devnull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0o755) + if err != nil { + return err + } + defer devnull.Close() //nolint:errcheck // What would we do with this error anyways + cmd.Stdout = devnull + cmd.Stderr = devnull + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to extract tar: %w", err) + } + + return nil +} + +// pollForFileDeletion waits for a file to be deleted, up until a timeout. It returns an error if the +// timeout is hit +func pollForFileDeletion(target string, timeout, pollInterval time.Duration) error { + deadline := time.After(timeout) + + for { + // Check if the file still exists, if it does keep looping + if _, err := os.Stat(target); errors.Is(err, os.ErrNotExist) { + return nil + } + + // Check for timeout before sleeping for the polling interval + select { + case <-deadline: + return errTimedOutPolling + default: + time.Sleep(pollInterval) + } + } +} + +// https://stackoverflow.com/a/30708914/30548878 +func isDirEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + return false, err + } + defer f.Close() //nolint:errcheck // nothing to do with this error + + _, err = f.Readdirnames(1) + if errors.Is(err, io.EOF) { + return true, nil + } + return false, err +} + +// Touch a file if it doesn't exist, otherwise wipes the contents of the file +func touchFile(name string) error { + // Ensure upstream directory exists for file + err := os.MkdirAll(filepath.Dir(name), 0o644) + if err != nil { + return err + } + + f, err := os.Create(name) + if err != nil { + return err + } + return f.Close() +} + +// writeCudaReadyFile ensures the ready files exist +func writeCudaReadyFile() error { + cudaReadyFilePath := os.Getenv(cudaReadyFileEnvVar) + + // Touch CUDA ready file + err := touchFile(cudaReadyFilePath) + if err != nil { + return err + } + + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 7514e56..2ab1e1a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { UseProcedureMode bool AwaitExplicitShutdown bool OneShot bool + SignalMode bool // Directory configuration WorkingDirectory string diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 5da8658..f5ff3c3 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "os/exec" + "os/signal" "runtime" "sync" "syscall" @@ -15,6 +16,7 @@ import ( "go.uber.org/zap" + "github.com/replicate/cog-runtime/internal/checkpointer" "github.com/replicate/cog-runtime/internal/config" "github.com/replicate/cog-runtime/internal/logging" "github.com/replicate/cog-runtime/internal/webhook" @@ -272,6 +274,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { log.Debugw("creating default runner", "working_dir", workingDir, "ipc_url", m.cfg.IPCUrl, + "signal_mode", m.cfg.SignalMode, "python_bin", m.cfg.PythonBinPath, ) @@ -284,10 +287,26 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { "-u", "-m", "coglet", "--name", DefaultRunnerName, - "--ipc-url", m.cfg.IPCUrl, "--working-dir", workingDir, } + if m.cfg.SignalMode { + args = append(args, "--signal-mode") + // Make sure the signal handling is running + // This runs an infinite loop for handling signals, so we explicitly + // do not want to put it in a wait group of any kind + go m.HandleSignals(m.ctx) //nolint:contextcheck // We want this to live for the lifetime of the manager + } else { + args = append(args, "--ipc-url", m.cfg.IPCUrl) + } + + // This returns an object that does nothing if it is not enabled. + cp := checkpointer.NewCheckpointer(ctx, m.logger.Sugar()) + err := cp.Prepare(ctx) + if err != nil { + cp.Disable() + } + log.Debugw("runner command", "python_path", pythonPath, "args", args, "working_dir", workingDir) tmpDir, err := os.MkdirTemp("", "cog-runner-tmp-") @@ -295,11 +314,6 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { return nil, fmt.Errorf("failed to create temp directory: %w", err) } - // Derive the runtime context from the manager's context - runtimeContext, runtimeCancel := context.WithCancel(ctx) - cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable - cmd.Dir = m.cfg.WorkingDirectory - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} env := mergeEnv(os.Environ(), m.cfg.EnvSet, m.cfg.EnvUnset) env = append(env, "TMPDIR="+tmpDir) @@ -310,8 +324,6 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { env = append(env, "LOG_LEVEL=debug") } - cmd.Env = env - // Read cog.yaml for runner configuration (capacity was already set in newManager) cogYaml, err := ReadCogYaml(workingDir) if err != nil { @@ -330,14 +342,30 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { tmpDir: tmpDir, uploader: uploader, } - runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, cogYaml.Concurrency.Max, m.cfg, m.baseLogger) - if err != nil { - return nil, err + + // If there is an existing checkpoint, try to restore from the checkpoint + if cp.HasCheckpoint() { + runner, err := m.startRunnerFromCheckpoint(ctx, env, runnerCtx, cogYaml.Concurrency.Max, cp) + if err == nil { + m.runners[0] = runner + m.monitoringWG.Go(func() { + m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) + }) + + return runner, cp.WriteReadyFile() + } + // If the error was non-nil, disable the checkpointer and continue + cp.Disable() } - runner.webhookSender = m.webhookSender - if err := runner.Start(ctx); err != nil { - return nil, fmt.Errorf("failed to start runner: %w", err) + // Derive the runtime context from the manager's context + runtimeContext, runtimeCancel := context.WithCancel(ctx) + +commandSetup: + cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable + runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, cogYaml.Concurrency.Max) + if err != nil { + return nil, fmt.Errorf("failed to set up runner: %w", err) } if err := runner.Config(ctx); err != nil { @@ -349,13 +377,81 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { } m.runners[0] = runner + + if !cp.HasCheckpoint() { + err = cp.Checkpoint(ctx, cmd, func() error { return waitForRunnerSetup(ctx, runner) }) + var FatalCheckpointError *checkpointer.FatalCheckpointError + // If we saw an error that would leave the runner unusable, turn off the + // checkpointer and recreate the command and runner + if errors.As(err, &FatalCheckpointError) { + // TODO: Is this bad? Should we just return the error back up? + // The main concern is what `runner.Config` does leaving artifacts + // between runs, although I think that should be fine? + cp.Disable() + goto commandSetup + } + // If the error is not fatal, we failed to create a checkpoint but are still + // running the cog process successfully, so we can just continue as if we did + // nothing + } + m.monitoringWG.Go(func() { m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) }) + return runner, cp.WriteReadyFile() +} + +func (m *Manager) setupRunner(runtimeContext context.Context, runtimeCancel context.CancelFunc, cmd *exec.Cmd, env []string, runnerCtx RunnerContext, maxConcurrency int) (*Runner, error) { + cmd.Dir = m.cfg.WorkingDirectory + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Env = env + + runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, maxConcurrency, m.cfg, m.baseLogger) + if err != nil { + return nil, err + } + + runner.webhookSender = m.webhookSender + if err := runner.Start(runtimeContext); err != nil { + return nil, fmt.Errorf("failed to start runner: %w", err) + } + return runner, nil } +func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, runnerCtx RunnerContext, maxConcurrency int, cp checkpointer.Checkpointer) (*Runner, error) { + // Derive the runtime context from the manager's context + runtimeContext, runtimeCancel := context.WithCancel(ctx) + + cmd, postSetupCallback, err := cp.Restore(runtimeContext) + if err != nil { + runtimeCancel() + return nil, err + } + + runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) + if err != nil { + m.logger.Sugar().Errorw("failed to set up runner", "error", err) + return nil, fmt.Errorf("failed to set up runner: %w", err) + } + + err = postSetupCallback(runtimeContext) + if err != nil { + return nil, fmt.Errorf("failed callback function: %w", err) + } + + // We checkpointed the model after it ran setup, so we need to manually send the ready signal + // to the runner. We can do this by sending the SigReady signal to the current PID, as signal + // mode should be on if the checkpoint exists + err = syscall.Kill(syscall.Getpid(), SigReady) + if err != nil { + m.logger.Sugar().Errorw("failed to send SIGUSR1", "error", err) + } + + return runner, err +} + // allocatePrediction reserves a slot in the runner for the prediction func (m *Manager) allocatePrediction(runner *Runner, req PredictionRequest) { //nolint:contextcheck // we do not use this context for the prediction see note below log := m.logger.Sugar() @@ -931,6 +1027,35 @@ func (m *Manager) HandleRunnerIPC(runnerName, status string) error { return runner.HandleIPC(status) } +func (m *Manager) HandleRunnerSignal(runnerName string, s os.Signal) error { + runner, _, exists := m.findRunner(runnerName) + if !exists { + return fmt.Errorf("%w: %s", ErrRunnerNotFound, runnerName) + } + return runner.HandleSignal(s) +} + +func (m *Manager) HandleSignals(ctx context.Context) { + log := m.logger.Sugar() + + ch := make(chan os.Signal, 1) + signal.Notify(ch, SigOutput, SigReady, SigBusy) + + for { + select { + case s := <-ch: + err := m.HandleRunnerSignal(DefaultRunnerName, s) + if err != nil { + log.Errorw("failed to handle IPC", "signal", s, "error", err) + // TODO: What do we do with this error? Put it on some error chan + // and ship it somewhere? + } + case <-ctx.Done(): + return + } + } +} + func (m *Manager) cleanupInProgress() bool { if !m.cfg.OneShot { return false diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 69f7842..df353c2 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -30,6 +30,12 @@ import ( "github.com/replicate/cog-runtime/internal/webhook" ) +const ( + SigOutput = syscall.SIGHUP + SigReady = syscall.SIGUSR1 + SigBusy = syscall.SIGUSR2 +) + var ( LogRegex = regexp.MustCompile(`^\[pid=(?P[^]]+)] (?P.*)$`) ResponseRegex = regexp.MustCompile(`^response-(?P\S+)-(?P\d+).json$`) @@ -869,22 +875,6 @@ func (r *Runner) predict(reqID string) (chan PredictionResponse, *PredictionResp log.Tracew("wrote prediction request file", "prediction_id", reqID, "path", requestPath, "working_dir", r.runnerCtx.workingdir, "request_data", string(requestData)) - // Debug: Check if file actually exists and list directory contents - if _, err := os.Stat(requestPath); err != nil { - log.Tracew("ERROR: written request file does not exist", "prediction_id", reqID, "path", requestPath, "error", err) - } else { - log.Tracew("confirmed request file exists", "prediction_id", reqID, "path", requestPath) - } - - // Debug: List all files in working directory - if entries, err := os.ReadDir(r.runnerCtx.workingdir); err == nil { - fileNames := make([]string, len(entries)) - for i, entry := range entries { - fileNames[i] = entry.Name() - } - log.Tracew("working directory contents after write", "prediction_id", reqID, "working_dir", r.runnerCtx.workingdir, "files", fileNames) - } - log.Tracew("returning prediction channel", "prediction_id", reqID) initialResponse := &PredictionResponse{ Status: PredictionStarting, @@ -968,6 +958,47 @@ func (r *Runner) HandleIPC(status string) error { return nil } +// HandleSignal does the exact same things as HandleIPC just using signals +// instead of webhooks. This only can be used in non-pipeline use cases +func (r *Runner) HandleSignal(status os.Signal) error { + switch status { + case SigReady: + if r.status == StatusStarting { + r.updateSchema() + r.updateSetupResult() + // Close setupComplete channel to signal first READY after setup + r.mu.Lock() + select { + case <-r.setupComplete: + // Already closed + default: + close(r.setupComplete) + } + r.mu.Unlock() + } + if err := r.updateStatus("READY"); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + case SigBusy: + if err := r.updateStatus("BUSY"); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + case SigOutput: + // Notify all active prediction watchers of OUTPUT event + r.mu.RLock() + for _, pending := range r.pending { + select { + case pending.outputNotify <- struct{}{}: + // Notification sent + default: + // Channel full, skip (watcher will poll anyway) + } + } + r.mu.RUnlock() + } + return nil +} + func (r *Runner) updateSchema() { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/server/server.go b/internal/server/server.go index a26f7b3..2f0796a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,7 +10,9 @@ import ( "net/http" "os" "path" + "path/filepath" "sync/atomic" + "syscall" "time" "github.com/replicate/go/httpclient" @@ -26,6 +28,10 @@ const ( IPCStatusReady IPCStatus = "READY" IPCStatusBUSY IPCStatus = "BUSY" IPCStatusOutput IPCStatus = "OUTPUT" + + SigOutput = syscall.SIGHUP + SigReady = syscall.SIGUSR1 + SigBusy = syscall.SIGUSR2 ) type IPC struct { @@ -394,11 +400,16 @@ func writeReadyFile() error { dir := "/var/run/cog" file := path.Join(dir, "ready") - if _, err := os.Stat(file); os.IsNotExist(err) { + return writeFileIfNotExists(file) +} + +func writeFileIfNotExists(fpath string) error { + dir := filepath.Dir(fpath) + if _, err := os.Stat(fpath); os.IsNotExist(err) { if err := os.MkdirAll(dir, 0o700); err != nil { return err } - if err := os.WriteFile(file, nil, 0o600); err != nil { + if err := os.WriteFile(fpath, nil, 0o600); err != nil { return err } } diff --git a/python/coglet/__main__.py b/python/coglet/__main__.py index 46abede..7499f42 100644 --- a/python/coglet/__main__.py +++ b/python/coglet/__main__.py @@ -53,7 +53,9 @@ def pre_setup(logger: logging.Logger, working_dir: str) -> Optional[file_runner. def main() -> int: parser = argparse.ArgumentParser() parser.add_argument('--name', metavar='NAME', required=True, help='name') - parser.add_argument('--ipc-url', metavar='URL', required=True, help='IPC URL') + group = parser.add_mutually_exclusive_group() + group.add_argument('--ipc-url', metavar='URL', help='IPC URL') + group.add_argument('--signal-mode', action='store_true') parser.add_argument( '--working-dir', metavar='DIR', required=True, help='working directory' ) @@ -85,6 +87,7 @@ def main() -> int: logger=logger, name=args.name, ipc_url=args.ipc_url, + signal_mode=args.signal_mode, working_dir=args.working_dir, config=config, ).start() diff --git a/python/coglet/file_runner.py b/python/coglet/file_runner.py index 582253b..ed51d8a 100644 --- a/python/coglet/file_runner.py +++ b/python/coglet/file_runner.py @@ -5,6 +5,7 @@ import pathlib import re import signal +import sys import tempfile import urllib.request from dataclasses import dataclass @@ -25,6 +26,13 @@ class FileRunner: REQUEST_RE = re.compile(r'^request-(?P\S+).json$') RESPONSE_FMT = 'response-{pid}-{epoch:05d}.json' + # Signal parent to scan output + SIG_OUTPUT = signal.SIGHUP + + # Signal ready or busy status + SIG_READY = signal.SIGUSR1 + SIG_BUSY = signal.SIGUSR2 + # IPC status updates to Go server IPC_READY = 'READY' IPC_BUSY = 'BUSY' @@ -35,16 +43,21 @@ def __init__( *, logger: logging.Logger, name: str, - ipc_url: str, + ipc_url: Optional[str], working_dir: str, config: Config, + signal_mode: bool = False, ): + if not signal_mode and not ipc_url: + raise ValueError('IPC URL cannot be null if signal mode is false') + self.signal_mode = signal_mode self.logger = logger self.name = name self.ipc_url = ipc_url self.working_dir = working_dir self.config = config self.runner: Optional[runner.Runner] = None + self.isatty = sys.stdout.isatty() async def start(self) -> int: self.logger.info( @@ -117,7 +130,10 @@ def _cancel_handler(signum, _) -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) ready = True - self._send_ipc(FileRunner.IPC_READY) + if self.signal_mode: + self._signal(FileRunner.SIG_READY) + else: + self._send_ipc(FileRunner.IPC_READY) # Go server cannot receive IPC yet when a procedure is starting # Write a ready file as signal with open(ready_file, 'w') as f: @@ -127,7 +143,10 @@ def _cancel_handler(signum, _) -> None: while True: if not ready and len(pending) < self.config.max_concurrency: ready = True - self._send_ipc(FileRunner.IPC_READY) + if self.signal_mode: + self._signal(FileRunner.SIG_READY) + else: + self._send_ipc(FileRunner.IPC_READY) if os.path.exists(stop_file): self.logger.info('stopping file runner') @@ -172,7 +191,10 @@ def _cancel_handler(signum, _) -> None: if ready and len(pending) + 1 == self.config.max_concurrency: ready = False - self._send_ipc(FileRunner.IPC_BUSY) + if self.signal_mode: + self._signal(FileRunner.SIG_BUSY) + else: + self._send_ipc(FileRunner.IPC_BUSY) pending[pid] = asyncio.create_task(self._predict(pid, req)) self.logger.info('prediction started: id=%s', pid) @@ -284,10 +306,15 @@ def _respond( ) os.rename(temp_path, resp_path) - self._send_ipc(FileRunner.IPC_OUTPUT) + if self.signal_mode: + self._signal(FileRunner.SIG_OUTPUT) + else: + self._send_ipc(FileRunner.IPC_OUTPUT) def _send_ipc(self, status: str) -> None: try: + if not self.ipc_url: + raise RuntimeError('IPC invoked but IPC URL not provided') payload = { 'name': self.name, 'pid': os.getpid(), @@ -297,3 +324,7 @@ def _send_ipc(self, status: str) -> None: urllib.request.urlopen(self.ipc_url, data=data).read() except Exception as e: self.logger.exception('IPC failed: %s', e) + + def _signal(self, signum: int) -> None: + if not self.isatty: + os.kill(os.getppid(), signum) diff --git a/python/coglet/file_runner.pyi b/python/coglet/file_runner.pyi index 56f71f5..907555f 100644 --- a/python/coglet/file_runner.pyi +++ b/python/coglet/file_runner.pyi @@ -4,6 +4,7 @@ This type stub file was generated by pyright. import logging from dataclasses import dataclass +from typing import Optional @dataclass(frozen=True) class Config: @@ -17,10 +18,13 @@ class FileRunner: CANCEL_RE = ... REQUEST_RE = ... RESPONSE_FMT = ... + SIG_OUTPUT = ... + SIG_READY = ... + SIG_BUSY = ... IPC_READY = ... IPC_BUSY = ... IPC_OUTPUT = ... - def __init__(self, *, logger: logging.Logger, name: str, ipc_url: str, working_dir: str, config: Config) -> None: + def __init__(self, *, logger: logging.Logger, name: str, ipc_url: Optional[str], working_dir: str, config: Config, signal_mode: bool = ...) -> None: ... async def start(self) -> int: