@@ -289,6 +289,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
289289 if err != nil {
290290 return conn , err
291291 }
292+
292293 greeting := conn .Greeting ()
293294 if greeting .Salt == "" {
294295 conn .Close ()
@@ -309,7 +310,7 @@ func (d AuthDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
309310 }
310311 }
311312
312- if err := authenticate (conn , d .Auth , d .Username , d .Password ,
313+ if err := authenticate (ctx , conn , d .Auth , d .Username , d .Password ,
313314 conn .Greeting ().Salt ); err != nil {
314315 conn .Close ()
315316 return nil , fmt .Errorf ("failed to authenticate: %w" , err )
@@ -340,7 +341,7 @@ func (d ProtocolDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
340341 protocolInfo : d .RequiredProtocolInfo ,
341342 }
342343
343- protocolConn .protocolInfo , err = identify (& protocolConn )
344+ protocolConn .protocolInfo , err = identify (ctx , & protocolConn )
344345 if err != nil {
345346 protocolConn .Close ()
346347 return nil , fmt .Errorf ("failed to identify: %w" , err )
@@ -372,11 +373,12 @@ func (d GreetingDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) {
372373 greetingConn := greetingConn {
373374 Conn : conn ,
374375 }
375- version , salt , err := readGreeting (greetingConn )
376+ version , salt , err := readGreeting (ctx , & greetingConn )
376377 if err != nil {
377378 greetingConn .Close ()
378379 return nil , fmt .Errorf ("failed to read greeting: %w" , err )
379380 }
381+
380382 greetingConn .greeting = Greeting {
381383 Version : version ,
382384 Salt : salt ,
@@ -410,31 +412,62 @@ func parseAddress(address string) (string, string) {
410412 return network , address
411413}
412414
415+ // ioWaiter waits in a background until an io operation done or a context
416+ // is expired. It closes the connection and writes a context error into the
417+ // output channel if a context expired.
418+ func ioWaiter (ctx context.Context , conn Conn , done <- chan struct {}) <- chan error {
419+ doneWait := make (chan error , 1 )
420+
421+ go func () {
422+ defer close (doneWait )
423+
424+ select {
425+ case <- ctx .Done ():
426+ conn .Close ()
427+ <- done
428+ doneWait <- ctx .Err ()
429+ case <- done :
430+ doneWait <- nil
431+ }
432+ }()
433+
434+ return doneWait
435+ }
436+
413437// readGreeting reads a greeting message.
414- func readGreeting (reader io. Reader ) (string , string , error ) {
438+ func readGreeting (ctx context. Context , conn Conn ) (string , string , error ) {
415439 var version , salt string
416440
441+ doneRead := make (chan struct {})
442+ doneWait := ioWaiter (ctx , conn , doneRead )
443+
417444 data := make ([]byte , 128 )
418- _ , err := io .ReadFull (reader , data )
445+ _ , err := io .ReadFull (conn , data )
446+ close (doneRead )
447+
419448 if err == nil {
420449 version = bytes .NewBuffer (data [:64 ]).String ()
421450 salt = bytes .NewBuffer (data [64 :108 ]).String ()
422451 }
423452
453+ if waitErr := <- doneWait ; waitErr != nil {
454+ err = waitErr
455+ }
456+
424457 return version , salt , err
425458}
426459
427460// identify sends info about client protocol, receives info
428461// about server protocol in response and stores it in the connection.
429- func identify (conn Conn ) (ProtocolInfo , error ) {
462+ func identify (ctx context. Context , conn Conn ) (ProtocolInfo , error ) {
430463 var info ProtocolInfo
431464
432465 req := NewIdRequest (clientProtocolInfo )
433- if err := writeRequest (conn , req ); err != nil {
466+ if err := writeRequest (ctx , conn , req ); err != nil {
434467 return info , err
435468 }
436469
437- resp , err := readResponse (conn , req )
470+ resp , err := readResponse (ctx , conn , req )
438471 if err != nil {
439472 if resp != nil &&
440473 resp .Header ().Error == iproto .ER_UNKNOWN_REQUEST_TYPE {
@@ -495,7 +528,7 @@ func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error {
495528}
496529
497530// authenticate authenticates for a connection.
498- func authenticate (c Conn , auth Auth , user string , pass string , salt string ) error {
531+ func authenticate (ctx context. Context , c Conn , auth Auth , user , pass , salt string ) error {
499532 var req Request
500533 var err error
501534
@@ -511,37 +544,74 @@ func authenticate(c Conn, auth Auth, user string, pass string, salt string) erro
511544 return errors .New ("unsupported method " + auth .String ())
512545 }
513546
514- if err = writeRequest (c , req ); err != nil {
547+ if err = writeRequest (ctx , c , req ); err != nil {
515548 return err
516549 }
517- if _ , err = readResponse (c , req ); err != nil {
550+ if _ , err = readResponse (ctx , c , req ); err != nil {
518551 return err
519552 }
520553 return nil
521554}
522555
523556// writeRequest writes a request to the writer.
524- func writeRequest (w writeFlusher , req Request ) error {
557+ func writeRequest (ctx context. Context , conn Conn , req Request ) error {
525558 var packet smallWBuf
526559 err := pack (& packet , msgpack .NewEncoder (& packet ), 0 , req , ignoreStreamId , nil )
527560
528561 if err != nil {
529562 return fmt .Errorf ("pack error: %w" , err )
530563 }
531- if _ , err = w .Write (packet .b ); err != nil {
564+
565+ doneWrite := make (chan struct {})
566+ doneWait := ioWaiter (ctx , conn , doneWrite )
567+
568+ _ , err = conn .Write (packet .b )
569+ close (doneWrite )
570+
571+ if waitErr := <- doneWait ; waitErr != nil {
572+ err = waitErr
573+ }
574+
575+ if err != nil {
532576 return fmt .Errorf ("write error: %w" , err )
533577 }
534- if err = w .Flush (); err != nil {
578+
579+ doneWrite = make (chan struct {})
580+ doneWait = ioWaiter (ctx , conn , doneWrite )
581+
582+ err = conn .Flush ()
583+ close (doneWrite )
584+
585+ if waitErr := <- doneWait ; waitErr != nil {
586+ err = waitErr
587+ }
588+
589+ if err != nil {
535590 return fmt .Errorf ("flush error: %w" , err )
536591 }
592+
593+ if waitErr := <- doneWait ; waitErr != nil {
594+ err = waitErr
595+ }
596+
537597 return err
538598}
539599
540600// readResponse reads a response from the reader.
541- func readResponse (r io. Reader , req Request ) (Response , error ) {
601+ func readResponse (ctx context. Context , conn Conn , req Request ) (Response , error ) {
542602 var lenbuf [packetLengthBytes ]byte
543603
544- respBytes , err := read (r , lenbuf [:])
604+ doneRead := make (chan struct {})
605+ doneWait := ioWaiter (ctx , conn , doneRead )
606+
607+ respBytes , err := read (conn , lenbuf [:])
608+
609+ close (doneRead )
610+
611+ if waitErr := <- doneWait ; waitErr != nil {
612+ err = waitErr
613+ }
614+
545615 if err != nil {
546616 return nil , fmt .Errorf ("read error: %w" , err )
547617 }
@@ -555,10 +625,12 @@ func readResponse(r io.Reader, req Request) (Response, error) {
555625 if err != nil {
556626 return nil , fmt .Errorf ("decode response header error: %w" , err )
557627 }
628+
558629 resp , err := req .Response (header , & buf )
559630 if err != nil {
560631 return nil , fmt .Errorf ("creating response error: %w" , err )
561632 }
633+
562634 _ , err = resp .Decode ()
563635 if err != nil {
564636 switch err .(type ) {
@@ -568,5 +640,6 @@ func readResponse(r io.Reader, req Request) (Response, error) {
568640 return resp , fmt .Errorf ("decode response body error: %w" , err )
569641 }
570642 }
643+
571644 return resp , nil
572645}
0 commit comments