Skip to content

Commit e57bf59

Browse files
[feature] Reverse Shell History (#1420)
* feat: add circular buffer history to Mux for reverse shell context Updates the Mux implementation to maintain a per-stream circular buffer of recent messages. When a new stream registers, it receives the buffered history immediately. This ensures that new websocket connections to a reverse shell receive recent output context. The default buffer size is 1KB and is configurable. Only messages with kind "data" (or unspecified) are buffered. * feat: add reordered circular history to Mux Implements a circular buffer for stream history in the Mux to provide context to new connections. Refactors `sessionBuffer` to be reusable for reordering messages before writing to history. Ensures that history is written in the correct order even if pubsub messages arrive out of order. Uses a ring buffer implementation for `CircularBuffer` to minimize allocations. Added comprehensive tests for history replay and ordering. * fix: prevent Mux blocking on registration Refactors the `Mux.Start` loop to use a background poller goroutine and a `select` statement. This ensures that registration and unregistration requests are processed immediately, even if no pubsub messages are arriving. This fixes the issue where new reverse shell connections would not receive context/history until the next message (e.g., user input) was sent. Also verified existing history replay and ordering logic. * history to 5KB * fix: robustify Mux background poller Updates the background poller in `Mux.Start` to retry on errors (logging them) instead of exiting, ensuring the Mux remains active even if transient subscription errors occur. Also removes dead code (`registerStreams` and `unregisterStreams` methods) which are now handled inline within the `select` loop. Verified that history replay and message handling functionality remains correct. * chore: restore flaky-monitor.yml Restores the `.github/workflows/flaky-monitor.yml` file which was inadvertently removed. * revert changes to tests.yml --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: KCarretto <Kcarretto@gmail.com>
1 parent 4f14168 commit e57bf59

File tree

5 files changed

+421
-71
lines changed

5 files changed

+421
-71
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package stream
2+
3+
import (
4+
"sync"
5+
)
6+
7+
// CircularBuffer is a fixed-size byte buffer that overwrites old data when full.
8+
// It is safe for concurrent use.
9+
type CircularBuffer struct {
10+
mu sync.Mutex
11+
data []byte
12+
size int
13+
start int
14+
length int
15+
}
16+
17+
// NewCircularBuffer creates a new circular buffer with the given size.
18+
func NewCircularBuffer(size int) *CircularBuffer {
19+
return &CircularBuffer{
20+
data: make([]byte, size),
21+
size: size,
22+
start: 0,
23+
length: 0,
24+
}
25+
}
26+
27+
// Write appends data to the buffer.
28+
func (cb *CircularBuffer) Write(p []byte) {
29+
cb.mu.Lock()
30+
defer cb.mu.Unlock()
31+
32+
n := len(p)
33+
if n == 0 {
34+
return
35+
}
36+
37+
// If the data being written is larger than the buffer size,
38+
// we only care about the last `size` bytes.
39+
if n >= cb.size {
40+
copy(cb.data, p[n-cb.size:])
41+
cb.start = 0
42+
cb.length = cb.size
43+
return
44+
}
45+
46+
// We are writing n bytes.
47+
// We write starting at (start + length) % size.
48+
writeStart := (cb.start + cb.length) % cb.size
49+
50+
// Check if the write wraps around the end of the buffer
51+
if writeStart+n <= cb.size {
52+
// Contiguous write
53+
copy(cb.data[writeStart:], p)
54+
} else {
55+
// Wrapped write
56+
chunk1 := cb.size - writeStart
57+
copy(cb.data[writeStart:], p[:chunk1])
58+
copy(cb.data[0:], p[chunk1:])
59+
}
60+
61+
// Update length and start
62+
if cb.length+n <= cb.size {
63+
cb.length += n
64+
} else {
65+
// Buffer overflowed
66+
overflow := (cb.length + n) - cb.size
67+
cb.start = (cb.start + overflow) % cb.size
68+
cb.length = cb.size
69+
}
70+
}
71+
72+
// Bytes returns the current content of the buffer.
73+
func (cb *CircularBuffer) Bytes() []byte {
74+
cb.mu.Lock()
75+
defer cb.mu.Unlock()
76+
77+
out := make([]byte, cb.length)
78+
if cb.length == 0 {
79+
return out
80+
}
81+
82+
// If the data is contiguous
83+
if cb.start+cb.length <= cb.size {
84+
copy(out, cb.data[cb.start:cb.start+cb.length])
85+
} else {
86+
// Data wraps around
87+
chunk1 := cb.size - cb.start
88+
copy(out, cb.data[cb.start:])
89+
copy(out[chunk1:], cb.data[:cb.length-chunk1])
90+
}
91+
return out
92+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package stream
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestCircularBuffer(t *testing.T) {
10+
cb := NewCircularBuffer(10)
11+
12+
// Test basic write
13+
cb.Write([]byte("hello"))
14+
assert.Equal(t, []byte("hello"), cb.Bytes())
15+
16+
// Test append within size
17+
cb2 := NewCircularBuffer(10)
18+
cb2.Write([]byte("hello "))
19+
cb2.Write([]byte("world"))
20+
assert.Equal(t, []byte("ello world"), cb2.Bytes())
21+
22+
// Test write larger than size
23+
cb3 := NewCircularBuffer(5)
24+
cb3.Write([]byte("1234567890"))
25+
assert.Equal(t, []byte("67890"), cb3.Bytes())
26+
27+
// Test write exact size
28+
cb4 := NewCircularBuffer(5)
29+
cb4.Write([]byte("12345"))
30+
assert.Equal(t, []byte("12345"), cb4.Bytes())
31+
cb4.Write([]byte("6"))
32+
assert.Equal(t, []byte("23456"), cb4.Bytes())
33+
}

tavern/internal/http/stream/mux.go

Lines changed: 133 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ const (
1414
// maxRegistrationBufSize defines the maximum receivers that can be buffered in the registration / unregistration channel
1515
// before new calls to `mux.Register()` and `mux.Unregister()` will block.
1616
maxRegistrationBufSize = 256
17+
// defaultHistorySize is the default size of the circular buffer for stream history.
18+
defaultHistorySize = 1024
1719
)
1820

1921
var upgrader = websocket.Upgrader{
@@ -22,28 +24,44 @@ var upgrader = websocket.Upgrader{
2224
CheckOrigin: func(r *http.Request) bool { return true },
2325
}
2426

27+
type historyState struct {
28+
buffer *CircularBuffer
29+
sessions map[string]*sessionBuffer
30+
}
31+
2532
// A Mux enables multiplexing subscription messages to multiple Streams.
2633
// Streams will only receive a Message if their configured ID matches the incoming metadata of a Message.
2734
// Additionally, new messages may be published using the Mux.
2835
type Mux struct {
29-
pub *pubsub.Topic
30-
sub *pubsub.Subscription
31-
register chan *Stream
32-
unregister chan *Stream
33-
streams map[*Stream]bool
36+
pub *pubsub.Topic
37+
sub *pubsub.Subscription
38+
register chan *Stream
39+
unregister chan *Stream
40+
streams map[*Stream]bool
41+
history map[string]*historyState
42+
historySize int
3443
}
3544

3645
// A MuxOption is used to provide further configuration to the Mux.
3746
type MuxOption func(*Mux)
3847

48+
// WithHistorySize sets the size of the circular buffer for stream history.
49+
func WithHistorySize(size int) MuxOption {
50+
return func(m *Mux) {
51+
m.historySize = size
52+
}
53+
}
54+
3955
// NewMux initializes and returns a new Mux with the provided pubsub info.
4056
func NewMux(pub *pubsub.Topic, sub *pubsub.Subscription, options ...MuxOption) *Mux {
4157
mux := &Mux{
42-
pub: pub,
43-
sub: sub,
44-
register: make(chan *Stream, maxRegistrationBufSize),
45-
unregister: make(chan *Stream, maxRegistrationBufSize),
46-
streams: make(map[*Stream]bool),
58+
pub: pub,
59+
sub: sub,
60+
register: make(chan *Stream, maxRegistrationBufSize),
61+
unregister: make(chan *Stream, maxRegistrationBufSize),
62+
streams: make(map[*Stream]bool),
63+
history: make(map[string]*historyState),
64+
historySize: defaultHistorySize,
4765
}
4866
for _, opt := range options {
4967
opt(mux)
@@ -66,83 +84,110 @@ func (mux *Mux) Register(s *Stream) {
6684
mux.register <- s
6785
}
6886

69-
// registerStreams inserts all registered streams into the streams map.
70-
func (mux *Mux) registerStreams(ctx context.Context) {
71-
for {
72-
select {
73-
case s := <-mux.register:
74-
slog.DebugContext(ctx, "mux registering new stream", "stream_id", s.id)
75-
mux.streams[s] = true
76-
default:
77-
return
78-
}
79-
}
80-
}
81-
8287
// Unregister a stream when it should no longer receive Messages from the Mux.
8388
// Typically this is done via defer after registering a Stream.
8489
// Unregistering a stream that is not registered will still close the stream channel.
8590
func (mux *Mux) Unregister(s *Stream) {
8691
mux.unregister <- s
8792
}
8893

89-
// unregisterStreams deletes all unregistered streams from the streams map.
90-
func (mux *Mux) unregisterStreams(ctx context.Context) {
91-
for {
92-
select {
93-
case s := <-mux.unregister:
94-
slog.DebugContext(ctx, "mux unregistering stream", "stream_id", s.id)
95-
delete(mux.streams, s)
96-
s.Close()
97-
default:
98-
return
99-
}
100-
}
101-
}
102-
10394
// Start the mux, returning an error if polling ever fails.
10495
func (mux *Mux) Start(ctx context.Context) error {
10596
slog.DebugContext(ctx, "mux starting to manage streams and polling")
106-
for {
107-
// Manage Streams
108-
mux.registerStreams(ctx)
109-
mux.unregisterStreams(ctx)
11097

111-
// Poll for new messages
98+
// Message channel to receive messages from the poller
99+
type pollResult struct {
100+
msg *pubsub.Message
101+
err error
102+
}
103+
msgChan := make(chan pollResult)
104+
105+
// Start poller goroutine
106+
go func() {
107+
defer close(msgChan)
108+
for {
109+
msg, err := mux.sub.Receive(ctx)
110+
select {
111+
case <-ctx.Done():
112+
return
113+
case msgChan <- pollResult{msg: msg, err: err}:
114+
// If context is done, stop.
115+
if ctx.Err() != nil {
116+
return
117+
}
118+
// Otherwise, loop again (retry on error).
119+
}
120+
}
121+
}()
122+
123+
for {
112124
select {
113125
case <-ctx.Done():
114126
slog.DebugContext(ctx, "mux context finished, exiting")
115127
return ctx.Err()
116-
default:
117-
slog.DebugContext(ctx, "mux polling for message")
118-
if err := mux.poll(ctx); err != nil {
119-
slog.ErrorContext(ctx, "mux failed to poll subscription", "error", err)
128+
129+
case s := <-mux.register:
130+
// Handle Registration
131+
slog.DebugContext(ctx, "mux registering new stream", "stream_id", s.id)
132+
mux.streams[s] = true
133+
134+
// Send history to the new stream
135+
if state, ok := mux.history[s.id]; ok && state.buffer != nil {
136+
data := state.buffer.Bytes()
137+
if len(data) > 0 {
138+
slog.DebugContext(ctx, "mux sending history to new stream", "stream_id", s.id, "bytes", len(data))
139+
msg := &pubsub.Message{
140+
Body: data,
141+
Metadata: map[string]string{
142+
metadataID: s.id,
143+
MetadataMsgKind: "data",
144+
// No order key needed for history injection
145+
},
146+
}
147+
s.processOneMessage(ctx, msg)
148+
}
149+
}
150+
151+
case s := <-mux.unregister:
152+
// Handle Unregistration
153+
slog.DebugContext(ctx, "mux unregistering stream", "stream_id", s.id)
154+
delete(mux.streams, s)
155+
s.Close()
156+
157+
case res, ok := <-msgChan:
158+
if !ok {
159+
// Poller exited. If due to context cancel, return that error.
160+
if ctx.Err() != nil {
161+
return ctx.Err()
162+
}
163+
return fmt.Errorf("poller exited unexpectedly")
120164
}
165+
if res.err != nil {
166+
// Log error and continue, matching original behavior (retry loop).
167+
// Unless context is done.
168+
if ctx.Err() != nil {
169+
return ctx.Err()
170+
}
171+
slog.ErrorContext(ctx, "mux failed to poll subscription", "error", res.err)
172+
continue
173+
}
174+
175+
// Handle Message
176+
mux.handleMessage(ctx, res.msg)
121177
}
122178
}
123179
}
124180

125-
// poll for a new message, broadcasting to all registered streams.
126-
// poll will also register & unregister streams after a new message is received.
127-
func (mux *Mux) poll(ctx context.Context) error {
128-
// Block waiting for message
129-
msg, err := mux.sub.Receive(ctx)
130-
if err != nil {
131-
return fmt.Errorf("failed to poll for new message: %w", err)
132-
}
133-
181+
// handleMessage processes a new message, updating history and broadcasting to streams.
182+
func (mux *Mux) handleMessage(ctx context.Context, msg *pubsub.Message) {
134183
// Always acknowledge the message
135184
defer msg.Ack()
136185

137-
// Manage Streams
138-
mux.registerStreams(ctx)
139-
mux.unregisterStreams(ctx)
140-
141186
// Get Message Metadata
142187
msgID, ok := msg.Metadata["id"]
143188
if !ok {
144189
slog.DebugContext(ctx, "mux received message without 'id' for stream, ignoring")
145-
return nil
190+
return
146191
}
147192
msgOrderKey, ok := msg.Metadata[metadataOrderKey]
148193
if !ok {
@@ -153,6 +198,34 @@ func (mux *Mux) poll(ctx context.Context) error {
153198
slog.DebugContext(ctx, "mux received message without msgOrderIndex")
154199
}
155200

201+
// Update History
202+
// Only buffer "data" messages (or messages with no kind specified, which default to data)
203+
kind, hasKind := msg.Metadata[MetadataMsgKind]
204+
if !hasKind || kind == "data" {
205+
state, ok := mux.history[msgID]
206+
if !ok {
207+
state = &historyState{
208+
buffer: NewCircularBuffer(mux.historySize),
209+
sessions: make(map[string]*sessionBuffer),
210+
}
211+
mux.history[msgID] = state
212+
}
213+
214+
// Use sessionBuffer to reorder messages before writing to circular buffer
215+
key := parseOrderKey(msg)
216+
sessBuf, ok := state.sessions[key]
217+
if !ok {
218+
sessBuf = &sessionBuffer{
219+
data: make(map[uint64]*pubsub.Message, maxStreamOrderBuf),
220+
}
221+
state.sessions[key] = sessBuf
222+
}
223+
224+
sessBuf.writeMessage(ctx, msg, func(m *pubsub.Message) {
225+
state.buffer.Write(m.Body)
226+
})
227+
}
228+
156229
// Broadcast Message
157230
slog.DebugContext(ctx, "mux broadcasting received message",
158231
"msg_id", msgID,
@@ -176,6 +249,4 @@ func (mux *Mux) poll(ctx context.Context) error {
176249
s.processOneMessage(ctx, msg)
177250
}
178251
}
179-
180-
return nil
181252
}

0 commit comments

Comments
 (0)