Skip to content

Commit 9419d11

Browse files
authored
feat: Implement scenario-based stateful command responses (#79)
* feat: Implement scenario-based stateful command responses Add scenario support to transcript_map.yaml — a scenario defines an ordered sequence of (command, transcript) pairs layered on top of a platform. Each SSH session gets its own sequence pointer that advances as commands match the next expected step; non-matching commands fall through to normal command_transcripts behavior. Schema: scenarios: csr1000v-add-interface: platform: csr1000v sequence: - command: show running-config transcript: transcripts/scenarios/.../before.txt - command: interface GigabitEthernet0/0/2 transcript: transcripts/generic_empty_return.txt - command: show running-config transcript: transcripts/scenarios/.../after.txt Inventory entries can now reference either platform or scenario (mutually exclusive). Validation checks scenario paths and platform references at startup. Sequence pointer resets on every new SSH connection. Closes #27 * fix: Extract resolveListeners helper to reduce run() cyclomatic complexity Extracts inventory/platform resolution into resolveListeners() to bring run() complexity below the gocyclo threshold of 15. Renames listenerSpec to listenerConfig for clarity. Also fixes gofmt formatting in ciscohandlers_test.go.
1 parent 945624d commit 9419d11

File tree

13 files changed

+396
-59
lines changed

13 files changed

+396
-59
lines changed

cissh.go

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,54 @@ import (
2222
"github.com/tbotnz/cisshgo/utils"
2323
)
2424

25-
type listenerSpec struct {
26-
fd *fakedevices.FakeDevice
27-
port int
25+
type listenerConfig struct {
26+
fd *fakedevices.FakeDevice
27+
port int
28+
sequence []utils.SequenceStep // nil for platform-only listeners
29+
}
30+
31+
func resolveListeners(cli utils.CLI, tm utils.TranscriptMap, baseDir string) ([]listenerConfig, error) {
32+
if cli.Inventory == "" {
33+
fd, err := fakedevices.InitGeneric(cli.Platform, tm, baseDir)
34+
if err != nil {
35+
return nil, err
36+
}
37+
var configs []listenerConfig
38+
for port := cli.StartingPort; port < cli.StartingPort+cli.Listeners; port++ {
39+
configs = append(configs, listenerConfig{fd, port, nil})
40+
}
41+
return configs, nil
42+
}
43+
44+
inv, err := utils.LoadInventory(cli.Inventory)
45+
if err != nil {
46+
return nil, err
47+
}
48+
49+
var configs []listenerConfig
50+
port := cli.StartingPort
51+
for _, entry := range inv.Devices {
52+
if entry.Scenario != "" {
53+
fd, seq, err := fakedevices.InitScenario(entry.Scenario, tm, baseDir)
54+
if err != nil {
55+
return nil, err
56+
}
57+
for i := 0; i < entry.Count; i++ {
58+
configs = append(configs, listenerConfig{fd, port, seq})
59+
port++
60+
}
61+
} else {
62+
fd, err := fakedevices.InitGeneric(entry.Platform, tm, baseDir)
63+
if err != nil {
64+
return nil, err
65+
}
66+
for i := 0; i < entry.Count; i++ {
67+
configs = append(configs, listenerConfig{fd, port, nil})
68+
port++
69+
}
70+
}
71+
}
72+
return configs, nil
2873
}
2974

3075
func run(ctx context.Context, cli utils.CLI) error {
@@ -38,47 +83,30 @@ func run(ctx context.Context, cli utils.CLI) error {
3883
return err
3984
}
4085

41-
var specs []listenerSpec
86+
configs, err := resolveListeners(cli, myTranscriptMap, baseDir)
87+
if err != nil {
88+
return err
89+
}
4290

43-
if cli.Inventory != "" {
44-
inv, err := utils.LoadInventory(cli.Inventory)
45-
if err != nil {
46-
return err
47-
}
48-
port := cli.StartingPort
49-
for _, entry := range inv.Devices {
50-
fd, err := fakedevices.InitGeneric(entry.Platform, myTranscriptMap, baseDir)
51-
if err != nil {
52-
return err
91+
var wg sync.WaitGroup
92+
for _, config := range configs { // coverage-ignore
93+
wg.Add(1)
94+
go func(cfg listenerConfig) {
95+
defer wg.Done()
96+
var err error
97+
if cfg.sequence != nil {
98+
err = sshlisteners.ScenarioListener(ctx, cfg.fd, cfg.sequence, cfg.port)
99+
} else {
100+
err = sshlisteners.GenericListener(ctx, cfg.fd, cfg.port, handlers.GenericCiscoHandler)
53101
}
54-
for i := 0; i < entry.Count; i++ {
55-
specs = append(specs, listenerSpec{fd, port})
56-
port++
102+
if err != nil {
103+
log.Printf("listener on port %d: %v", cfg.port, err)
57104
}
58-
}
59-
} else {
60-
fd, err := fakedevices.InitGeneric(cli.Platform, myTranscriptMap, baseDir)
61-
if err != nil {
62-
return err
63-
}
64-
for port := cli.StartingPort; port < cli.StartingPort+cli.Listeners; port++ {
65-
specs = append(specs, listenerSpec{fd, port})
66-
}
105+
}(config)
67106
}
68107

69-
var wg sync.WaitGroup
70-
for _, spec := range specs { // coverage-ignore
71-
wg.Add(1) // coverage-ignore
72-
go func(fd *fakedevices.FakeDevice, port int) { // coverage-ignore
73-
defer wg.Done() // coverage-ignore
74-
if err := sshlisteners.GenericListener(ctx, fd, port, handlers.GenericCiscoHandler); err != nil { // coverage-ignore
75-
log.Printf("listener on port %d: %v", port, err) // coverage-ignore
76-
} // coverage-ignore
77-
}(spec.fd, spec.port) // coverage-ignore
78-
} // coverage-ignore
79-
80-
wg.Wait() // coverage-ignore
81-
return nil // coverage-ignore
108+
wg.Wait()
109+
return nil
82110
}
83111

84112
func main() { // coverage-ignore

fakedevices/genericFakeDevice.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,32 @@ func readFile(filename string) (string, error) {
5050
return string(content), nil
5151
}
5252

53+
// InitScenario builds a FakeDevice for a named scenario, loading the base platform
54+
// and returning the pre-loaded sequence steps alongside the device.
55+
func InitScenario(scenarioName string, tm utils.TranscriptMap, baseDir string) (*FakeDevice, []utils.SequenceStep, error) {
56+
s, ok := tm.Scenarios[scenarioName]
57+
if !ok {
58+
return nil, nil, fmt.Errorf("scenario %q not found in transcript map", scenarioName)
59+
}
60+
fd, err := InitGeneric(s.Platform, tm, baseDir)
61+
if err != nil {
62+
return nil, nil, err
63+
}
64+
steps := make([]utils.SequenceStep, len(s.Sequence))
65+
for i, step := range s.Sequence {
66+
path := step.Transcript
67+
if !filepath.IsAbs(path) {
68+
path = filepath.Join(baseDir, path)
69+
}
70+
content, err := readFile(path)
71+
if err != nil {
72+
return nil, nil, err
73+
}
74+
steps[i] = utils.SequenceStep{Command: step.Command, Transcript: content}
75+
}
76+
return fd, steps, nil
77+
}
78+
5379
// InitGeneric builds a FakeDevice struct for use with cisshgo.
5480
// baseDir is the directory from which transcript paths are resolved (typically
5581
// the directory containing the transcript map file).

fakedevices/genericFakeDevice_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ func TestTranscriptMapIntegrity(t *testing.T) {
158158
}
159159
}
160160
}
161+
for name, s := range tm.Scenarios {
162+
for i, step := range s.Sequence {
163+
referenced[step.Transcript] = true
164+
if _, err := os.Stat(step.Transcript); err != nil {
165+
t.Errorf("scenario %q step %d (%q): file not found: %s", name, i, step.Command, step.Transcript)
166+
}
167+
}
168+
}
161169

162170
// Walk transcripts/ and flag unreferenced .txt files
163171
err = filepath.WalkDir("transcripts", func(path string, d fs.DirEntry, err error) error {
@@ -177,3 +185,37 @@ func TestTranscriptMapIntegrity(t *testing.T) {
177185
t.Fatalf("walking transcripts dir: %v", err)
178186
}
179187
}
188+
189+
func TestInitScenario(t *testing.T) {
190+
if err := os.Chdir(".."); err != nil {
191+
t.Fatal(err)
192+
}
193+
t.Cleanup(func() { os.Chdir("fakedevices") })
194+
195+
tm, err := utils.LoadTranscriptMap("transcripts/transcript_map.yaml")
196+
if err != nil {
197+
t.Fatal(err)
198+
}
199+
200+
fd, steps, err := InitScenario("csr1000v-add-interface", tm, ".")
201+
if err != nil {
202+
t.Fatalf("InitScenario: %v", err)
203+
}
204+
if fd.Platform != "csr1000v" {
205+
t.Errorf("Platform = %q, want csr1000v", fd.Platform)
206+
}
207+
if len(steps) != 5 {
208+
t.Errorf("steps len = %d, want 5", len(steps))
209+
}
210+
if steps[0].Command != "show running-config" {
211+
t.Errorf("steps[0].Command = %q, want 'show running-config'", steps[0].Command)
212+
}
213+
}
214+
215+
func TestInitScenario_UnknownScenario(t *testing.T) {
216+
tm := utils.TranscriptMap{Platforms: map[string]utils.TranscriptMapPlatform{}}
217+
_, _, err := InitScenario("nonexistent", tm, ".")
218+
if err == nil {
219+
t.Error("expected error for unknown scenario")
220+
}
221+
}

ssh_server/handlers/ciscohandlers.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ import (
1616

1717
// GenericCiscoHandler function handles generic Cisco style sessions
1818
func GenericCiscoHandler(myFakeDevice *fakedevices.FakeDevice) ssh.Handler {
19+
return genericCiscoSession(myFakeDevice, nil)
20+
}
21+
22+
// GenericCiscoScenarioHandler returns an ssh.Handler that plays back a scenario sequence.
23+
func GenericCiscoScenarioHandler(myFakeDevice *fakedevices.FakeDevice, sequence []utils.SequenceStep) ssh.Handler {
24+
return genericCiscoSession(myFakeDevice, sequence)
25+
}
26+
27+
func genericCiscoSession(myFakeDevice *fakedevices.FakeDevice, sequence []utils.SequenceStep) ssh.Handler {
1928
return func(s ssh.Session) {
2029

2130
// Exec mode: client sent a command directly (e.g., ssh host "show version")
@@ -34,7 +43,8 @@ func GenericCiscoHandler(myFakeDevice *fakedevices.FakeDevice) ssh.Handler {
3443
return
3544
}
3645

37-
// Interactive shell mode
46+
// Interactive shell mode — sequence pointer resets per session
47+
seqIdx := 0
3848
contextState := myFakeDevice.ContextSearch["base"]
3949
t := term.NewTerminal(s, myFakeDevice.Hostname+contextState)
4050

@@ -45,7 +55,7 @@ func GenericCiscoHandler(myFakeDevice *fakedevices.FakeDevice) ssh.Handler {
4555
}
4656
log.Println(userInput)
4757

48-
done := handleShellInput(t, userInput, myFakeDevice, &contextState)
58+
done := handleShellInput(t, userInput, myFakeDevice, &contextState, sequence, &seqIdx)
4959
if done {
5060
break
5161
}
@@ -56,7 +66,7 @@ func GenericCiscoHandler(myFakeDevice *fakedevices.FakeDevice) ssh.Handler {
5666

5767
// handleShellInput processes a single line of user input in interactive shell mode.
5868
// Returns true if the session should be terminated.
59-
func handleShellInput(t *term.Terminal, userInput string, fd *fakedevices.FakeDevice, contextState *string) bool {
69+
func handleShellInput(t *term.Terminal, userInput string, fd *fakedevices.FakeDevice, contextState *string, sequence []utils.SequenceStep, seqIdx *int) bool {
6070
if userInput == "" {
6171
t.Write([]byte(""))
6272
return false
@@ -102,12 +112,28 @@ func handleShellInput(t *term.Terminal, userInput string, fd *fakedevices.FakeDe
102112
}
103113

104114
// Match against supported commands
105-
return dispatchCommand(t, userInput, fd)
115+
return dispatchCommand(t, userInput, fd, sequence, seqIdx)
106116
}
107117

108-
// dispatchCommand matches userInput against supported commands and writes the response.
118+
// dispatchCommand matches userInput against the active sequence step first, then supported commands.
109119
// Returns true if the session should be terminated.
110-
func dispatchCommand(t *term.Terminal, userInput string, fd *fakedevices.FakeDevice) bool {
120+
func dispatchCommand(t *term.Terminal, userInput string, fd *fakedevices.FakeDevice, sequence []utils.SequenceStep, seqIdx *int) bool {
121+
// Check if the next sequence step matches
122+
if seqIdx != nil && *seqIdx < len(sequence) {
123+
step := sequence[*seqIdx]
124+
match, _, multipleMatches, _ := utils.CmdMatch(userInput, map[string]string{step.Command: ""})
125+
if match && !multipleMatches {
126+
output, err := fakedevices.TranscriptReader(step.Transcript, fd)
127+
if err != nil {
128+
log.Println(err)
129+
return true
130+
}
131+
t.Write(append([]byte(output), '\n'))
132+
*seqIdx++
133+
return false
134+
}
135+
}
136+
111137
match, matchedCommand, multipleMatches, err := utils.CmdMatch(userInput, fd.SupportedCommands)
112138
if err != nil {
113139
log.Println(err) // coverage-ignore // CmdMatch never returns errors

ssh_server/handlers/ciscohandlers_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/gliderlabs/ssh"
1313
"github.com/tbotnz/cisshgo/fakedevices"
14+
"github.com/tbotnz/cisshgo/utils"
1415
)
1516

1617
// newTestDevice creates a FakeDevice for testing without reading files from disk.
@@ -362,3 +363,51 @@ func TestHandler_TranscriptReaderError(t *testing.T) {
362363
// Should not crash, just close the session
363364
_ = interact(t, addr, []string{"show bad"})
364365
}
366+
367+
func TestHandler_ScenarioSequence(t *testing.T) {
368+
fd := newTestDevice()
369+
sequence := []utils.SequenceStep{
370+
{Command: "show running-config", Transcript: "config before\n"},
371+
{Command: "interface GigabitEthernet0/0/2", Transcript: ""},
372+
{Command: "show running-config", Transcript: "config after\n"},
373+
}
374+
375+
ln, err := net.Listen("tcp", "127.0.0.1:0")
376+
if err != nil {
377+
t.Fatal(err)
378+
}
379+
addr := ln.Addr().String()
380+
ln.Close()
381+
382+
srv := &ssh.Server{
383+
Addr: addr,
384+
Handler: GenericCiscoScenarioHandler(fd, sequence),
385+
PasswordHandler: func(ctx ssh.Context, pass string) bool {
386+
return pass == fd.Password
387+
},
388+
}
389+
go func() { _ = srv.ListenAndServe() }()
390+
391+
for i := 0; i < 20; i++ {
392+
conn, dialErr := net.DialTimeout("tcp", addr, 100*time.Millisecond)
393+
if dialErr == nil {
394+
conn.Close()
395+
break
396+
}
397+
time.Sleep(50 * time.Millisecond)
398+
}
399+
defer srv.Close()
400+
401+
// All three steps in a single session — pointer advances across commands
402+
out := interact(t, addr, []string{
403+
"show running-config", // step 0 → "config before"
404+
"interface GigabitEthernet0/0/2", // step 1 → ""
405+
"show running-config", // step 2 → "config after"
406+
})
407+
if !strings.Contains(out, "config before") {
408+
t.Errorf("expected 'config before' in output, got:\n%s", out)
409+
}
410+
if !strings.Contains(out, "config after") {
411+
t.Errorf("expected 'config after' in output, got:\n%s", out)
412+
}
413+
}

ssh_server/handlers/handlers.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ import (
44
"github.com/gliderlabs/ssh"
55

66
"github.com/tbotnz/cisshgo/fakedevices"
7+
"github.com/tbotnz/cisshgo/utils"
78
)
89

910
// PlatformHandler defines a default type for all platform handlers
1011
type PlatformHandler func(*fakedevices.FakeDevice) ssh.Handler
12+
13+
// ScenarioHandler defines a handler type that includes a sequence of steps
14+
type ScenarioHandler func(*fakedevices.FakeDevice, []utils.SequenceStep) ssh.Handler

ssh_server/sshlisteners/sshlisteners.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/tbotnz/cisshgo/fakedevices"
1313
"github.com/tbotnz/cisshgo/ssh_server/handlers"
14+
"github.com/tbotnz/cisshgo/utils"
1415
)
1516

1617
// GenericListener starts an SSH server on the given port and blocks until ctx is cancelled.
@@ -20,12 +21,26 @@ func GenericListener(
2021
portNumber int,
2122
myHandler handlers.PlatformHandler,
2223
) error {
24+
return listen(ctx, myFakeDevice, portNumber, myHandler(myFakeDevice.Copy()))
25+
}
26+
27+
// ScenarioListener starts an SSH server that plays back a scenario sequence.
28+
func ScenarioListener(
29+
ctx context.Context,
30+
myFakeDevice *fakedevices.FakeDevice,
31+
sequence []utils.SequenceStep,
32+
portNumber int,
33+
) error {
34+
return listen(ctx, myFakeDevice, portNumber, handlers.GenericCiscoScenarioHandler(myFakeDevice.Copy(), sequence))
35+
}
36+
37+
func listen(ctx context.Context, myFakeDevice *fakedevices.FakeDevice, portNumber int, handler ssh.Handler) error {
2338
portString := ":" + strconv.Itoa(portNumber)
2439
log.Printf("Starting cissh.go ssh server on port %s\n", portString)
2540

2641
srv := &ssh.Server{
2742
Addr: portString,
28-
Handler: myHandler(myFakeDevice.Copy()),
43+
Handler: handler,
2944
PasswordHandler: func(sshCtx ssh.Context, pass string) bool {
3045
return pass == myFakeDevice.Password
3146
},

0 commit comments

Comments
 (0)