Skip to content

Commit 275b4d7

Browse files
committed
feat(agent,ssh,pkg): add connection versioning to support for v1 and v2 tunnels
This change refactors the agent to support multiple connection protocol versions, introducing a new configuration option ConnectionVersion to select between them. Version 1 uses an Echo-based tunnel with HTTP-like handlers, while version 2 leverages a Yamux-based tunnel integrated with the client. The agent’s Listen method now delegates to version-specific listeners, and related handlers were updated accordingly. The tunnel abstraction was also revised to adopt generics, renaming Handler to HandlerFunc and splitting out implementations for v1 and v2. Additionally, reverse listener creation was modernized in the API client with explicit NewReverseListenerV1 and NewReverseListenerV2 methods, each with their own configuration management. These changes simplify the agent’s architecture while allowing incremental migration to the new tunnel protocol.
1 parent 159c614 commit 275b4d7

File tree

18 files changed

+849
-442
lines changed

18 files changed

+849
-442
lines changed

agent/agent.go

Lines changed: 88 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,16 @@ import (
4242
"context"
4343
"crypto/rsa"
4444
"fmt"
45-
"io"
4645
"math/rand"
4746
"net"
48-
"net/netip"
4947
"net/url"
5048
"os"
5149
"runtime"
5250
"strings"
53-
"sync"
5451
"sync/atomic"
5552
"time"
5653

5754
"github.com/Masterminds/semver"
58-
dockerclient "github.com/docker/docker/client"
5955
"github.com/pkg/errors"
6056
"github.com/shellhub-io/shellhub/agent/pkg/keygen"
6157
"github.com/shellhub-io/shellhub/agent/pkg/sysinfo"
@@ -110,6 +106,10 @@ type Config struct {
110106
// MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
111107
// before attempting to reconnect to the ShellHub server. Default is 60 seconds.
112108
MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`
109+
110+
// ConnectionVersion specifies the version of the connection protocol to use.
111+
// Supported values are 1 and 2. Default is 1.
112+
ConnectionVersion int `env:"CONNECTION_VERSION,default=1"`
113113
}
114114

115115
func LoadConfigFromEnv() (*Config, map[string]interface{}, error) {
@@ -159,12 +159,11 @@ type Agent struct {
159159
cli client.Client
160160
serverInfo *models.Info
161161
server *server.Server
162-
tunnel *tunnel.Tunnel
163162
listening chan bool
164163
closed atomic.Bool
165164
mode Mode
166-
// conn is the current connection to the server.
167-
conn net.Conn
165+
// listener is the current connection to the server.
166+
listener atomic.Pointer[net.Listener]
168167
}
169168

170169
// NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant
@@ -370,237 +369,109 @@ func (a *Agent) isClosed() bool {
370369
func (a *Agent) Close() error {
371370
a.closed.Store(true)
372371

373-
return a.conn.Close()
374-
}
375-
376-
// httpProxyHandler handlers proxy connections to the required address.
377-
func httpProxyHandler(agent *Agent) tunnel.Handler {
378-
const ProxyHandlerNetwork = "tcp"
372+
l := a.listener.Load()
373+
if l == nil {
374+
return nil
375+
}
379376

380-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
381-
headers, err := ctx.Headers()
382-
if err != nil {
383-
log.WithError(err).Error("failed to get the headers from the connection")
377+
return (*l).Close()
378+
}
384379

385-
return err
386-
}
380+
func (a *Agent) Listen(ctx context.Context) error {
381+
a.mode.Serve(a)
387382

388-
id := headers["id"]
389-
host := headers["host"]
390-
port := headers["port"]
383+
switch a.config.ConnectionVersion {
384+
case 1:
385+
return a.listenV1(ctx)
386+
case 2:
387+
return a.listenV2(ctx)
388+
default:
389+
return fmt.Errorf("unsupported connection version: %d", a.config.ConnectionVersion)
390+
}
391+
}
391392

392-
logger := log.WithFields(log.Fields{
393-
"id": id,
394-
"host": host,
395-
"port": port,
396-
})
393+
func (a *Agent) listenV1(ctx context.Context) error {
394+
tun := tunnel.NewTunnelV1()
397395

398-
if _, ok := agent.mode.(*ConnectorMode); ok {
399-
cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
400-
if err != nil {
401-
log.WithError(err).Error("failed to create the Docker client")
396+
tun.Handle(HandleSSHOpenV1, sshHandlerV1(a))
397+
tun.Handle(HandleSSHCloseV1, sshCloseHandlerV1(a))
398+
tun.Handle(HandleHTTPProxyV1, httpProxyHandlerV1(a))
402399

403-
return ctx.Error(errors.New("failed to connect to the Docker Engine"))
404-
}
405-
406-
container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID)
407-
if err != nil {
408-
log.WithError(err).Error("failed to inspect the container")
400+
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
409401

410-
return ctx.Error(errors.New("failed to inspect the container"))
411-
}
402+
logger := log.WithFields(log.Fields{
403+
"version": AgentVersion,
404+
"tenant_id": a.authData.Namespace,
405+
"server_address": a.config.ServerAddress,
406+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
407+
"api_endpoint": a.serverInfo.Endpoints.API,
408+
"connection_version": a.config.ConnectionVersion,
409+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
410+
})
412411

413-
var target string
412+
ctx, cancel := context.WithCancel(ctx)
413+
go func() {
414+
for {
415+
if a.isClosed() {
416+
logger.Info("Stopped listening for connections")
414417

415-
addr, err := netip.ParseAddr(host)
416-
if err != nil {
417-
log.WithError(err).Error("failed to parse the address on proxy")
418+
cancel()
418419

419-
return ctx.Error(errors.New("failed to parse the address on proxy"))
420+
return
420421
}
421422

422-
if addr.IsLoopback() {
423-
log.Trace("host is a loopback address, using the container IP address")
424-
425-
for _, network := range container.NetworkSettings.Networks {
426-
target = network.IPAddress
427-
428-
break
429-
}
430-
} else {
431-
for _, network := range container.NetworkSettings.Networks {
432-
subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen))
433-
if err != nil {
434-
logger.WithError(err).Error("failed to parse the gateway on proxy")
435-
436-
continue
437-
}
438-
439-
ip, err := netip.ParseAddr(host)
440-
if err != nil {
441-
logger.WithError(err).Error("failed to parse the address on proxy")
423+
ShellHubConnectV1Path := "/ssh/connection"
442424

443-
continue
444-
}
425+
logger.Debug("Using tunnel version 1")
445426

446-
if subnet.Contains(ip) {
447-
target = ip.String()
427+
listener, err := a.cli.NewReverseListenerV1(
428+
ctx,
429+
a.authData.Token,
430+
ShellHubConnectV1Path,
431+
)
432+
if err != nil {
433+
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
448434

449-
break
450-
}
451-
}
452-
}
435+
time.Sleep(time.Second * 10)
453436

454-
if target == "" {
455-
return ctx.Error(errors.New("address not found on the device"))
437+
continue
456438
}
439+
a.listener.Store(&listener)
457440

458-
host = target
459-
}
460-
461-
ErrFailedDialToAddressAndPort := errors.New("failed to dial to the address and port")
462-
463-
logger.Trace("proxy handler connecting to the address")
464-
465-
in, err := net.Dial(ProxyHandlerNetwork, net.JoinHostPort(host, port))
466-
if err != nil {
467-
logger.WithError(err).Error("proxy handler failed to dial to the address")
468-
469-
return ctx.Error(ErrFailedDialToAddressAndPort)
470-
}
471-
472-
defer in.Close()
473-
474-
logger.Trace("proxy handler dialed to the address")
475-
476-
// TODO: Add consts for status values.
477-
if err := ctx.Status("ok"); err != nil {
478-
logger.WithError(err).Error("proxy handler failed to send status response")
479-
480-
return err
481-
}
482-
483-
wg := new(sync.WaitGroup)
484-
done := sync.OnceFunc(func() {
485-
defer in.Close()
486-
defer rwc.Close()
487-
488-
logger.Trace("close called on in and out connections")
489-
})
490-
491-
wg.Add(1)
492-
go func() {
493-
defer done()
494-
defer wg.Done()
495-
496-
if _, err := io.Copy(in, rwc); err != nil && err != io.EOF {
497-
logger.WithError(err).Error("proxy handler copy from rwc to in failed")
498-
}
499-
}()
441+
logger.Info("Server connection established")
500442

501-
wg.Add(1)
502-
go func() {
503-
defer done()
504-
defer wg.Done()
443+
a.listening <- true
505444

506-
if _, err := io.Copy(rwc, in); err != nil && err != io.EOF {
507-
logger.WithError(err).Error("proxy handler copy from in to rwc failed")
445+
if err := tun.Listen(ctx, listener); err != nil {
446+
logger.WithError(err).Error("Tunnel listener exited with error")
508447
}
509-
}()
510-
511-
logger.WithError(err).Info("proxy handler waiting for data pipe")
512-
513-
wg.Wait()
514-
515-
logger.WithError(err).Info("proxy handler done")
516-
517-
return nil
518-
}
519-
}
520-
521-
func sshHandler(agent *Agent) tunnel.Handler {
522-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
523-
defer rwc.Close()
524-
525-
headers, err := ctx.Headers()
526-
if err != nil {
527-
log.WithError(err).Error("failed to get the headers from the connection")
528-
529-
return err
530-
}
531-
532-
id := headers["id"]
533-
534-
conn, ok := rwc.(net.Conn)
535-
if !ok {
536-
log.Error("failed to cast the ReadWriteCloser to net.Conn")
537-
538-
return errors.New("failed to cast the ReadWriteCloser to net.Conn")
539-
}
540448

541-
agent.server.Sessions.Store(id, conn)
542-
agent.server.HandleConn(conn)
543-
544-
return nil
545-
}
546-
}
547-
548-
func sshCloseHandler(agent *Agent) tunnel.Handler {
549-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
550-
defer rwc.Close()
551-
552-
headers, err := ctx.Headers()
553-
if err != nil {
554-
log.WithError(err).Error("failed to get the headers from the connection")
555-
556-
return err
449+
a.listening <- false
557450
}
451+
}()
558452

559-
id := headers["id"]
560-
561-
agent.server.CloseSession(id)
562-
563-
log.WithFields(
564-
log.Fields{
565-
"id": id,
566-
"version": AgentVersion,
567-
"tenant_id": agent.authData.Namespace,
568-
"server_address": agent.config.ServerAddress,
569-
},
570-
).Info("A tunnel connection was closed")
453+
<-ctx.Done()
571454

572-
return nil
573-
}
455+
return a.Close()
574456
}
575457

576-
const (
577-
// HandleSSHOpen is the protocol used to open a new SSH connection.
578-
HandleSSHOpen = "/ssh/open/1.0.0"
579-
// HandleSSHClose is the protocol used to close an existing SSH connection.
580-
HandleSSHClose = "/ssh/close/1.0.0"
581-
// HandleHTTPProxy is the protocol used to open a new HTTP proxy connection.
582-
HandleHTTPProxy = "/http/proxy/1.0.0"
583-
)
584-
585-
// Listen creates the SSH server and listening for connections.
586-
func (a *Agent) Listen(ctx context.Context) error {
587-
a.mode.Serve(a)
588-
589-
a.tunnel = tunnel.NewTunnel()
458+
func (a *Agent) listenV2(ctx context.Context) error {
459+
tun := tunnel.NewTunnelV2(a.cli)
590460

591-
a.tunnel.Handle(HandleSSHOpen, sshHandler(a))
592-
a.tunnel.Handle(HandleSSHClose, sshCloseHandler(a))
593-
a.tunnel.Handle(HandleHTTPProxy, httpProxyHandler(a))
461+
tun.Handle(HandleSSHOpenV2, sshHandlerV2(a))
462+
tun.Handle(HandleSSHCloseV2, sshCloseHandlerV2(a))
463+
tun.Handle(HandleHTTPProxyV2, httpProxyHandlerV2(a))
594464

595465
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
596466

597467
logger := log.WithFields(log.Fields{
598-
"version": AgentVersion,
599-
"tenant_id": a.authData.Namespace,
600-
"server_address": a.config.ServerAddress,
601-
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
602-
"api_endpoint": a.serverInfo.Endpoints.API,
603-
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
468+
"version": AgentVersion,
469+
"tenant_id": a.authData.Namespace,
470+
"server_address": a.config.ServerAddress,
471+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
472+
"api_endpoint": a.serverInfo.Endpoints.API,
473+
"connection_version": a.config.ConnectionVersion,
474+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
604475
})
605476

606477
ctx, cancel := context.WithCancel(ctx)
@@ -614,24 +485,30 @@ func (a *Agent) Listen(ctx context.Context) error {
614485
return
615486
}
616487

617-
DefaultShellHubConnectPath := "/connection"
488+
ShellHubConnectV2Path := "/connection"
489+
490+
logger.Debug("Using tunnel version 2")
618491

619-
conn, err := a.cli.Connect(ctx, a.authData.Token, DefaultShellHubConnectPath)
492+
listener, err := a.cli.NewReverseListenerV2(
493+
ctx,
494+
a.authData.Token,
495+
ShellHubConnectV2Path,
496+
client.NewReverseV2ConfigFromMap(a.authData.Config),
497+
)
620498
if err != nil {
621499
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
622500

623501
time.Sleep(time.Second * 10)
624502

625503
continue
626504
}
627-
628-
a.conn = conn
505+
a.listener.Store(&listener)
629506

630507
logger.Info("Server connection established")
631508

632509
a.listening <- true
633510

634-
if err := a.tunnel.Listen(conn, tunnel.NewConfigFromMap(a.authData.Config)); err != nil {
511+
if err := tun.Listen(ctx, listener); err != nil {
635512
logger.WithError(err).Error("Tunnel listener exited with error")
636513
}
637514

0 commit comments

Comments
 (0)