@@ -42,20 +42,16 @@ import (
4242 "context"
4343 "crypto/rsa"
4444 "fmt"
45- "io"
4645 "math/rand"
4746 "net"
48- "net/netip"
4947 "net/url"
5048 "os"
5149 "runtime"
5250 "strings"
53- "sync"
5451 "sync/atomic"
5552 "time"
5653
5754 "github.com/Masterminds/semver"
58- dockerclient "github.com/docker/docker/client"
5955 "github.com/pkg/errors"
6056 "github.com/shellhub-io/shellhub/pkg/agent/pkg/keygen"
6157 "github.com/shellhub-io/shellhub/pkg/agent/pkg/sysinfo"
@@ -125,6 +121,10 @@ type Config struct {
125121 // MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
126122 // before attempting to reconnect to the ShellHub server. Default is 60 seconds.
127123 MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`
124+
125+ // ConnectionVersion specifies the version of the connection protocol to use.
126+ // Supported values are 1 and 2. Default is 1.
127+ ConnectionVersion int `env:"CONNECTION_VERSION,default=1"`
128128}
129129
130130func LoadConfigFromEnv () (* Config , map [string ]interface {}, error ) {
@@ -174,12 +174,11 @@ type Agent struct {
174174 cli client.Client
175175 serverInfo * models.Info
176176 server * server.Server
177- tunnel * tunnel.Tunnel
178177 listening chan bool
179178 closed atomic.Bool
180179 mode Mode
181- // conn is the current connection to the server.
182- conn net.Conn
180+ // listener is the current connection to the server.
181+ listener atomic. Pointer [ net.Listener ]
183182}
184183
185184// NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant
@@ -385,237 +384,109 @@ func (a *Agent) isClosed() bool {
385384func (a * Agent ) Close () error {
386385 a .closed .Store (true )
387386
388- return a .conn .Close ()
389- }
390-
391- // httpProxyHandler handlers proxy connections to the required address.
392- func httpProxyHandler (agent * Agent ) tunnel.Handler {
393- const ProxyHandlerNetwork = "tcp"
387+ l := a .listener .Load ()
388+ if l == nil {
389+ return nil
390+ }
394391
395- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
396- headers , err := ctx .Headers ()
397- if err != nil {
398- log .WithError (err ).Error ("failed to get the headers from the connection" )
392+ return (* l ).Close ()
393+ }
399394
400- return err
401- }
395+ func ( a * Agent ) Listen ( ctx context. Context ) error {
396+ a . mode . Serve ( a )
402397
403- id := headers ["id" ]
404- host := headers ["host" ]
405- port := headers ["port" ]
398+ switch a .config .ConnectionVersion {
399+ case 1 :
400+ return a .listenV1 (ctx )
401+ case 2 :
402+ return a .listenV2 (ctx )
403+ default :
404+ return fmt .Errorf ("unsupported connection version: %d" , a .config .ConnectionVersion )
405+ }
406+ }
406407
407- logger := log .WithFields (log.Fields {
408- "id" : id ,
409- "host" : host ,
410- "port" : port ,
411- })
408+ func (a * Agent ) listenV1 (ctx context.Context ) error {
409+ tun := tunnel .NewTunnelV1 ()
412410
413- if _ , ok := agent .mode .(* ConnectorMode ); ok {
414- cli , err := dockerclient .NewClientWithOpts (dockerclient .FromEnv , dockerclient .WithAPIVersionNegotiation ())
415- if err != nil {
416- log .WithError (err ).Error ("failed to create the Docker client" )
411+ tun .Handle (HandleSSHOpenV1 , sshHandlerV1 (a ))
412+ tun .Handle (HandleSSHCloseV1 , sshCloseHandlerV1 (a ))
413+ tun .Handle (HandleHTTPProxyV1 , httpProxyHandlerV1 (a ))
417414
418- return ctx .Error (errors .New ("failed to connect to the Docker Engine" ))
419- }
420-
421- container , err := cli .ContainerInspect (context .Background (), agent .server .ContainerID )
422- if err != nil {
423- log .WithError (err ).Error ("failed to inspect the container" )
415+ go a .ping (ctx , AgentPingDefaultInterval ) //nolint:errcheck
424416
425- return ctx .Error (errors .New ("failed to inspect the container" ))
426- }
417+ logger := log .WithFields (log.Fields {
418+ "version" : AgentVersion ,
419+ "tenant_id" : a .authData .Namespace ,
420+ "server_address" : a .config .ServerAddress ,
421+ "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
422+ "api_endpoint" : a .serverInfo .Endpoints .API ,
423+ "connection_version" : a .config .ConnectionVersion ,
424+ "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
425+ })
427426
428- var target string
427+ ctx , cancel := context .WithCancel (ctx )
428+ go func () {
429+ for {
430+ if a .isClosed () {
431+ logger .Info ("Stopped listening for connections" )
429432
430- addr , err := netip .ParseAddr (host )
431- if err != nil {
432- log .WithError (err ).Error ("failed to parse the address on proxy" )
433+ cancel ()
433434
434- return ctx . Error ( errors . New ( "failed to parse the address on proxy" ))
435+ return
435436 }
436437
437- if addr .IsLoopback () {
438- log .Trace ("host is a loopback address, using the container IP address" )
439-
440- for _ , network := range container .NetworkSettings .Networks {
441- target = network .IPAddress
442-
443- break
444- }
445- } else {
446- for _ , network := range container .NetworkSettings .Networks {
447- subnet , err := netip .ParsePrefix (fmt .Sprintf ("%s/%d" , network .Gateway , network .IPPrefixLen ))
448- if err != nil {
449- logger .WithError (err ).Error ("failed to parse the gateway on proxy" )
450-
451- continue
452- }
453-
454- ip , err := netip .ParseAddr (host )
455- if err != nil {
456- logger .WithError (err ).Error ("failed to parse the address on proxy" )
438+ ShellHubConnectV1Path := "/ssh/connection"
457439
458- continue
459- }
440+ logger .Debug ("Using tunnel version 1" )
460441
461- if subnet .Contains (ip ) {
462- target = ip .String ()
442+ listener , err := a .cli .NewReverseListenerV1 (
443+ ctx ,
444+ a .authData .Token ,
445+ ShellHubConnectV1Path ,
446+ )
447+ if err != nil {
448+ logger .Error ("Failed to connect to server through reverse tunnel. Retry in 10 seconds" )
463449
464- break
465- }
466- }
467- }
450+ time .Sleep (time .Second * 10 )
468451
469- if target == "" {
470- return ctx .Error (errors .New ("address not found on the device" ))
452+ continue
471453 }
454+ a .listener .Store (& listener )
472455
473- host = target
474- }
475-
476- ErrFailedDialToAddressAndPort := errors .New ("failed to dial to the address and port" )
477-
478- logger .Trace ("proxy handler connecting to the address" )
479-
480- in , err := net .Dial (ProxyHandlerNetwork , net .JoinHostPort (host , port ))
481- if err != nil {
482- logger .WithError (err ).Error ("proxy handler failed to dial to the address" )
483-
484- return ctx .Error (ErrFailedDialToAddressAndPort )
485- }
486-
487- defer in .Close ()
488-
489- logger .Trace ("proxy handler dialed to the address" )
490-
491- // TODO: Add consts for status values.
492- if err := ctx .Status ("ok" ); err != nil {
493- logger .WithError (err ).Error ("proxy handler failed to send status response" )
494-
495- return err
496- }
497-
498- wg := new (sync.WaitGroup )
499- done := sync .OnceFunc (func () {
500- defer in .Close ()
501- defer rwc .Close ()
502-
503- logger .Trace ("close called on in and out connections" )
504- })
505-
506- wg .Add (1 )
507- go func () {
508- defer done ()
509- defer wg .Done ()
510-
511- if _ , err := io .Copy (in , rwc ); err != nil && err != io .EOF {
512- logger .WithError (err ).Error ("proxy handler copy from rwc to in failed" )
513- }
514- }()
456+ logger .Info ("Server connection established" )
515457
516- wg .Add (1 )
517- go func () {
518- defer done ()
519- defer wg .Done ()
458+ a .listening <- true
520459
521- if _ , err := io . Copy ( rwc , in ); err != nil && err != io . EOF {
522- logger .WithError (err ).Error ("proxy handler copy from in to rwc failed " )
460+ if err := tun . Listen ( ctx , listener ); err != nil {
461+ logger .WithError (err ).Error ("Tunnel listener exited with error " )
523462 }
524- }()
525-
526- logger .WithError (err ).Info ("proxy handler waiting for data pipe" )
527-
528- wg .Wait ()
529-
530- logger .WithError (err ).Info ("proxy handler done" )
531-
532- return nil
533- }
534- }
535-
536- func sshHandler (agent * Agent ) tunnel.Handler {
537- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
538- defer rwc .Close ()
539-
540- headers , err := ctx .Headers ()
541- if err != nil {
542- log .WithError (err ).Error ("failed to get the headers from the connection" )
543-
544- return err
545- }
546-
547- id := headers ["id" ]
548-
549- conn , ok := rwc .(net.Conn )
550- if ! ok {
551- log .Error ("failed to cast the ReadWriteCloser to net.Conn" )
552-
553- return errors .New ("failed to cast the ReadWriteCloser to net.Conn" )
554- }
555463
556- agent .server .Sessions .Store (id , conn )
557- agent .server .HandleConn (conn )
558-
559- return nil
560- }
561- }
562-
563- func sshCloseHandler (agent * Agent ) tunnel.Handler {
564- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
565- defer rwc .Close ()
566-
567- headers , err := ctx .Headers ()
568- if err != nil {
569- log .WithError (err ).Error ("failed to get the headers from the connection" )
570-
571- return err
464+ a .listening <- false
572465 }
466+ }()
573467
574- id := headers ["id" ]
575-
576- agent .server .CloseSession (id )
577-
578- log .WithFields (
579- log.Fields {
580- "id" : id ,
581- "version" : AgentVersion ,
582- "tenant_id" : agent .authData .Namespace ,
583- "server_address" : agent .config .ServerAddress ,
584- },
585- ).Info ("A tunnel connection was closed" )
468+ <- ctx .Done ()
586469
587- return nil
588- }
470+ return a .Close ()
589471}
590472
591- const (
592- // HandleSSHOpen is the protocol used to open a new SSH connection.
593- HandleSSHOpen = "/ssh/open/1.0.0"
594- // HandleSSHClose is the protocol used to close an existing SSH connection.
595- HandleSSHClose = "/ssh/close/1.0.0"
596- // HandleHTTPProxy is the protocol used to open a new HTTP proxy connection.
597- HandleHTTPProxy = "/http/proxy/1.0.0"
598- )
599-
600- // Listen creates the SSH server and listening for connections.
601- func (a * Agent ) Listen (ctx context.Context ) error {
602- a .mode .Serve (a )
603-
604- a .tunnel = tunnel .NewTunnel ()
473+ func (a * Agent ) listenV2 (ctx context.Context ) error {
474+ tun := tunnel .NewTunnelV2 (a .cli )
605475
606- a . tunnel . Handle (HandleSSHOpen , sshHandler (a ))
607- a . tunnel . Handle (HandleSSHClose , sshCloseHandler (a ))
608- a . tunnel . Handle (HandleHTTPProxy , httpProxyHandler (a ))
476+ tun . Handle (HandleSSHOpenV2 , sshHandlerV2 (a ))
477+ tun . Handle (HandleSSHCloseV2 , sshCloseHandlerV2 (a ))
478+ tun . Handle (HandleHTTPProxyV2 , httpProxyHandlerV2 (a ))
609479
610480 go a .ping (ctx , AgentPingDefaultInterval ) //nolint:errcheck
611481
612482 logger := log .WithFields (log.Fields {
613- "version" : AgentVersion ,
614- "tenant_id" : a .authData .Namespace ,
615- "server_address" : a .config .ServerAddress ,
616- "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
617- "api_endpoint" : a .serverInfo .Endpoints .API ,
618- "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
483+ "version" : AgentVersion ,
484+ "tenant_id" : a .authData .Namespace ,
485+ "server_address" : a .config .ServerAddress ,
486+ "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
487+ "api_endpoint" : a .serverInfo .Endpoints .API ,
488+ "connection_version" : a .config .ConnectionVersion ,
489+ "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
619490 })
620491
621492 ctx , cancel := context .WithCancel (ctx )
@@ -629,24 +500,30 @@ func (a *Agent) Listen(ctx context.Context) error {
629500 return
630501 }
631502
632- DefaultShellHubConnectPath := "/connection"
503+ ShellHubConnectV2Path := "/connection"
504+
505+ logger .Debug ("Using tunnel version 2" )
633506
634- conn , err := a .cli .Connect (ctx , a .authData .Token , DefaultShellHubConnectPath )
507+ listener , err := a .cli .NewReverseListenerV2 (
508+ ctx ,
509+ a .authData .Token ,
510+ ShellHubConnectV2Path ,
511+ client .NewReverseV2ConfigFromMap (a .authData .Config ),
512+ )
635513 if err != nil {
636514 logger .Error ("Failed to connect to server through reverse tunnel. Retry in 10 seconds" )
637515
638516 time .Sleep (time .Second * 10 )
639517
640518 continue
641519 }
642-
643- a .conn = conn
520+ a .listener .Store (& listener )
644521
645522 logger .Info ("Server connection established" )
646523
647524 a .listening <- true
648525
649- if err := a . tunnel . Listen (conn , tunnel . NewConfigFromMap ( a . authData . Config ) ); err != nil {
526+ if err := tun . Listen (ctx , listener ); err != nil {
650527 logger .WithError (err ).Error ("Tunnel listener exited with error" )
651528 }
652529
0 commit comments