diff --git a/README.md b/README.md index 71b3ee3..5fb50f8 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,11 @@ type Scheduler interface { // Stop shutdowns the scheduler. Stop() + + // GracefulStop shutdowns the scheduler and blocks until all jobs + // have returned. GracefulStop will return when the context passed + // to it has expired. + GracefulStop(ctx context.Context) error } ``` diff --git a/quartz/scheduler.go b/quartz/scheduler.go index 0b688c2..2c2f68b 100644 --- a/quartz/scheduler.go +++ b/quartz/scheduler.go @@ -64,6 +64,11 @@ type Scheduler interface { // Stop shutdowns the scheduler. Stop() + + // GracefulStop shutdowns the scheduler and blocks until all jobs + // have returned. GracefulStop will return when the context passed + // to it has expired. + GracefulStop(ctx context.Context) error } type dispatchedJob struct { @@ -83,6 +88,7 @@ type StdScheduler struct { wg sync.WaitGroup interrupt chan struct{} + stopCh chan struct{} cancel context.CancelFunc feeder chan ScheduledJob dispatch chan dispatchedJob @@ -288,6 +294,7 @@ func NewStdScheduler(opts ...SchedulerOpt) (Scheduler, error) { // initialize the scheduler with default values scheduler := &StdScheduler{ interrupt: make(chan struct{}, 1), + stopCh: make(chan struct{}), feeder: make(chan ScheduledJob), dispatch: make(chan dispatchedJob), queue: NewJobQueue(), @@ -360,6 +367,8 @@ func (sched *StdScheduler) Start(ctx context.Context) { return } + sched.stopCh = make(chan struct{}) + ctx, sched.cancel = context.WithCancel(ctx) go func() { <-ctx.Done(); sched.Stop() }() @@ -548,6 +557,28 @@ func (sched *StdScheduler) Stop() { sched.started = false } +// GracefulStop shutdowns the scheduler and blocks until all jobs +// have returned. GracefulStop will return when the context passed +// to it has expired. +func (sched *StdScheduler) GracefulStop(ctx context.Context) error { + sched.mtx.Lock() + defer sched.mtx.Unlock() + + if !sched.started { + sched.logger.Info("Scheduler is not running") + return nil + } + + sched.logger.Info("Gracefully closing the scheduler") + + close(sched.stopCh) + + sched.started = false + sched.Wait(ctx) + + return ctx.Err() +} + func (sched *StdScheduler) startExecutionLoop(ctx context.Context) { defer sched.wg.Done() const maxTimerDuration = time.Duration(1<<63 - 1) @@ -573,6 +604,11 @@ func (sched *StdScheduler) startExecutionLoop(ctx context.Context) { sched.logger.Trace("Interrupted waiting for next tick") timer.Stop() + case <-sched.stopCh: + sched.logger.Info("Exit the execution loop") + timer.Stop() + return + case <-ctx.Done(): sched.logger.Info("Exit the execution loop") timer.Stop() @@ -592,6 +628,8 @@ func (sched *StdScheduler) startWorkers(ctx context.Context) { select { case <-ctx.Done(): return + case <-sched.stopCh: + return case dispatched := <-sched.dispatch: sched.executeWithRetries(dispatched.ctx, dispatched.jobDetail) } diff --git a/quartz/scheduler_test.go b/quartz/scheduler_test.go index 9b52305..c5b1445 100644 --- a/quartz/scheduler_test.go +++ b/quartz/scheduler_test.go @@ -648,3 +648,107 @@ func jobCount(sched quartz.Scheduler, matchers ...quartz.Matcher[quartz.Schedule keys, _ := sched.GetJobKeys(matchers...) return len(keys) } + +func TestScheduler_GracefulStop(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sched, err := quartz.NewStdScheduler() + assert.IsNil(t, err) + sched.Start(ctx) + + timerJob := job.NewFunctionJob(func(_ context.Context) (bool, error) { + time.Sleep(250 * time.Millisecond) + return true, nil + }) + err = sched.ScheduleJob( + quartz.NewJobDetail(timerJob, quartz.NewJobKey("funcJob")), + quartz.NewSimpleTrigger(10*time.Millisecond), + ) + assert.IsNil(t, err) + + time.Sleep(50 * time.Millisecond) + + termCtx, termCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer termCancel() + err = sched.GracefulStop(termCtx) + if err != nil && !errors.Is(err, context.Canceled) { + t.Fatalf("graceful stop failed: %v", err) + } + + assert.Equal(t, sched.IsStarted(), false) +} + +func TestScheduler_GracefulStop_DoesNotCancelJob(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sched, err := quartz.NewStdScheduler() + assert.IsNil(t, err) + sched.Start(ctx) + + var canceledCount int64 + timerJob := job.NewFunctionJob(func(ctx context.Context) (bool, error) { + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + select { + case <-ctx.Done(): + atomic.AddInt64(&canceledCount, 1) + return true, nil + case <-timer.C: + return false, nil + } + }) + err = sched.ScheduleJob( + quartz.NewJobDetail(timerJob, quartz.NewJobKey("funcJob")), + quartz.NewSimpleTrigger(10*time.Millisecond), + ) + assert.IsNil(t, err) + + time.Sleep(50 * time.Millisecond) + + termCtx, termCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer termCancel() + err = sched.GracefulStop(termCtx) + if err != nil { + t.Fatalf("graceful stop failed: %v", err) + } + + if got := atomic.LoadInt64(&canceledCount); got != 0 { + t.Error("job was canceled") + } + assert.Equal(t, sched.IsStarted(), false) +} + +func TestScheduler_GracefulStop_WithWorkerLimit(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sched, err := quartz.NewStdScheduler(quartz.WithWorkerLimit(4)) + assert.IsNil(t, err) + sched.Start(ctx) + + timerJob := job.NewFunctionJob(func(_ context.Context) (bool, error) { + time.Sleep(250 * time.Millisecond) + return true, nil + }) + err = sched.ScheduleJob( + quartz.NewJobDetail(timerJob, quartz.NewJobKey("funcJob")), + quartz.NewSimpleTrigger(10*time.Millisecond), + ) + assert.IsNil(t, err) + + time.Sleep(50 * time.Millisecond) + + termCtx, termCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer termCancel() + err = sched.GracefulStop(termCtx) + if err != nil && !errors.Is(err, context.Canceled) { + t.Fatalf("graceful stop failed: %v", err) + } + + assert.Equal(t, sched.IsStarted(), false) +}