Skip to content

Commit 90a2046

Browse files
committed
feat(agent,ssh,pkg): add connection versioning to support for v1 and v2 tunnels
This change refactors the agent to support multiple connection protocol versions, introducing a new configuration option ConnectionVersion to select between them. Version 1 uses an Echo-based tunnel with HTTP-like handlers, while version 2 leverages a Yamux-based tunnel integrated with the client. The agent’s Listen method now delegates to version-specific listeners, and related handlers were updated accordingly. The tunnel abstraction was also revised to adopt generics, renaming Handler to HandlerFunc and splitting out implementations for v1 and v2. Additionally, reverse listener creation was modernized in the API client with explicit NewReverseListenerV1 and NewReverseListenerV2 methods, each with their own configuration management. These changes simplify the agent’s architecture while allowing incremental migration to the new tunnel protocol.
1 parent 171ce07 commit 90a2046

File tree

13 files changed

+803
-437
lines changed

13 files changed

+803
-437
lines changed

pkg/agent/agent.go

Lines changed: 78 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,16 @@ import (
4242
"context"
4343
"crypto/rsa"
4444
"fmt"
45-
"io"
4645
"math/rand"
4746
"net"
48-
"net/netip"
4947
"net/url"
5048
"os"
5149
"runtime"
5250
"strings"
53-
"sync"
5451
"sync/atomic"
5552
"time"
5653

5754
"github.com/Masterminds/semver"
58-
dockerclient "github.com/docker/docker/client"
5955
"github.com/pkg/errors"
6056
"github.com/shellhub-io/shellhub/pkg/agent/pkg/keygen"
6157
"github.com/shellhub-io/shellhub/pkg/agent/pkg/sysinfo"
@@ -125,6 +121,10 @@ type Config struct {
125121
// MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
126122
// before attempting to reconnect to the ShellHub server. Default is 60 seconds.
127123
MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`
124+
125+
// ConnectionVersion specifies the version of the connection protocol to use.
126+
// Supported values are 1 and 2. Default is 1.
127+
ConnectionVersion int `env:"CONNECTION_VERSION,default=1" validate:"oneof=1 2"`
128128
}
129129

130130
func LoadConfigFromEnv() (*Config, map[string]interface{}, error) {
@@ -174,7 +174,6 @@ type Agent struct {
174174
cli client.Client
175175
serverInfo *models.Info
176176
server *server.Server
177-
tunnel *tunnel.Tunnel
178177
listening chan bool
179178
closed atomic.Bool
180179
mode Mode
@@ -388,234 +387,100 @@ func (a *Agent) Close() error {
388387
return a.conn.Close()
389388
}
390389

391-
// httpProxyHandler handlers proxy connections to the required address.
392-
func httpProxyHandler(agent *Agent) tunnel.Handler {
393-
const ProxyHandlerNetwork = "tcp"
394-
395-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
396-
headers, err := ctx.Headers()
397-
if err != nil {
398-
log.WithError(err).Error("failed to get the headers from the connection")
399-
400-
return err
401-
}
402-
403-
id := headers["id"]
404-
host := headers["host"]
405-
port := headers["port"]
390+
func (a *Agent) Listen(ctx context.Context) error {
391+
a.mode.Serve(a)
406392

407-
logger := log.WithFields(log.Fields{
408-
"id": id,
409-
"host": host,
410-
"port": port,
411-
})
393+
switch a.config.ConnectionVersion {
394+
case 1:
395+
return a.listenV1(ctx)
396+
case 2:
397+
return a.listenV2(ctx)
398+
default:
399+
return fmt.Errorf("unsupported connection version: %d", a.config.ConnectionVersion)
400+
}
401+
}
412402

413-
if _, ok := agent.mode.(*ConnectorMode); ok {
414-
cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
415-
if err != nil {
416-
log.WithError(err).Error("failed to create the Docker client")
403+
func (a *Agent) listenV1(ctx context.Context) error {
404+
tun := tunnel.NewTunnelV1()
417405

418-
return ctx.Error(errors.New("failed to connect to the Docker Engine"))
419-
}
406+
tun.Handle(HandleSSHOpenV1, sshHandlerV1(a))
407+
tun.Handle(HandleSSHCloseV1, sshCloseHandlerV1(a))
408+
tun.Handle(HandleHTTPProxyV1, httpProxyHandlerV1(a))
420409

421-
container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID)
422-
if err != nil {
423-
log.WithError(err).Error("failed to inspect the container")
410+
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
424411

425-
return ctx.Error(errors.New("failed to inspect the container"))
426-
}
412+
logger := log.WithFields(log.Fields{
413+
"version": AgentVersion,
414+
"tenant_id": a.authData.Namespace,
415+
"server_address": a.config.ServerAddress,
416+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
417+
"api_endpoint": a.serverInfo.Endpoints.API,
418+
"connection_version": a.config.ConnectionVersion,
419+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
420+
})
427421

428-
var target string
422+
ctx, cancel := context.WithCancel(ctx)
423+
go func() {
424+
for {
425+
if a.isClosed() {
426+
logger.Info("Stopped listening for connections")
429427

430-
addr, err := netip.ParseAddr(host)
431-
if err != nil {
432-
log.WithError(err).Error("failed to parse the address on proxy")
428+
cancel()
433429

434-
return ctx.Error(errors.New("failed to parse the address on proxy"))
430+
return
435431
}
436432

437-
if addr.IsLoopback() {
438-
log.Trace("host is a loopback address, using the container IP address")
439-
440-
for _, network := range container.NetworkSettings.Networks {
441-
target = network.IPAddress
442-
443-
break
444-
}
445-
} else {
446-
for _, network := range container.NetworkSettings.Networks {
447-
subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen))
448-
if err != nil {
449-
logger.WithError(err).Error("failed to parse the gateway on proxy")
450-
451-
continue
452-
}
453-
454-
ip, err := netip.ParseAddr(host)
455-
if err != nil {
456-
logger.WithError(err).Error("failed to parse the address on proxy")
433+
ShellHubConnectV1Path := "/ssh/connection"
457434

458-
continue
459-
}
435+
logger.Debug("Using tunnel version 1")
460436

461-
if subnet.Contains(ip) {
462-
target = ip.String()
437+
listener, err := a.cli.NewReverseListenerV1(
438+
ctx,
439+
a.authData.Token,
440+
ShellHubConnectV1Path,
441+
)
442+
if err != nil {
443+
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
463444

464-
break
465-
}
466-
}
467-
}
445+
time.Sleep(time.Second * 10)
468446

469-
if target == "" {
470-
return ctx.Error(errors.New("address not found on the device"))
447+
continue
471448
}
472449

473-
host = target
474-
}
475-
476-
ErrFailedDialToAddressAndPort := errors.New("failed to dial to the address and port")
477-
478-
logger.Trace("proxy handler connecting to the address")
479-
480-
in, err := net.Dial(ProxyHandlerNetwork, net.JoinHostPort(host, port))
481-
if err != nil {
482-
logger.WithError(err).Error("proxy handler failed to dial to the address")
483-
484-
return ctx.Error(ErrFailedDialToAddressAndPort)
485-
}
486-
487-
defer in.Close()
488-
489-
logger.Trace("proxy handler dialed to the address")
490-
491-
// TODO: Add consts for status values.
492-
if err := ctx.Status("ok"); err != nil {
493-
logger.WithError(err).Error("proxy handler failed to send status response")
494-
495-
return err
496-
}
497-
498-
wg := new(sync.WaitGroup)
499-
done := sync.OnceFunc(func() {
500-
defer in.Close()
501-
defer rwc.Close()
502-
503-
logger.Trace("close called on in and out connections")
504-
})
505-
506-
wg.Add(1)
507-
go func() {
508-
defer done()
509-
defer wg.Done()
510-
511-
if _, err := io.Copy(in, rwc); err != nil && err != io.EOF {
512-
logger.WithError(err).Error("proxy handler copy from rwc to in failed")
513-
}
514-
}()
450+
logger.Info("Server connection established")
515451

516-
wg.Add(1)
517-
go func() {
518-
defer done()
519-
defer wg.Done()
452+
a.listening <- true
520453

521-
if _, err := io.Copy(rwc, in); err != nil && err != io.EOF {
522-
logger.WithError(err).Error("proxy handler copy from in to rwc failed")
454+
if err := tun.Listen(ctx, listener); err != nil {
455+
logger.WithError(err).Error("Tunnel listener exited with error")
523456
}
524-
}()
525-
526-
logger.WithError(err).Info("proxy handler waiting for data pipe")
527-
528-
wg.Wait()
529-
530-
logger.WithError(err).Info("proxy handler done")
531-
532-
return nil
533-
}
534-
}
535-
536-
func sshHandler(agent *Agent) tunnel.Handler {
537-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
538-
defer rwc.Close()
539-
540-
headers, err := ctx.Headers()
541-
if err != nil {
542-
log.WithError(err).Error("failed to get the headers from the connection")
543-
544-
return err
545-
}
546-
547-
id := headers["id"]
548-
549-
conn, ok := rwc.(net.Conn)
550-
if !ok {
551-
log.Error("failed to cast the ReadWriteCloser to net.Conn")
552-
553-
return errors.New("failed to cast the ReadWriteCloser to net.Conn")
554-
}
555-
556-
agent.server.Sessions.Store(id, conn)
557-
agent.server.HandleConn(conn)
558457

559-
return nil
560-
}
561-
}
562-
563-
func sshCloseHandler(agent *Agent) tunnel.Handler {
564-
return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error {
565-
defer rwc.Close()
566-
567-
headers, err := ctx.Headers()
568-
if err != nil {
569-
log.WithError(err).Error("failed to get the headers from the connection")
570-
571-
return err
458+
a.listening <- false
572459
}
460+
}()
573461

574-
id := headers["id"]
575-
576-
agent.server.CloseSession(id)
577-
578-
log.WithFields(
579-
log.Fields{
580-
"id": id,
581-
"version": AgentVersion,
582-
"tenant_id": agent.authData.Namespace,
583-
"server_address": agent.config.ServerAddress,
584-
},
585-
).Info("A tunnel connection was closed")
462+
<-ctx.Done()
586463

587-
return nil
588-
}
464+
return a.Close()
589465
}
590466

591-
const (
592-
// HandleSSHOpen is the protocol used to open a new SSH connection.
593-
HandleSSHOpen = "/ssh/open/1.0.0"
594-
// HandleSSHClose is the protocol used to close an existing SSH connection.
595-
HandleSSHClose = "/ssh/close/1.0.0"
596-
// HandleHTTPProxy is the protocol used to open a new HTTP proxy connection.
597-
HandleHTTPProxy = "/http/proxy/1.0.0"
598-
)
467+
func (a *Agent) listenV2(ctx context.Context) error {
468+
tun := tunnel.NewTunnelV2(a.cli)
599469

600-
// Listen creates the SSH server and listening for connections.
601-
func (a *Agent) Listen(ctx context.Context) error {
602-
a.mode.Serve(a)
603-
604-
a.tunnel = tunnel.NewTunnel()
605-
606-
a.tunnel.Handle(HandleSSHOpen, sshHandler(a))
607-
a.tunnel.Handle(HandleSSHClose, sshCloseHandler(a))
608-
a.tunnel.Handle(HandleHTTPProxy, httpProxyHandler(a))
470+
tun.Handle(HandleSSHOpenV2, sshHandlerV2(a))
471+
tun.Handle(HandleSSHCloseV2, sshCloseHandlerV2(a))
472+
tun.Handle(HandleHTTPProxyV2, httpProxyHandlerV2(a))
609473

610474
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck
611475

612476
logger := log.WithFields(log.Fields{
613-
"version": AgentVersion,
614-
"tenant_id": a.authData.Namespace,
615-
"server_address": a.config.ServerAddress,
616-
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
617-
"api_endpoint": a.serverInfo.Endpoints.API,
618-
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
477+
"version": AgentVersion,
478+
"tenant_id": a.authData.Namespace,
479+
"server_address": a.config.ServerAddress,
480+
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
481+
"api_endpoint": a.serverInfo.Endpoints.API,
482+
"connection_version": a.config.ConnectionVersion,
483+
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
619484
})
620485

621486
ctx, cancel := context.WithCancel(ctx)
@@ -629,9 +494,16 @@ func (a *Agent) Listen(ctx context.Context) error {
629494
return
630495
}
631496

632-
DefaultShellHubConnectPath := "/connection"
497+
ShellHubConnectV2Path := "/connection"
633498

634-
conn, err := a.cli.Connect(ctx, a.authData.Token, DefaultShellHubConnectPath)
499+
logger.Debug("Using tunnel version 2")
500+
501+
listener, err := a.cli.NewReverseListenerV2(
502+
ctx,
503+
a.authData.Token,
504+
ShellHubConnectV2Path,
505+
client.NewReverseV2ConfigFromMap(a.authData.Config),
506+
)
635507
if err != nil {
636508
logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
637509

@@ -640,13 +512,11 @@ func (a *Agent) Listen(ctx context.Context) error {
640512
continue
641513
}
642514

643-
a.conn = conn
644-
645515
logger.Info("Server connection established")
646516

647517
a.listening <- true
648518

649-
if err := a.tunnel.Listen(conn, tunnel.NewConfigFromMap(a.authData.Config)); err != nil {
519+
if err := tun.Listen(ctx, listener); err != nil {
650520
logger.WithError(err).Error("Tunnel listener exited with error")
651521
}
652522

0 commit comments

Comments
 (0)