@@ -42,20 +42,16 @@ import (
4242 "context"
4343 "crypto/rsa"
4444 "fmt"
45- "io"
4645 "math/rand"
4746 "net"
48- "net/netip"
4947 "net/url"
5048 "os"
5149 "runtime"
5250 "strings"
53- "sync"
5451 "sync/atomic"
5552 "time"
5653
5754 "github.com/Masterminds/semver"
58- dockerclient "github.com/docker/docker/client"
5955 "github.com/pkg/errors"
6056 "github.com/shellhub-io/shellhub/agent/pkg/keygen"
6157 "github.com/shellhub-io/shellhub/agent/pkg/sysinfo"
@@ -110,6 +106,10 @@ type Config struct {
110106 // MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
111107 // before attempting to reconnect to the ShellHub server. Default is 60 seconds.
112108 MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`
109+
110+ // ConnectionVersion specifies the version of the connection protocol to use.
111+ // Supported values are 1 and 2. Default is 1.
112+ ConnectionVersion int `env:"CONNECTION_VERSION,default=1"`
113113}
114114
115115func LoadConfigFromEnv () (* Config , map [string ]interface {}, error ) {
@@ -159,12 +159,11 @@ type Agent struct {
159159 cli client.Client
160160 serverInfo * models.Info
161161 server * server.Server
162- tunnel * tunnel.Tunnel
163162 listening chan bool
164163 closed atomic.Bool
165164 mode Mode
166- // conn is the current connection to the server.
167- conn net.Conn
165+ // listener is the current connection to the server.
166+ listener atomic. Pointer [ net.Listener ]
168167}
169168
170169// NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant
@@ -370,237 +369,109 @@ func (a *Agent) isClosed() bool {
370369func (a * Agent ) Close () error {
371370 a .closed .Store (true )
372371
373- return a .conn .Close ()
374- }
375-
376- // httpProxyHandler handlers proxy connections to the required address.
377- func httpProxyHandler (agent * Agent ) tunnel.Handler {
378- const ProxyHandlerNetwork = "tcp"
372+ l := a .listener .Load ()
373+ if l == nil {
374+ return nil
375+ }
379376
380- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
381- headers , err := ctx .Headers ()
382- if err != nil {
383- log .WithError (err ).Error ("failed to get the headers from the connection" )
377+ return (* l ).Close ()
378+ }
384379
385- return err
386- }
380+ func ( a * Agent ) Listen ( ctx context. Context ) error {
381+ a . mode . Serve ( a )
387382
388- id := headers ["id" ]
389- host := headers ["host" ]
390- port := headers ["port" ]
383+ switch a .config .ConnectionVersion {
384+ case 1 :
385+ return a .listenV1 (ctx )
386+ case 2 :
387+ return a .listenV2 (ctx )
388+ default :
389+ return fmt .Errorf ("unsupported connection version: %d" , a .config .ConnectionVersion )
390+ }
391+ }
391392
392- logger := log .WithFields (log.Fields {
393- "id" : id ,
394- "host" : host ,
395- "port" : port ,
396- })
393+ func (a * Agent ) listenV1 (ctx context.Context ) error {
394+ tun := tunnel .NewTunnelV1 ()
397395
398- if _ , ok := agent .mode .(* ConnectorMode ); ok {
399- cli , err := dockerclient .NewClientWithOpts (dockerclient .FromEnv , dockerclient .WithAPIVersionNegotiation ())
400- if err != nil {
401- log .WithError (err ).Error ("failed to create the Docker client" )
396+ tun .Handle (HandleSSHOpenV1 , sshHandlerV1 (a ))
397+ tun .Handle (HandleSSHCloseV1 , sshCloseHandlerV1 (a ))
398+ tun .Handle (HandleHTTPProxyV1 , httpProxyHandlerV1 (a ))
402399
403- return ctx .Error (errors .New ("failed to connect to the Docker Engine" ))
404- }
405-
406- container , err := cli .ContainerInspect (context .Background (), agent .server .ContainerID )
407- if err != nil {
408- log .WithError (err ).Error ("failed to inspect the container" )
400+ go a .ping (ctx , AgentPingDefaultInterval ) //nolint:errcheck
409401
410- return ctx .Error (errors .New ("failed to inspect the container" ))
411- }
402+ logger := log .WithFields (log.Fields {
403+ "version" : AgentVersion ,
404+ "tenant_id" : a .authData .Namespace ,
405+ "server_address" : a .config .ServerAddress ,
406+ "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
407+ "api_endpoint" : a .serverInfo .Endpoints .API ,
408+ "connection_version" : a .config .ConnectionVersion ,
409+ "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
410+ })
412411
413- var target string
412+ ctx , cancel := context .WithCancel (ctx )
413+ go func () {
414+ for {
415+ if a .isClosed () {
416+ logger .Info ("Stopped listening for connections" )
414417
415- addr , err := netip .ParseAddr (host )
416- if err != nil {
417- log .WithError (err ).Error ("failed to parse the address on proxy" )
418+ cancel ()
418419
419- return ctx . Error ( errors . New ( "failed to parse the address on proxy" ))
420+ return
420421 }
421422
422- if addr .IsLoopback () {
423- log .Trace ("host is a loopback address, using the container IP address" )
424-
425- for _ , network := range container .NetworkSettings .Networks {
426- target = network .IPAddress
427-
428- break
429- }
430- } else {
431- for _ , network := range container .NetworkSettings .Networks {
432- subnet , err := netip .ParsePrefix (fmt .Sprintf ("%s/%d" , network .Gateway , network .IPPrefixLen ))
433- if err != nil {
434- logger .WithError (err ).Error ("failed to parse the gateway on proxy" )
435-
436- continue
437- }
438-
439- ip , err := netip .ParseAddr (host )
440- if err != nil {
441- logger .WithError (err ).Error ("failed to parse the address on proxy" )
423+ ShellHubConnectV1Path := "/ssh/connection"
442424
443- continue
444- }
425+ logger .Debug ("Using tunnel version 1" )
445426
446- if subnet .Contains (ip ) {
447- target = ip .String ()
427+ listener , err := a .cli .NewReverseListenerV1 (
428+ ctx ,
429+ a .authData .Token ,
430+ ShellHubConnectV1Path ,
431+ )
432+ if err != nil {
433+ logger .Error ("Failed to connect to server through reverse tunnel. Retry in 10 seconds" )
448434
449- break
450- }
451- }
452- }
435+ time .Sleep (time .Second * 10 )
453436
454- if target == "" {
455- return ctx .Error (errors .New ("address not found on the device" ))
437+ continue
456438 }
439+ a .listener .Store (& listener )
457440
458- host = target
459- }
460-
461- ErrFailedDialToAddressAndPort := errors .New ("failed to dial to the address and port" )
462-
463- logger .Trace ("proxy handler connecting to the address" )
464-
465- in , err := net .Dial (ProxyHandlerNetwork , net .JoinHostPort (host , port ))
466- if err != nil {
467- logger .WithError (err ).Error ("proxy handler failed to dial to the address" )
468-
469- return ctx .Error (ErrFailedDialToAddressAndPort )
470- }
471-
472- defer in .Close ()
473-
474- logger .Trace ("proxy handler dialed to the address" )
475-
476- // TODO: Add consts for status values.
477- if err := ctx .Status ("ok" ); err != nil {
478- logger .WithError (err ).Error ("proxy handler failed to send status response" )
479-
480- return err
481- }
482-
483- wg := new (sync.WaitGroup )
484- done := sync .OnceFunc (func () {
485- defer in .Close ()
486- defer rwc .Close ()
487-
488- logger .Trace ("close called on in and out connections" )
489- })
490-
491- wg .Add (1 )
492- go func () {
493- defer done ()
494- defer wg .Done ()
495-
496- if _ , err := io .Copy (in , rwc ); err != nil && err != io .EOF {
497- logger .WithError (err ).Error ("proxy handler copy from rwc to in failed" )
498- }
499- }()
441+ logger .Info ("Server connection established" )
500442
501- wg .Add (1 )
502- go func () {
503- defer done ()
504- defer wg .Done ()
443+ a .listening <- true
505444
506- if _ , err := io . Copy ( rwc , in ); err != nil && err != io . EOF {
507- logger .WithError (err ).Error ("proxy handler copy from in to rwc failed " )
445+ if err := tun . Listen ( ctx , listener ); err != nil {
446+ logger .WithError (err ).Error ("Tunnel listener exited with error " )
508447 }
509- }()
510-
511- logger .WithError (err ).Info ("proxy handler waiting for data pipe" )
512-
513- wg .Wait ()
514-
515- logger .WithError (err ).Info ("proxy handler done" )
516-
517- return nil
518- }
519- }
520-
521- func sshHandler (agent * Agent ) tunnel.Handler {
522- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
523- defer rwc .Close ()
524-
525- headers , err := ctx .Headers ()
526- if err != nil {
527- log .WithError (err ).Error ("failed to get the headers from the connection" )
528-
529- return err
530- }
531-
532- id := headers ["id" ]
533-
534- conn , ok := rwc .(net.Conn )
535- if ! ok {
536- log .Error ("failed to cast the ReadWriteCloser to net.Conn" )
537-
538- return errors .New ("failed to cast the ReadWriteCloser to net.Conn" )
539- }
540448
541- agent .server .Sessions .Store (id , conn )
542- agent .server .HandleConn (conn )
543-
544- return nil
545- }
546- }
547-
548- func sshCloseHandler (agent * Agent ) tunnel.Handler {
549- return func (ctx tunnel.Context , rwc io.ReadWriteCloser ) error {
550- defer rwc .Close ()
551-
552- headers , err := ctx .Headers ()
553- if err != nil {
554- log .WithError (err ).Error ("failed to get the headers from the connection" )
555-
556- return err
449+ a .listening <- false
557450 }
451+ }()
558452
559- id := headers ["id" ]
560-
561- agent .server .CloseSession (id )
562-
563- log .WithFields (
564- log.Fields {
565- "id" : id ,
566- "version" : AgentVersion ,
567- "tenant_id" : agent .authData .Namespace ,
568- "server_address" : agent .config .ServerAddress ,
569- },
570- ).Info ("A tunnel connection was closed" )
453+ <- ctx .Done ()
571454
572- return nil
573- }
455+ return a .Close ()
574456}
575457
576- const (
577- // HandleSSHOpen is the protocol used to open a new SSH connection.
578- HandleSSHOpen = "/ssh/open/1.0.0"
579- // HandleSSHClose is the protocol used to close an existing SSH connection.
580- HandleSSHClose = "/ssh/close/1.0.0"
581- // HandleHTTPProxy is the protocol used to open a new HTTP proxy connection.
582- HandleHTTPProxy = "/http/proxy/1.0.0"
583- )
584-
585- // Listen creates the SSH server and listening for connections.
586- func (a * Agent ) Listen (ctx context.Context ) error {
587- a .mode .Serve (a )
588-
589- a .tunnel = tunnel .NewTunnel ()
458+ func (a * Agent ) listenV2 (ctx context.Context ) error {
459+ tun := tunnel .NewTunnelV2 (a .cli )
590460
591- a . tunnel . Handle (HandleSSHOpen , sshHandler (a ))
592- a . tunnel . Handle (HandleSSHClose , sshCloseHandler (a ))
593- a . tunnel . Handle (HandleHTTPProxy , httpProxyHandler (a ))
461+ tun . Handle (HandleSSHOpenV2 , sshHandlerV2 (a ))
462+ tun . Handle (HandleSSHCloseV2 , sshCloseHandlerV2 (a ))
463+ tun . Handle (HandleHTTPProxyV2 , httpProxyHandlerV2 (a ))
594464
595465 go a .ping (ctx , AgentPingDefaultInterval ) //nolint:errcheck
596466
597467 logger := log .WithFields (log.Fields {
598- "version" : AgentVersion ,
599- "tenant_id" : a .authData .Namespace ,
600- "server_address" : a .config .ServerAddress ,
601- "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
602- "api_endpoint" : a .serverInfo .Endpoints .API ,
603- "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
468+ "version" : AgentVersion ,
469+ "tenant_id" : a .authData .Namespace ,
470+ "server_address" : a .config .ServerAddress ,
471+ "ssh_endpoint" : a .serverInfo .Endpoints .SSH ,
472+ "api_endpoint" : a .serverInfo .Endpoints .API ,
473+ "connection_version" : a .config .ConnectionVersion ,
474+ "sshid" : fmt .Sprintf ("%s.%s@%s" , a .authData .Namespace , a .authData .Name , strings .Split (a .serverInfo .Endpoints .SSH , ":" )[0 ]),
604475 })
605476
606477 ctx , cancel := context .WithCancel (ctx )
@@ -614,24 +485,30 @@ func (a *Agent) Listen(ctx context.Context) error {
614485 return
615486 }
616487
617- DefaultShellHubConnectPath := "/connection"
488+ ShellHubConnectV2Path := "/connection"
489+
490+ logger .Debug ("Using tunnel version 2" )
618491
619- conn , err := a .cli .Connect (ctx , a .authData .Token , DefaultShellHubConnectPath )
492+ listener , err := a .cli .NewReverseListenerV2 (
493+ ctx ,
494+ a .authData .Token ,
495+ ShellHubConnectV2Path ,
496+ client .NewReverseV2ConfigFromMap (a .authData .Config ),
497+ )
620498 if err != nil {
621499 logger .Error ("Failed to connect to server through reverse tunnel. Retry in 10 seconds" )
622500
623501 time .Sleep (time .Second * 10 )
624502
625503 continue
626504 }
627-
628- a .conn = conn
505+ a .listener .Store (& listener )
629506
630507 logger .Info ("Server connection established" )
631508
632509 a .listening <- true
633510
634- if err := a . tunnel . Listen (conn , tunnel . NewConfigFromMap ( a . authData . Config ) ); err != nil {
511+ if err := tun . Listen (ctx , listener ); err != nil {
635512 logger .WithError (err ).Error ("Tunnel listener exited with error" )
636513 }
637514
0 commit comments