Skip to content

Commit 6cb0aab

Browse files
committed
feat(agent,ssh,pkg): add connection versioning to support for v1 and v2 tunnels
This change refactors the agent to support multiple connection protocol versions, introducing a new configuration option ConnectionVersion to select between them. Version 1 uses an Echo-based tunnel with HTTP-like handlers, while version 2 leverages a Yamux-based tunnel integrated with the client. The agent’s Listen method now delegates to version-specific listeners, and related handlers were updated accordingly. The tunnel abstraction was also revised to adopt generics, renaming Handler to HandlerFunc and splitting out implementations for v1 and v2. Additionally, reverse listener creation was modernized in the API client with explicit NewReverseListenerV1 and NewReverseListenerV2 methods, each with their own configuration management. These changes simplify the agent’s architecture while allowing incremental migration to the new tunnel protocol.
1 parent 563c366 commit 6cb0aab

File tree

13 files changed

+819
-440
lines changed

13 files changed

+819
-440
lines changed

pkg/agent/agent.go

Lines changed: 88 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,16 @@ import (
4242
"context"
4343
"crypto/rsa"
4444
"fmt"
45-
"io"
4645
"math/rand"
4746
"net"
48-
"net/netip"
4947
"net/url"
5048
"os"
5149
"runtime"
5250
"strings"
53-
"sync"
5451
"sync/atomic"
5552
"time"
5653

5754
"github.com/Masterminds/semver"
58-
dockerclient "github.com/docker/docker/client"
5955
"github.com/pkg/errors"
6056
"github.com/shellhub-io/shellhub/pkg/agent/pkg/keygen"
6157
"github.com/shellhub-io/shellhub/pkg/agent/pkg/sysinfo"
@@ -125,6 +121,10 @@ type Config struct {
125121
// MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
126122
// before attempting to reconnect to the ShellHub server. Default is 60 seconds.
127123
MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`
124+
125+
// ConnectionVersion specifies the version of the connection protocol to use.
126+
// Supported values are 1 and 2. Default is 1.
127+
ConnectionVersion int `env:"CONNECTION_VERSION,default=1"`
128128
}
129129

130130
func LoadConfigFromEnv() (*Config, map[string]interface{}, error) {
@@ -174,12 +174,11 @@ type Agent struct {
174174
cli client.Client
175175
serverInfo *models.Info
176176
server *server.Server
177-
tunnel *tunnel.Tunnel
178177
listening chan bool
179178
closed atomic.Bool
180179
mode Mode
181-
// conn is the current connection to the server.
182-
conn net.Conn
180+
// listener is the current connection to the server.
181+
listener atomic.Pointer[net.Listener]
183182
}
184183

185184
// NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant
@@ -385,237 +384,109 @@ func (a *Agent) isClosed() bool {
385384
func (a *Agent) Close() error {
386385
a.closed.Store(true)
387386

388-
return a.conn.Close()
389-
}
390-
391-
// httpProxyHandler handlers proxy connections to the required address.
392-
func httpProxyHandler(agent *Agent) tunnel.Handler {
393-
const ProxyHandlerNetwork = "tcp"
387+
l := a.listener.Load()
388+
if l == nil {
389+
return nil
390+
}
394391

395-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
396-
headers, err := ctx.Headers()
397-
if err != nil {
398-
log.WithError(err).Error("failed to get the headers from the connection")
392+
return (*l).Close()
393+
}
399394

400-
return err
401-
}
395+
func (a *Agent) Listen(ctx context.Context) error {
396+
a.mode.Serve(a)
402397

403-
id := headers["id"]
404-
host := headers["host"]
405-
port := headers["port"]
398+
switch a.config.ConnectionVersion {
399+
case 1:
400+
return a.listenV1(ctx)
401+
case 2:
402+
return a.listenV2(ctx)
403+
default:
404+
return fmt.Errorf("unsupported connection version: %d", a.config.ConnectionVersion)
405+
}
406+
}
406407

407-
logger := log.WithFields(log.Fields{
408-
"id": id,
409-
"host": host,
410-
"port": port,
411-
})
408+
func (a *Agent) listenV1(ctx context.Context) error {
409+
tun := tunnel.NewTunnelV1()
412410

413-
if _, ok := agent.mode.(*ConnectorMode); ok {
414-
cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
415-
if err != nil {
416-
log.WithError(err).Error("failed to create the Docker client")
411+
tun.Handle(HandleSSHOpenV1, sshHandlerV1(a))
412+
tun.Handle(HandleSSHCloseV1, sshCloseHandlerV1(a))
413+
tun.Handle(HandleHTTPProxyV1, httpProxyHandlerV1(a))
417414

418-
return ctx.Error(errors.New("failed to connect to the Docker Engine"))
419-
}
420-
421-
container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID)
422-
if err != nil {
423-
log.WithError(err).Error("failed to inspect the container")
415+
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
424416

425-
return ctx.Error(errors.New("failed to inspect the container"))
426-
}
417+
logger := log.WithFields(log.Fields{
418+
"version": AgentVersion,
419+
"tenant_id": a.authData.Namespace,
420+
"server_address": a.config.ServerAddress,
421+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
422+
"api_endpoint": a.serverInfo.Endpoints.API,
423+
"connection_version": a.config.ConnectionVersion,
424+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
425+
})
427426

428-
var target string
427+
ctx, cancel := context.WithCancel(ctx)
428+
go func() {
429+
for {
430+
if a.isClosed() {
431+
logger.Info("Stopped listening for connections")
429432

430-
addr, err := netip.ParseAddr(host)
431-
if err != nil {
432-
log.WithError(err).Error("failed to parse the address on proxy")
433+
cancel()
433434

434-
return ctx.Error(errors.New("failed to parse the address on proxy"))
435+
return
435436
}
436437

437-
if addr.IsLoopback() {
438-
log.Trace("host is a loopback address, using the container IP address")
439-
440-
for _, network := range container.NetworkSettings.Networks {
441-
target = network.IPAddress
442-
443-
break
444-
}
445-
} else {
446-
for _, network := range container.NetworkSettings.Networks {
447-
subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen))
448-
if err != nil {
449-
logger.WithError(err).Error("failed to parse the gateway on proxy")
450-
451-
continue
452-
}
453-
454-
ip, err := netip.ParseAddr(host)
455-
if err != nil {
456-
logger.WithError(err).Error("failed to parse the address on proxy")
438+
ShellHubConnectV1Path := "/ssh/connection"
457439

458-
continue
459-
}
440+
logger.Debug("Using tunnel version 1")
460441

461-
if subnet.Contains(ip) {
462-
target = ip.String()
442+
listener, err := a.cli.NewReverseListenerV1(
443+
ctx,
444+
a.authData.Token,
445+
ShellHubConnectV1Path,
446+
)
447+
if err != nil {
448+
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
463449

464-
break
465-
}
466-
}
467-
}
450+
time.Sleep(time.Second * 10)
468451

469-
if target == "" {
470-
return ctx.Error(errors.New("address not found on the device"))
452+
continue
471453
}
454+
a.listener.Store(&listener)
472455

473-
host = target
474-
}
475-
476-
ErrFailedDialToAddressAndPort := errors.New("failed to dial to the address and port")
477-
478-
logger.Trace("proxy handler connecting to the address")
479-
480-
in, err := net.Dial(ProxyHandlerNetwork, net.JoinHostPort(host, port))
481-
if err != nil {
482-
logger.WithError(err).Error("proxy handler failed to dial to the address")
483-
484-
return ctx.Error(ErrFailedDialToAddressAndPort)
485-
}
486-
487-
defer in.Close()
488-
489-
logger.Trace("proxy handler dialed to the address")
490-
491-
// TODO: Add consts for status values.
492-
if err := ctx.Status("ok"); err != nil {
493-
logger.WithError(err).Error("proxy handler failed to send status response")
494-
495-
return err
496-
}
497-
498-
wg := new(sync.WaitGroup)
499-
done := sync.OnceFunc(func() {
500-
defer in.Close()
501-
defer rwc.Close()
502-
503-
logger.Trace("close called on in and out connections")
504-
})
505-
506-
wg.Add(1)
507-
go func() {
508-
defer done()
509-
defer wg.Done()
510-
511-
if _, err := io.Copy(in, rwc); err != nil && err != io.EOF {
512-
logger.WithError(err).Error("proxy handler copy from rwc to in failed")
513-
}
514-
}()
456+
logger.Info("Server connection established")
515457

516-
wg.Add(1)
517-
go func() {
518-
defer done()
519-
defer wg.Done()
458+
a.listening <- true
520459

521-
if _, err := io.Copy(rwc, in); err != nil && err != io.EOF {
522-
logger.WithError(err).Error("proxy handler copy from in to rwc failed")
460+
if err := tun.Listen(ctx, listener); err != nil {
461+
logger.WithError(err).Error("Tunnel listener exited with error")
523462
}
524-
}()
525-
526-
logger.WithError(err).Info("proxy handler waiting for data pipe")
527-
528-
wg.Wait()
529-
530-
logger.WithError(err).Info("proxy handler done")
531-
532-
return nil
533-
}
534-
}
535-
536-
func sshHandler(agent *Agent) tunnel.Handler {
537-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
538-
defer rwc.Close()
539-
540-
headers, err := ctx.Headers()
541-
if err != nil {
542-
log.WithError(err).Error("failed to get the headers from the connection")
543-
544-
return err
545-
}
546-
547-
id := headers["id"]
548-
549-
conn, ok := rwc.(net.Conn)
550-
if !ok {
551-
log.Error("failed to cast the ReadWriteCloser to net.Conn")
552-
553-
return errors.New("failed to cast the ReadWriteCloser to net.Conn")
554-
}
555463

556-
agent.server.Sessions.Store(id, conn)
557-
agent.server.HandleConn(conn)
558-
559-
return nil
560-
}
561-
}
562-
563-
func sshCloseHandler(agent *Agent) tunnel.Handler {
564-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
565-
defer rwc.Close()
566-
567-
headers, err := ctx.Headers()
568-
if err != nil {
569-
log.WithError(err).Error("failed to get the headers from the connection")
570-
571-
return err
464+
a.listening <- false
572465
}
466+
}()
573467

574-
id := headers["id"]
575-
576-
agent.server.CloseSession(id)
577-
578-
log.WithFields(
579-
log.Fields{
580-
"id": id,
581-
"version": AgentVersion,
582-
"tenant_id": agent.authData.Namespace,
583-
"server_address": agent.config.ServerAddress,
584-
},
585-
).Info("A tunnel connection was closed")
468+
<-ctx.Done()
586469

587-
return nil
588-
}
470+
return a.Close()
589471
}
590472

591-
const (
592-
// HandleSSHOpen is the protocol used to open a new SSH connection.
593-
HandleSSHOpen = "/ssh/open/1.0.0"
594-
// HandleSSHClose is the protocol used to close an existing SSH connection.
595-
HandleSSHClose = "/ssh/close/1.0.0"
596-
// HandleHTTPProxy is the protocol used to open a new HTTP proxy connection.
597-
HandleHTTPProxy = "/http/proxy/1.0.0"
598-
)
599-
600-
// Listen creates the SSH server and listening for connections.
601-
func (a *Agent) Listen(ctx context.Context) error {
602-
a.mode.Serve(a)
603-
604-
a.tunnel = tunnel.NewTunnel()
473+
func (a *Agent) listenV2(ctx context.Context) error {
474+
tun := tunnel.NewTunnelV2(a.cli)
605475

606-
a.tunnel.Handle(HandleSSHOpen, sshHandler(a))
607-
a.tunnel.Handle(HandleSSHClose, sshCloseHandler(a))
608-
a.tunnel.Handle(HandleHTTPProxy, httpProxyHandler(a))
476+
tun.Handle(HandleSSHOpenV2, sshHandlerV2(a))
477+
tun.Handle(HandleSSHCloseV2, sshCloseHandlerV2(a))
478+
tun.Handle(HandleHTTPProxyV2, httpProxyHandlerV2(a))
609479

610480
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
611481

612482
logger := log.WithFields(log.Fields{
613-
"version": AgentVersion,
614-
"tenant_id": a.authData.Namespace,
615-
"server_address": a.config.ServerAddress,
616-
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
617-
"api_endpoint": a.serverInfo.Endpoints.API,
618-
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
483+
"version": AgentVersion,
484+
"tenant_id": a.authData.Namespace,
485+
"server_address": a.config.ServerAddress,
486+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
487+
"api_endpoint": a.serverInfo.Endpoints.API,
488+
"connection_version": a.config.ConnectionVersion,
489+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
619490
})
620491

621492
ctx, cancel := context.WithCancel(ctx)
@@ -629,24 +500,30 @@ func (a *Agent) Listen(ctx context.Context) error {
629500
return
630501
}
631502

632-
DefaultShellHubConnectPath := "/connection"
503+
ShellHubConnectV2Path := "/connection"
504+
505+
logger.Debug("Using tunnel version 2")
633506

634-
conn, err := a.cli.Connect(ctx, a.authData.Token, DefaultShellHubConnectPath)
507+
listener, err := a.cli.NewReverseListenerV2(
508+
ctx,
509+
a.authData.Token,
510+
ShellHubConnectV2Path,
511+
client.NewReverseV2ConfigFromMap(a.authData.Config),
512+
)
635513
if err != nil {
636514
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
637515

638516
time.Sleep(time.Second * 10)
639517

640518
continue
641519
}
642-
643-
a.conn = conn
520+
a.listener.Store(&listener)
644521

645522
logger.Info("Server connection established")
646523

647524
a.listening <- true
648525

649-
if err := a.tunnel.Listen(conn, tunnel.NewConfigFromMap(a.authData.Config)); err != nil {
526+
if err := tun.Listen(ctx, listener); err != nil {
650527
logger.WithError(err).Error("Tunnel listener exited with error")
651528
}
652529

0 commit comments

Comments
 (0)