Skip to content

Commit d1aa97e

Browse files
committed
refactor: add reconnect as a backend field, always create RecoveryManager
1 parent 8153cf3 commit d1aa97e

File tree

6 files changed

+41
-51
lines changed

6 files changed

+41
-51
lines changed

agent/runner.go

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,28 @@ func (r *Runner) Run(runnerCtx, shutdownCtx context.Context) error {
9494
workflowCtx, cancelWorkflowCtx := context.WithCancelCause(workflowCtx)
9595
defer cancelWorkflowCtx(nil)
9696

97-
// recoveryManager is declared here so the cancel listener can mark it as canceled.
98-
// It will be initialized later after workflow state is set up.
99-
var recoveryManager *pipeline.RecoveryManager
100-
101-
// Add sigterm support for internal context.
102-
// Required to be able to terminate the running workflow by external signals.
97+
// Handle SIGTERM (k8s, docker, system shutdown)
10398
workflowCtx = utils.WithContextSigtermCallback(workflowCtx, func() {
10499
logger.Error().Msg("received sigterm termination signal")
105100
// WithContextSigtermCallback would cancel the context too, but we want our own custom error
106101
cancelWorkflowCtx(pipeline.ErrCancel)
107102
})
108103

104+
state := rpc.WorkflowState{
105+
Started: time.Now().Unix(),
106+
}
107+
if err := r.client.Init(runnerCtx, workflow.ID, state); err != nil {
108+
logger.Error().Err(err).Msg("workflow initialization failed")
109+
// TODO: should we return here?
110+
}
111+
112+
// Initialize recovery manager before launching goroutines that reference it
113+
recoveryManager := pipeline.NewRecoveryManager(r.client, workflow.ID, true)
114+
if err := recoveryManager.InitRecoveryState(runnerCtx, workflow.Config, int64(timeout.Seconds())); err != nil {
115+
logger.Warn().Err(err).Msg("failed to initialize recovery state, continuing without recovery")
116+
recoveryManager = pipeline.NewRecoveryManager(r.client, workflow.ID, false)
117+
}
118+
109119
// Listen for remote cancel events (UI / API).
110120
// When canceled, we MUST cancel the workflow context
111121
// so that workflow execution stop immediately.
@@ -118,9 +128,7 @@ func (r *Runner) Run(runnerCtx, shutdownCtx context.Context) error {
118128
} else {
119129
if canceled {
120130
logger.Debug().Msg("server side cancel signal received")
121-
if recoveryManager != nil {
122-
recoveryManager.SetCanceled()
123-
}
131+
recoveryManager.SetCanceled()
124132
cancelWorkflowCtx(pipeline.ErrCancel)
125133
}
126134
logger.Debug().Msg("cancel listener exited normally")
@@ -144,25 +152,6 @@ func (r *Runner) Run(runnerCtx, shutdownCtx context.Context) error {
144152
}
145153
}()
146154

147-
state := rpc.WorkflowState{
148-
Started: time.Now().Unix(),
149-
}
150-
151-
if err := r.client.Init(runnerCtx, workflow.ID, state); err != nil {
152-
logger.Error().Err(err).Msg("signaling workflow initialization to server failed")
153-
// We have an error, maybe the server is currently unreachable or other server-side errors occurred.
154-
// So let's clean up and end this not yet started workflow run.
155-
cancelWorkflowCtx(err)
156-
return err
157-
}
158-
159-
// Initialize recovery manager; if not enabled on server, it will be a no-op
160-
recoveryManager = pipeline.NewRecoveryManager(r.client, workflow.ID, true)
161-
if err := recoveryManager.InitRecoveryState(runnerCtx, workflow.Config, int64(timeout.Seconds())); err != nil {
162-
logger.Warn().Err(err).Msg("failed to initialize recovery state, continuing without recovery")
163-
recoveryManager = pipeline.NewRecoveryManager(nil, workflow.ID, false)
164-
}
165-
166155
var uploads sync.WaitGroup
167156

168157
// Run pipeline
@@ -204,7 +193,7 @@ func (r *Runner) Run(runnerCtx, shutdownCtx context.Context) error {
204193

205194
// If workflow is recoverable (context canceled, recovery enabled, not user cancel),
206195
// skip marking as done. The workflow will be picked up by a new agent after restart.
207-
if recoveryManager != nil && recoveryManager.IsRecoverable(runnerCtx) {
196+
if recoveryManager.IsRecoverable(runnerCtx) {
208197
logger.Info().Msg("workflow is recoverable, not marking as done")
209198
return nil
210199
}

pipeline/backend/dummy/dummy.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ func (e *dummy) DestroyStep(_ context.Context, step *backend.Step, taskUUID stri
216216
return nil
217217
}
218218

219+
func (e *dummy) Reconnect(_ context.Context, _ *backend.Step, _ string) error {
220+
return fmt.Errorf("reconnect not supported")
221+
}
222+
219223
func (e *dummy) DestroyWorkflow(_ context.Context, _ *backend.Config, taskUUID string) error {
220224
log.Trace().Str("taskUUID", taskUUID).Msgf("delete workflow environment")
221225

pipeline/backend/local/local.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ func (e *local) DestroyStep(_ context.Context, step *types.Step, taskUUID string
256256
return nil
257257
}
258258

259+
func (e *local) Reconnect(_ context.Context, _ *types.Step, _ string) error {
260+
return fmt.Errorf("reconnect not supported")
261+
}
262+
259263
func (e *local) DestroyWorkflow(_ context.Context, _ *types.Config, taskUUID string) error {
260264
log.Trace().Str("taskUUID", taskUUID).Msg("delete workflow environment")
261265

pipeline/backend/types/backend.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,15 @@ type Backend interface {
161161
// This function may be called concurrently for different workflows
162162
// and must be thread-safe.
163163
DestroyWorkflow(ctx context.Context, conf *Config, taskUUID string) error
164+
165+
// Reconnect attempts to reconnect to a running step after agent restart.
166+
// Returns nil if reconnection is possible, error otherwise.
167+
// After successful reconnect, TailStep and WaitStep can be used normally.
168+
// Backends that do not support reconnection should return an error.
169+
Reconnect(ctx context.Context, step *Step, taskUUID string) error
164170
}
165171

166172
// BackendInfo represents the reported information of a loaded backend.
167173
type BackendInfo struct {
168174
Platform string
169175
}
170-
171-
// Reconnector is an optional interface that backends can implement to support
172-
// reconnecting to running steps after agent restart.
173-
type Reconnector interface {
174-
// Reconnect attempts to reconnect to a running step.
175-
// Returns nil if reconnection is possible, error otherwise.
176-
// After successful reconnect, TailStep and WaitStep can be used normally.
177-
Reconnect(ctx context.Context, step *Step, taskUUID string) error
178-
}

pipeline/pipeline.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,12 @@ func (r *Runtime) execAll(runnerCtx context.Context, steps []*backend.Step) <-ch
256256
return nil
257257
} else if r.recoveryManager.ShouldReconnect(recoveryState) {
258258
// Attempt to reconnect to a running step
259-
if reconnector, ok := r.engine.(backend.Reconnector); ok {
260-
reconnectErr := reconnector.Reconnect(r.ctx, step, r.taskUUID)
261-
if reconnectErr == nil {
262-
logger.Info().Str("step", step.Name).Msg("reconnecting to existing step")
263-
return r.execReconnected(step)
264-
}
265-
logger.Debug().Err(reconnectErr).Str("step", step.Name).Msg("cannot reconnect, re-executing step")
266-
} else {
267-
logger.Debug().Str("step", step.Name).Msg("backend does not support reconnection, re-executing step")
259+
reconnectErr := r.engine.Reconnect(r.ctx, step, r.taskUUID)
260+
if reconnectErr == nil {
261+
logger.Info().Str("step", step.Name).Msg("reconnecting to existing step")
262+
return r.execReconnected(step)
268263
}
264+
logger.Debug().Err(reconnectErr).Str("step", step.Name).Msg("cannot reconnect, re-executing step")
269265
}
270266

271267
// Mark step as running in recovery state

pipeline/recovery.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func NewRecoveryManager(client RecoveryClient, workflowID string, enabled bool)
5050
// On first run, creates recovery states for all steps.
5151
// On agent restart, loads existing states into cache.
5252
func (m *RecoveryManager) InitRecoveryState(ctx context.Context, config *backend.Config, timeoutSeconds int64) error {
53-
if !m.enabled || m.client == nil {
53+
if !m.enabled {
5454
return nil
5555
}
5656

@@ -78,7 +78,7 @@ func (m *RecoveryManager) GetStepState(step *backend.Step) *rpc.RecoveryState {
7878

7979
// MarkStepRunning marks a step as running.
8080
func (m *RecoveryManager) MarkStepRunning(ctx context.Context, step *backend.Step) error {
81-
if !m.enabled || m.client == nil {
81+
if !m.enabled {
8282
return nil
8383
}
8484

@@ -87,7 +87,7 @@ func (m *RecoveryManager) MarkStepRunning(ctx context.Context, step *backend.Ste
8787

8888
// MarkStepSuccess marks a step as successfully completed.
8989
func (m *RecoveryManager) MarkStepSuccess(ctx context.Context, step *backend.Step) error {
90-
if !m.enabled || m.client == nil {
90+
if !m.enabled {
9191
return nil
9292
}
9393

@@ -96,7 +96,7 @@ func (m *RecoveryManager) MarkStepSuccess(ctx context.Context, step *backend.Ste
9696

9797
// MarkStepFailed marks a step as failed.
9898
func (m *RecoveryManager) MarkStepFailed(ctx context.Context, step *backend.Step, exitCode int) error {
99-
if !m.enabled || m.client == nil {
99+
if !m.enabled {
100100
return nil
101101
}
102102

0 commit comments

Comments
 (0)