Skip to content

Commit 4ce0fe7

Browse files
bugfix: Ensure schema is loaded before processing (#221)
* bugfix: Ensure schema is loaded before processing Schema wasn't loaded prior to processing the input paths this resulted in a noop nil response. To ensure we always load the schema first we now block on runner ready state before processing. Additionaly a nil doc is considered an error. In the cases we call ProcessInputPaths we can accept that error; in manager we explicitly check for doc being nil and emit an error log. * Ensure that we always 202 for async Async predictions should not 500. However, since we now known that the runner failed before we send the prediction request down, we end up needing to work around the error and still send a 202 *then* let the webhook indicate failure. While this behavior is somewhat crazy, it is the behavioral contract we can revisit in the future.
1 parent 078c77f commit 4ce0fe7

19 files changed

+496
-97
lines changed

internal/runner/manager.go

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ var (
2929
ErrRunnerNotFound = errors.New("runner not found")
3030
ErrNoEmptySlot = errors.New("no empty slot available")
3131
ErrInvalidRunnerStatus = errors.New("invalid runner status for new prediction")
32+
// ErrAsyncPrediction is a sentinel error used to indicate that a prediction is being served asynchronously, it is not surfaced outside of runner
33+
ErrAsyncPrediction = errors.New("async prediction")
3234
)
3335

3436
// Manager manages the lifecycle and capacity of prediction runners
@@ -184,6 +186,7 @@ func (m *Manager) PredictAsync(ctx context.Context, req PredictionRequest) (*Pre
184186

185187
// predict is the internal implementation shared by both sync and async predictions
186188
func (m *Manager) predict(ctx context.Context, req PredictionRequest, async bool) (chan PredictionResponse, *PredictionResponse, error) {
189+
log := m.logger.Sugar()
187190
if err := m.claimSlot(); err != nil {
188191
return nil, nil, err
189192
}
@@ -202,7 +205,40 @@ func (m *Manager) predict(ctx context.Context, req PredictionRequest, async bool
202205
return nil, nil, fmt.Errorf("runner not ready: %s", runner.status)
203206
}
204207

205-
respChan, initialResponse, err := runner.predict(req)
208+
runner.mu.RLock()
209+
pending, exists := runner.pending[req.ID]
210+
setupCompletedChan := runner.setupComplete
211+
runner.mu.RUnlock()
212+
if !exists {
213+
m.releaseSlot()
214+
return nil, nil, fmt.Errorf("failed to find pending prediction after allocation: %s", req.ID)
215+
}
216+
select {
217+
case <-setupCompletedChan:
218+
// We need to wait for setup to complete before proceeding so that we can ensure that
219+
// the OpenAPI schema is available for input processing
220+
log.Tracew("runner setup complete, proceeding with prediction", "prediction_id", req.ID, "runner", runner.runnerCtx.id)
221+
case <-pending.ctx.Done():
222+
// Prediction was canceled, watcher will perform cleanup, we need to abort
223+
// the rest of the prediction processing
224+
log.Tracew("prediction was canceled before setup complete, aborting", "prediction_id", req.ID, "runner", runner.runnerCtx.id)
225+
m.releaseSlot()
226+
return nil, nil, fmt.Errorf("prediction %s was canceled: %w", req.ID, pending.ctx.Err())
227+
}
228+
229+
// Check for setup failure before calling predict
230+
runner.mu.Lock()
231+
status := runner.status
232+
runner.mu.Unlock()
233+
if status == StatusSetupFailed {
234+
// Setup failure will be handled by async webhook machinery
235+
// Return sentinel error to indicate async handling
236+
log.Tracew("setup failed, using async handling", "prediction_id", req.ID, "runner", runner.runnerCtx.id)
237+
m.releaseSlot()
238+
return nil, nil, ErrAsyncPrediction
239+
}
240+
241+
respChan, initialResponse, err := runner.predict(req.ID)
206242
if err != nil {
207243
m.releaseSlot()
208244
return nil, nil, err
@@ -330,19 +366,28 @@ func (m *Manager) allocatePrediction(runner *Runner, req PredictionRequest) { //
330366
// NOTE(morgan): by design we do not use the passed in context, as the passed
331367
// in context is tied to the http request, and would cause the prediction to
332368
// fail at the end of the http request's lifecycle.
333-
watcherCtx, cancel := context.WithCancel(m.ctx)
369+
predictionCtx, cancel := context.WithCancel(m.ctx)
334370

335371
pending := &PendingPrediction{
336372
request: req,
337373
outputCache: make(map[string]string),
338374
c: make(chan PredictionResponse, 1),
339375
cancel: cancel, // Manager can cancel this watcher explicitly
376+
ctx: predictionCtx,
340377
watcherDone: make(chan struct{}),
341378
outputNotify: make(chan struct{}, 1),
342379
webhookSender: m.webhookSender,
343380
}
344381
runner.pending[req.ID] = pending
345382

383+
now := time.Now().Format(config.TimeFormat)
384+
if pending.request.CreatedAt == "" {
385+
pending.request.CreatedAt = now
386+
}
387+
if pending.request.StartedAt == "" {
388+
pending.request.StartedAt = now
389+
}
390+
346391
// Start per-prediction response watcher with cleanup wrapper
347392
go func() {
348393
defer func() {
@@ -399,7 +444,7 @@ func (m *Manager) allocatePrediction(runner *Runner, req PredictionRequest) { //
399444
}
400445
}()
401446

402-
runner.watchPredictionResponses(watcherCtx, req.ID, pending)
447+
runner.watchPredictionResponses(predictionCtx, req.ID, pending)
403448
}()
404449
}
405450

internal/runner/path.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package runner
33
import (
44
"bytes"
55
"encoding/base64"
6+
"errors"
67
"fmt"
78
"io"
89
"net/http"
@@ -18,17 +19,22 @@ import (
1819
"github.com/replicate/go/httpclient"
1920
)
2021

21-
var Base64Regex = regexp.MustCompile(`^data:.*;base64,(?P<base64>.*)$`)
22+
var (
23+
Base64Regex = regexp.MustCompile(`^data:.*;base64,(?P<base64>.*)$`)
24+
ErrSchemaNotAvailable = errors.New("OpenAPI schema not available for input processing")
25+
)
2226

2327
func isURI(s *openapi3.SchemaRef) bool {
2428
return s.Value.Type.Is("string") && s.Value.Format == "uri"
2529
}
2630

2731
// ProcessInputPaths processes the input paths and discards the now unused paths from the input.
28-
// Note that we return the input, but the expectation is that input will be mutated in-place.
32+
// Note that we return the input, but the expectation is that input will be mutated in-place. This function
33+
// returns ErrSchemaNotAvailable if the OpenAPI schema is not available. It is up to the caller to decide how
34+
// handles this error (e.g. log a warning and proceed without path processing).
2935
func ProcessInputPaths(input any, doc *openapi3.T, paths *[]string, fn func(string, *[]string) (string, error)) (any, error) {
3036
if doc == nil {
31-
return input, nil
37+
return input, ErrSchemaNotAvailable
3238
}
3339

3440
schema, ok := doc.Components.Schemas["Input"]

internal/runner/path_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ func TestProcessInputPaths(t *testing.T) {
576576
}
577577

578578
result, err := ProcessInputPaths(input, nil, &paths, mockFn)
579-
require.NoError(t, err)
579+
require.ErrorIs(t, err, ErrSchemaNotAvailable)
580580
assert.Equal(t, input, result)
581581
assert.Empty(t, paths)
582582
})

internal/runner/runner.go

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -812,40 +812,53 @@ func (r *Runner) verifyProcessCleanup(pid int) {
812812

813813
// predict returns a channel that will receive the prediction response and an initial prediction response
814814
// populated with the relevant fields from the request
815-
func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, *PredictionResponse, error) {
815+
func (r *Runner) predict(reqID string) (chan PredictionResponse, *PredictionResponse, error) {
816816
log := r.logger.Sugar()
817817
r.mu.Lock()
818818
defer r.mu.Unlock()
819819

820-
log.Tracew("runner.predict called", "prediction_id", req.ID, "status", r.status)
820+
log.Tracew("runner.predict called", "prediction_id", reqID, "status", r.status)
821821

822822
// Prediction must be pre-allocated by manager
823-
pending, exists := r.pending[req.ID]
823+
pending, exists := r.pending[reqID]
824824
if !exists {
825-
return nil, nil, fmt.Errorf("prediction %s not allocated", req.ID)
825+
return nil, nil, fmt.Errorf("prediction %s not allocated", reqID)
826826
}
827827

828-
log.Tracew("prediction found in pending", "prediction_id", req.ID)
828+
if pending.request.ID != reqID {
829+
return nil, nil, fmt.Errorf("prediction ID mismatch: expected %s, got %s", reqID, pending.request.ID)
830+
}
831+
832+
pending.mu.Lock()
833+
defer pending.mu.Unlock()
834+
835+
log.Tracew("prediction found in pending", "prediction_id", reqID)
829836

830837
// Process input paths (base64 and URL inputs)
831838
inputPaths := make([]string, 0)
832-
input, err := ProcessInputPaths(req.Input, r.doc, &inputPaths, Base64ToInput)
833-
if err != nil {
834-
return nil, nil, fmt.Errorf("failed to process base64 inputs: %w", err)
835-
}
836-
input, err = ProcessInputPaths(input, r.doc, &inputPaths, URLToInput)
837-
if err != nil {
838-
return nil, nil, fmt.Errorf("failed to process URL inputs: %w", err)
839+
if r.doc == nil {
840+
log.Errorw("OpenAPI schema not available for input processing - cannot convert base64 or URL inputs", "prediction_id", reqID)
841+
} else {
842+
// Process base64 inputs first, then URL inputs (to allow URL inputs to reference base64-decoded files)
843+
input, err := ProcessInputPaths(pending.request.Input, r.doc, &inputPaths, Base64ToInput)
844+
if err != nil {
845+
return nil, nil, fmt.Errorf("failed to process base64 inputs: %w", err)
846+
}
847+
848+
input, err = ProcessInputPaths(input, r.doc, &inputPaths, URLToInput)
849+
if err != nil {
850+
return nil, nil, fmt.Errorf("failed to process URL inputs: %w", err)
851+
}
852+
pending.request.Input = input
839853
}
840-
req.Input = input
841854
pending.inputPaths = inputPaths
842855

843-
// Write prediction request to file (async like original)
844-
requestFile := fmt.Sprintf("request-%s.json", req.ID)
845-
log.Debugw("writing prediction request file", "prediction_id", req.ID, "file", requestFile)
856+
// Write prediction request to file
857+
requestFile := fmt.Sprintf("request-%s.json", reqID)
858+
log.Debugw("writing prediction request file", "prediction_id", reqID, "file", requestFile)
846859
requestPath := path.Join(r.runnerCtx.workingdir, requestFile)
847860

848-
requestData, err := json.Marshal(req)
861+
requestData, err := json.Marshal(pending.request)
849862
if err != nil {
850863
return nil, nil, fmt.Errorf("failed to marshal request: %w", err)
851864
}
@@ -854,13 +867,13 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, *Predi
854867
return nil, nil, fmt.Errorf("failed to write request file: %w", err)
855868
}
856869

857-
log.Tracew("wrote prediction request file", "prediction_id", req.ID, "path", requestPath, "working_dir", r.runnerCtx.workingdir, "request_data", string(requestData))
870+
log.Tracew("wrote prediction request file", "prediction_id", reqID, "path", requestPath, "working_dir", r.runnerCtx.workingdir, "request_data", string(requestData))
858871

859872
// Debug: Check if file actually exists and list directory contents
860873
if _, err := os.Stat(requestPath); err != nil {
861-
log.Tracew("ERROR: written request file does not exist", "prediction_id", req.ID, "path", requestPath, "error", err)
874+
log.Tracew("ERROR: written request file does not exist", "prediction_id", reqID, "path", requestPath, "error", err)
862875
} else {
863-
log.Tracew("confirmed request file exists", "prediction_id", req.ID, "path", requestPath)
876+
log.Tracew("confirmed request file exists", "prediction_id", reqID, "path", requestPath)
864877
}
865878

866879
// Debug: List all files in working directory
@@ -869,25 +882,14 @@ func (r *Runner) predict(req PredictionRequest) (chan PredictionResponse, *Predi
869882
for i, entry := range entries {
870883
fileNames[i] = entry.Name()
871884
}
872-
log.Tracew("working directory contents after write", "prediction_id", req.ID, "working_dir", r.runnerCtx.workingdir, "files", fileNames)
873-
}
874-
875-
// Update pending prediction with request details
876-
pending.request = req
877-
878-
now := time.Now().Format(config.TimeFormat)
879-
if pending.request.CreatedAt == "" {
880-
pending.request.CreatedAt = now
881-
}
882-
if pending.request.StartedAt == "" {
883-
pending.request.StartedAt = now
885+
log.Tracew("working directory contents after write", "prediction_id", reqID, "working_dir", r.runnerCtx.workingdir, "files", fileNames)
884886
}
885887

886-
log.Tracew("returning prediction channel", "prediction_id", req.ID)
888+
log.Tracew("returning prediction channel", "prediction_id", reqID)
887889
initialResponse := &PredictionResponse{
888890
Status: PredictionStarting,
889891
}
890-
initialResponse.populateFromRequest(req)
892+
initialResponse.populateFromRequest(pending.request)
891893
return pending.c, initialResponse, nil
892894
}
893895

@@ -970,14 +972,23 @@ func (r *Runner) updateSchema() {
970972
r.mu.Lock()
971973
defer r.mu.Unlock()
972974

975+
log := r.logger.Sugar()
973976
schemaPath := filepath.Join(r.runnerCtx.workingdir, "openapi.json")
977+
log.Tracew("attempting to read openapi.json", "path", schemaPath)
978+
974979
if schemaData, err := os.ReadFile(schemaPath); err == nil { //nolint:gosec // expected dynamic path
975980
r.schema = string(schemaData)
981+
log.Tracew("successfully read openapi.json", "schema_length", len(schemaData))
976982

977983
// Parse the schema for use in ProcessInputPaths
978984
if doc, parseErr := openapi3.NewLoader().LoadFromData(schemaData); parseErr == nil {
979985
r.doc = doc
986+
log.Tracew("successfully parsed openapi schema for ProcessInputPaths")
987+
} else {
988+
log.Errorw("failed to parse openapi schema", "error", parseErr)
980989
}
990+
} else {
991+
log.Tracew("failed to read openapi.json", "error", err)
981992
}
982993
}
983994

internal/runner/runner_test.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -492,29 +492,36 @@ func TestRunnerPredict(t *testing.T) {
492492
runnerCtx: RunnerContext{workingdir: tempDir},
493493
logger: loggingtest.NewTestLogger(t),
494494
}
495-
496-
// Pre-allocate prediction
497-
r.pending["test-id"] = &PendingPrediction{
498-
c: make(chan PredictionResponse, 1),
499-
}
495+
predictionID, _ := PredictionID()
500496

501497
req := PredictionRequest{
502-
ID: "test-id",
498+
ID: predictionID,
503499
Input: map[string]any{"key": "value"},
500+
// CreatedAt and StartedAt would be set in the manager allocatePrediction step
501+
// so we need to set them directly here
502+
CreatedAt: time.Now().Format(config.TimeFormat),
503+
StartedAt: time.Now().Format(config.TimeFormat),
504+
}
505+
// Pre-allocate prediction
506+
r.pending[predictionID] = &PendingPrediction{
507+
c: make(chan PredictionResponse, 1),
508+
request: req,
504509
}
505510

506-
ch, initialResponse, err := r.predict(req)
507-
assert.NotNil(t, initialResponse)
511+
ch, initialResponse, err := r.predict(req.ID)
512+
require.NoError(t, err)
513+
require.NotNil(t, initialResponse)
508514
assert.Equal(t, PredictionStarting, initialResponse.Status)
509515
assert.NotEmpty(t, initialResponse.ID)
510516
assert.Equal(t, req.Input, initialResponse.Input)
517+
assert.NotEmpty(t, initialResponse.CreatedAt)
518+
assert.NotEmpty(t, initialResponse.StartedAt)
511519
assert.Equal(t, req.CreatedAt, initialResponse.CreatedAt)
512520
assert.Equal(t, req.StartedAt, initialResponse.StartedAt)
513-
require.NoError(t, err)
514521
assert.NotNil(t, ch)
515522

516523
// Check request file was created
517-
requestFile := path.Join(tempDir, "request-test-id.json")
524+
requestFile := path.Join(tempDir, fmt.Sprintf("request-%s.json", predictionID))
518525
_, err = os.Stat(requestFile)
519526
assert.NoError(t, err)
520527
})
@@ -528,12 +535,14 @@ func TestRunnerPredict(t *testing.T) {
528535
logger: loggingtest.NewTestLogger(t),
529536
}
530537

531-
req := PredictionRequest{ID: "test-id"}
532-
ch, initialResponse, err := r.predict(req)
538+
predictionID, _ := PredictionID()
539+
540+
req := PredictionRequest{ID: predictionID}
541+
ch, initialResponse, err := r.predict(req.ID)
533542
require.Error(t, err)
534543
assert.Nil(t, ch)
535544
assert.Nil(t, initialResponse)
536-
assert.Contains(t, err.Error(), "prediction test-id not allocated")
545+
assert.Contains(t, err.Error(), fmt.Sprintf("prediction %s not allocated", predictionID))
537546
})
538547
}
539548

@@ -1118,7 +1127,7 @@ func TestPerPredictionWatcher(t *testing.T) {
11181127

11191128
// Setup temp directory with response files
11201129
tempDir := t.TempDir()
1121-
predictionID := "test-prediction-123"
1130+
predictionID, _ := PredictionID()
11221131

11231132
// Create response files - one for our prediction, one for another
11241133
responseFile1 := fmt.Sprintf("response-%s-00001.json", predictionID)
@@ -1170,7 +1179,7 @@ func TestPerPredictionWatcher(t *testing.T) {
11701179
t.Parallel()
11711180

11721181
tempDir := t.TempDir()
1173-
predictionID := "test-prediction-456"
1182+
predictionID, _ := PredictionID()
11741183
filename := fmt.Sprintf("response-%s-00001.json", predictionID)
11751184
filePath := filepath.Join(tempDir, filename)
11761185

@@ -1223,7 +1232,7 @@ func TestPerPredictionWatcher(t *testing.T) {
12231232
t.Parallel()
12241233

12251234
tempDir := t.TempDir()
1226-
predictionID := "test-prediction-789"
1235+
predictionID, _ := PredictionID()
12271236

12281237
// Setup runner
12291238
logger := loggingtest.NewTestLogger(t)
@@ -1290,7 +1299,7 @@ func TestPerPredictionWatcher(t *testing.T) {
12901299
t.Parallel()
12911300

12921301
tempDir := t.TempDir()
1293-
predictionID := "test-prediction-abc"
1302+
predictionID, _ := PredictionID()
12941303

12951304
// Setup runner
12961305
logger := loggingtest.NewTestLogger(t)

0 commit comments

Comments
 (0)