Skip to content

Commit 66acf1d

Browse files
authored
Merge pull request #62 from qmuntal/concstring
Support concurrent state machine reads
2 parents fc6e1a7 + 7bc2a0b commit 66acf1d

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

statemachine.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ type StateMachine struct {
7575
eventQueue list.List
7676
firingMode FiringMode
7777
firingMutex sync.Mutex
78+
stateMutex sync.RWMutex
7879
}
7980

8081
func newStateMachine() *StateMachine {
@@ -295,22 +296,28 @@ func (sm *StateMachine) setState(ctx context.Context, state State) error {
295296
return sm.stateMutator(ctx, state)
296297
}
297298

298-
func (sm *StateMachine) currentState(ctx context.Context) (sr *stateRepresentation, err error) {
299-
var state State
300-
state, err = sm.State(ctx)
301-
if err == nil {
302-
sr = sm.stateRepresentation(state)
299+
func (sm *StateMachine) currentState(ctx context.Context) (*stateRepresentation, error) {
300+
state, err := sm.State(ctx)
301+
if err != nil {
302+
return nil, err
303303
}
304-
return
304+
return sm.stateRepresentation(state), nil
305305
}
306306

307-
func (sm *StateMachine) stateRepresentation(state State) (sr *stateRepresentation) {
308-
var ok bool
309-
if sr, ok = sm.stateConfig[state]; !ok {
310-
sr = newstateRepresentation(state)
311-
sm.stateConfig[state] = sr
307+
func (sm *StateMachine) stateRepresentation(state State) *stateRepresentation {
308+
sm.stateMutex.RLock()
309+
sr, ok := sm.stateConfig[state]
310+
sm.stateMutex.RUnlock()
311+
if !ok {
312+
sm.stateMutex.Lock()
313+
defer sm.stateMutex.Unlock()
314+
// Check again, since another goroutine may have added it while we were waiting for the lock.
315+
if sr, ok = sm.stateConfig[state]; !ok {
316+
sr = newstateRepresentation(state)
317+
sm.stateConfig[state] = sr
318+
}
312319
}
313-
return
320+
return sr
314321
}
315322

316323
func (sm *StateMachine) internalFire(ctx context.Context, trigger Trigger, args ...any) error {
@@ -354,7 +361,7 @@ func (sm *StateMachine) internalFireQueued(ctx context.Context, trigger Trigger,
354361
return nil
355362
}
356363

357-
func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, args ...any) (err error) {
364+
func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, args ...any) error {
358365
sm.ops.Add(1)
359366
defer sm.ops.Add(^uint64(0))
360367
var (
@@ -366,7 +373,7 @@ func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, ar
366373
}
367374
source, err := sm.State(ctx)
368375
if err != nil {
369-
return
376+
return err
370377
}
371378
representativeState := sm.stateRepresentation(source)
372379
var result triggerBehaviourResult
@@ -397,7 +404,7 @@ func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, ar
397404
err = sr.InternalAction(ctx, transition, args...)
398405
}
399406
}
400-
return
407+
return err
401408
}
402409

403410
func (sm *StateMachine) handleReentryTrigger(ctx context.Context, sr *stateRepresentation, transition Transition, args ...any) error {

statemachine_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,21 @@ func TestStateMachine_String(t *testing.T) {
15131513
}
15141514
}
15151515

1516+
func TestStateMachine_String_Concurrent(t *testing.T) {
1517+
// Test that race mode doesn't complain about concurrent access to the state machine.
1518+
sm := NewStateMachine(stateA)
1519+
const n = 10
1520+
var wg sync.WaitGroup
1521+
wg.Add(n)
1522+
for i := 0; i < n; i++ {
1523+
go func() {
1524+
defer wg.Done()
1525+
_ = sm.String()
1526+
}()
1527+
}
1528+
wg.Wait()
1529+
}
1530+
15161531
func TestStateMachine_Firing_Queued(t *testing.T) {
15171532
sm := NewStateMachine(stateA)
15181533

0 commit comments

Comments
 (0)