Skip to content

Commit aa77393

Browse files
authored
Merge pull request #68 from qmuntal/asyncqueued
Fix race conditions in queued Fire
2 parents 23039c6 + 8af1ab7 commit aa77393

File tree

2 files changed

+101
-53
lines changed

2 files changed

+101
-53
lines changed

modes.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package stateless
2+
3+
import (
4+
"context"
5+
"sync"
6+
"sync/atomic"
7+
)
8+
9+
type fireMode interface {
10+
Fire(ctx context.Context, trigger Trigger, args ...any) error
11+
Firing() bool
12+
}
13+
14+
type fireModeImmediate struct {
15+
ops atomic.Uint64
16+
sm *StateMachine
17+
}
18+
19+
func (f *fireModeImmediate) Firing() bool {
20+
return f.ops.Load() > 0
21+
}
22+
23+
func (f *fireModeImmediate) Fire(ctx context.Context, trigger Trigger, args ...any) error {
24+
f.ops.Add(1)
25+
defer f.ops.Add(^uint64(0))
26+
return f.sm.internalFireOne(ctx, trigger, args...)
27+
}
28+
29+
type queuedTrigger struct {
30+
Context context.Context
31+
Trigger Trigger
32+
Args []any
33+
}
34+
35+
type fireModeQueued struct {
36+
firing atomic.Bool
37+
sm *StateMachine
38+
39+
triggers []queuedTrigger
40+
mu sync.Mutex // guards triggers
41+
}
42+
43+
func (f *fireModeQueued) Firing() bool {
44+
return f.firing.Load()
45+
}
46+
47+
func (f *fireModeQueued) Fire(ctx context.Context, trigger Trigger, args ...any) error {
48+
f.enqueue(ctx, trigger, args...)
49+
for {
50+
et, ok := f.fetch()
51+
if !ok {
52+
break
53+
}
54+
err := f.execute(et)
55+
if err != nil {
56+
return err
57+
}
58+
}
59+
return nil
60+
}
61+
62+
func (f *fireModeQueued) enqueue(ctx context.Context, trigger Trigger, args ...any) {
63+
f.mu.Lock()
64+
defer f.mu.Unlock()
65+
66+
f.triggers = append(f.triggers, queuedTrigger{Context: ctx, Trigger: trigger, Args: args})
67+
}
68+
69+
func (f *fireModeQueued) fetch() (et queuedTrigger, ok bool) {
70+
f.mu.Lock()
71+
defer f.mu.Unlock()
72+
73+
if len(f.triggers) == 0 {
74+
return queuedTrigger{}, false
75+
}
76+
77+
if !f.firing.CompareAndSwap(false, true) {
78+
return queuedTrigger{}, false
79+
}
80+
81+
et, f.triggers = f.triggers[0], f.triggers[1:]
82+
return et, true
83+
}
84+
85+
func (f *fireModeQueued) execute(et queuedTrigger) error {
86+
defer f.firing.Swap(false)
87+
return f.sm.internalFireOne(et.Context, et.Trigger, et.Args...)
88+
}

statemachine.go

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
package stateless
22

33
import (
4-
"container/list"
54
"context"
65
"fmt"
76
"reflect"
87
"sync"
9-
"sync/atomic"
108
)
119

1210
// State is used to to represent the possible machine states.
@@ -64,26 +62,29 @@ func callEvents(events []TransitionFunc, ctx context.Context, transition Transit
6462
// It is safe to use the StateMachine concurrently, but non of the callbacks (state manipulation, actions, events, ...) are guarded,
6563
// so it is up to the client to protect them against race conditions.
6664
type StateMachine struct {
67-
ops atomic.Uint64
6865
stateConfig map[State]*stateRepresentation
6966
triggerConfig map[Trigger]triggerWithParameters
7067
stateAccessor func(context.Context) (State, error)
7168
stateMutator func(context.Context, State) error
7269
unhandledTriggerAction UnhandledTriggerActionFunc
7370
onTransitioningEvents []TransitionFunc
7471
onTransitionedEvents []TransitionFunc
75-
eventQueue list.List
76-
firingMode FiringMode
77-
firingMutex sync.Mutex
7872
stateMutex sync.RWMutex
73+
mode fireMode
7974
}
8075

81-
func newStateMachine() *StateMachine {
82-
return &StateMachine{
76+
func newStateMachine(firingMode FiringMode) *StateMachine {
77+
sm := &StateMachine{
8378
stateConfig: make(map[State]*stateRepresentation),
8479
triggerConfig: make(map[Trigger]triggerWithParameters),
8580
unhandledTriggerAction: UnhandledTriggerActionFunc(DefaultUnhandledTriggerAction),
8681
}
82+
if firingMode == FiringImmediate {
83+
sm.mode = &fireModeImmediate{sm: sm}
84+
} else {
85+
sm.mode = &fireModeQueued{sm: sm}
86+
}
87+
return sm
8788
}
8889

8990
// NewStateMachine returns a queued state machine.
@@ -94,7 +95,7 @@ func NewStateMachine(initialState State) *StateMachine {
9495
// NewStateMachineWithMode returns a state machine with the desired firing mode
9596
func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMachine {
9697
var stateMutex sync.Mutex
97-
sm := newStateMachine()
98+
sm := newStateMachine(firingMode)
9899
reference := &struct {
99100
State State
100101
}{State: initialState}
@@ -109,16 +110,14 @@ func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMa
109110
reference.State = state
110111
return nil
111112
}
112-
sm.firingMode = firingMode
113113
return sm
114114
}
115115

116116
// NewStateMachineWithExternalStorage returns a state machine with external state storage.
117117
func NewStateMachineWithExternalStorage(stateAccessor func(context.Context) (State, error), stateMutator func(context.Context, State) error, firingMode FiringMode) *StateMachine {
118-
sm := newStateMachine()
118+
sm := newStateMachine(firingMode)
119119
sm.stateAccessor = stateAccessor
120120
sm.stateMutator = stateMutator
121-
sm.firingMode = firingMode
122121
return sm
123122
}
124123

@@ -276,7 +275,7 @@ func (sm *StateMachine) Configure(state State) *StateConfiguration {
276275

277276
// Firing returns true when the state machine is processing a trigger.
278277
func (sm *StateMachine) Firing() bool {
279-
return sm.ops.Load() != 0
278+
return sm.mode.Firing()
280279
}
281280

282281
// String returns a human-readable representation of the state machine.
@@ -321,49 +320,10 @@ func (sm *StateMachine) stateRepresentation(state State) *stateRepresentation {
321320
}
322321

323322
func (sm *StateMachine) internalFire(ctx context.Context, trigger Trigger, args ...any) error {
324-
switch sm.firingMode {
325-
case FiringImmediate:
326-
return sm.internalFireOne(ctx, trigger, args...)
327-
case FiringQueued:
328-
fallthrough
329-
default:
330-
return sm.internalFireQueued(ctx, trigger, args...)
331-
}
332-
}
333-
334-
type queuedTrigger struct {
335-
Context context.Context
336-
Trigger Trigger
337-
Args []any
338-
}
339-
340-
func (sm *StateMachine) internalFireQueued(ctx context.Context, trigger Trigger, args ...any) error {
341-
sm.firingMutex.Lock()
342-
sm.eventQueue.PushBack(queuedTrigger{Context: ctx, Trigger: trigger, Args: args})
343-
sm.firingMutex.Unlock()
344-
if sm.Firing() {
345-
return nil
346-
}
347-
348-
for {
349-
sm.firingMutex.Lock()
350-
e := sm.eventQueue.Front()
351-
if e == nil {
352-
sm.firingMutex.Unlock()
353-
break
354-
}
355-
et := sm.eventQueue.Remove(e).(queuedTrigger)
356-
sm.firingMutex.Unlock()
357-
if err := sm.internalFireOne(et.Context, et.Trigger, et.Args...); err != nil {
358-
return err
359-
}
360-
}
361-
return nil
323+
return sm.mode.Fire(ctx, trigger, args...)
362324
}
363325

364326
func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, args ...any) error {
365-
sm.ops.Add(1)
366-
defer sm.ops.Add(^uint64(0))
367327
var (
368328
config triggerWithParameters
369329
ok bool

0 commit comments

Comments
 (0)