@@ -14,19 +14,20 @@ type dialFunc = func(ctx context.Context, network, addr string) (net.Conn, error
1414
1515const idLen = 8
1616
17- type onerequester interface {
17+ type Onerequester interface {
1818 RequestAndRecv (sendBytes []byte ) ([]byte , error )
1919 Close () error
2020 SetDialer (dialer dialFunc ) error
2121}
2222
2323type Requester struct {
24- parent onerequester
25- mtu uint
24+ createRequester func () (Onerequester , error )
25+ mtu uint
26+ dialer dialFunc
2627}
2728
28- func NewRequester (parent onerequester , mtu uint ) (* Requester , error ) {
29- return & Requester {parent : parent , mtu : mtu }, nil
29+ func NewRequester (createRequester func () ( Onerequester , error ) , mtu uint ) (* Requester , error ) {
30+ return & Requester {createRequester : createRequester , mtu : mtu }, nil
3031}
3132
3233func (r * Requester ) RequestAndRecv (sendBytes []byte ) ([]byte , error ) {
@@ -39,32 +40,71 @@ func (r *Requester) RequestAndRecv(sendBytes []byte) ([]byte, error) {
3940
4041 parts := splitIntoChunks (sendBytes , int (r .mtu ))
4142
42- for i , partBytes := range parts {
43- toSend := & pb.DnsPartReq {Id : id [:], PartNum : proto .Uint32 (uint32 (i )), TotalParts : proto .Uint32 (uint32 (len (parts ))), Data : partBytes }
44- toSendBytes , err := proto .Marshal (toSend )
45- if err != nil {
46- return nil , fmt .Errorf ("error marshal part %v: %v" , i , err )
47- }
43+ resCh := make (chan []byte , len (parts ))
44+ errCh := make (chan error , len (parts ))
45+ waitCh := make (chan struct {}, len (parts ))
4846
49- respBytes , err := r .parent .RequestAndRecv (toSendBytes )
50- if err != nil {
51- return nil , fmt .Errorf ("error request part %v: %v" , i , err )
52- }
47+ for i , partBytes := range parts {
48+ i := i
49+ partBytes := partBytes
50+ go func () {
51+ toSend := & pb.DnsPartReq {Id : id [:], PartNum : proto .Uint32 (uint32 (i )), TotalParts : proto .Uint32 (uint32 (len (parts ))), Data : partBytes }
52+ toSendBytes , err := proto .Marshal (toSend )
53+ if err != nil {
54+ errCh <- fmt .Errorf ("error marshal part %v: %v" , i , err )
55+ return
56+ }
57+
58+ req , err := r .createRequester ()
59+ if err != nil {
60+ errCh <- fmt .Errorf ("error creating requester in part %v: %v" , i , err )
61+ return
62+ }
63+
64+ if r .dialer != nil {
65+ err = req .SetDialer (r .dialer )
66+ if err != nil {
67+ errCh <- fmt .Errorf ("error setting dialer in part %v: %v" , i , err )
68+ return
69+ }
70+ }
71+
72+ respBytes , err := req .RequestAndRecv (toSendBytes )
73+ if err != nil {
74+ errCh <- fmt .Errorf ("error request part %v: %v" , i , err )
75+ return
76+ }
77+
78+ resp := & pb.DnsPartResp {}
79+ err = proto .Unmarshal (respBytes , resp )
80+ if err != nil {
81+ errCh <- fmt .Errorf ("error unmarshal response: %v" , err )
82+ return
83+ }
84+
85+ if resp .GetWaiting () {
86+ waitCh <- struct {}{}
87+ return
88+ }
89+
90+ resCh <- resp .GetData ()
91+
92+ }()
93+ }
5394
54- resp := & pb.DnsPartResp {}
55- err = proto .Unmarshal (respBytes , resp )
56- if err != nil {
57- return nil , fmt .Errorf ("error unmarshal response: %v" , err )
58- }
95+ errs := []error {}
5996
60- if resp .GetWaiting () {
61- continue
97+ for range parts {
98+ select {
99+ case res := <- resCh :
100+ return res , nil
101+ case err = <- errCh :
102+ errs = append (errs , err )
103+ case <- waitCh :
62104 }
63-
64- return resp .GetData (), nil
65105 }
66106
67- return nil , fmt .Errorf ("no response" )
107+ return nil , fmt .Errorf ("errors occurred: %v" , errs )
68108}
69109
70110func splitIntoChunks (data []byte , mtu int ) [][]byte {
@@ -81,10 +121,11 @@ func splitIntoChunks(data []byte, mtu int) [][]byte {
81121
82122// Close closes the parent transport
83123func (r * Requester ) Close () error {
84- return r . parent . Close ()
124+ return nil
85125}
86126
87127// SetDialer sets the parent dialer
88128func (r * Requester ) SetDialer (dialer dialFunc ) error {
89- return r .parent .SetDialer (dialer )
129+ r .dialer = dialer
130+ return nil
90131}
0 commit comments