22package tcp
33
44import (
5- "fmt "
5+ lio "github.com/vadv/gopher-lua-libs/io "
66 "net"
77 "time"
88
@@ -22,11 +22,15 @@ const (
2222
2323type luaTCPClient struct {
2424 net.Conn
25- address string
25+ address string
26+ dialTimeout time.Duration
27+ writeTimeout time.Duration
28+ readTimeout time.Duration
29+ closeTimeout time.Duration
2630}
2731
2832func (c * luaTCPClient ) connect () error {
29- conn , err := net .DialTimeout ("tcp" , c .address , DefaultDialTimeout )
33+ conn , err := net .DialTimeout ("tcp" , c .address , c . dialTimeout )
3034 if err != nil {
3135 return err
3236 }
@@ -46,7 +50,16 @@ func checkLuaTCPClient(L *lua.LState, n int) *luaTCPClient {
4650// Open lua tcp.open(string) returns (tcp_client_ud, err)
4751func Open (L * lua.LState ) int {
4852 addr := L .CheckString (1 )
49- t := & luaTCPClient {address : addr }
53+ t := & luaTCPClient {
54+ address : addr ,
55+ dialTimeout : DefaultDialTimeout ,
56+ writeTimeout : DefaultWriteTimeout ,
57+ readTimeout : DefaultReadTimeout ,
58+ closeTimeout : DefaultCloseTimeout ,
59+ }
60+ if dialTimeout , ok := L .Get (2 ).(lua.LNumber ); ok {
61+ t .dialTimeout = time .Duration (dialTimeout * lua .LNumber (time .Second ))
62+ }
5063 if err := t .connect (); err != nil {
5164 L .Push (lua .LNil )
5265 L .Push (lua .LString (err .Error ()))
@@ -62,49 +75,24 @@ func Open(L *lua.LState) int {
6275// Write lua tcp_client_ud:write() returns err
6376func Write (L * lua.LState ) int {
6477 conn := checkLuaTCPClient (L , 1 )
65- data := L .CheckString (2 )
66- conn .SetWriteDeadline (time .Now ().Add (DefaultWriteTimeout ))
67- count , err := conn .Write ([]byte (data ))
68- if err != nil {
69- L .Push (lua .LString (fmt .Sprintf ("write to `%s`: %s" , conn .address , err .Error ())))
70- return 1
71- }
72- if count != len (data ) {
73- L .Push (lua .LString (fmt .Sprintf ("write to `%s` get: %d except: %d" , conn .address , count , len (data ))))
74- return 1
75- }
76- return 0
78+ _ = conn .SetWriteDeadline (time .Now ().Add (conn .writeTimeout ))
79+ return lio .IOWriterWrite (L )
7780}
7881
7982// Read lua tcp_client_ud:read(max_size_int) returns (string, err)
8083func Read (L * lua.LState ) int {
8184 conn := checkLuaTCPClient (L , 1 )
82- count := int (1024 )
83- if L .GetTop () > 1 {
84- count = int (L .CheckInt64 (2 ))
85- if count < 1 {
86- L .ArgError (2 , "must be > 1" )
87- }
85+ // Backward compatibility for callers that don't pass a length
86+ if L .GetTop () < 2 {
87+ L .Push (lua .LNumber (1024 ))
8888 }
89- buf := make ([]byte , count )
90- conn .SetReadDeadline (time .Now ().Add (DefaultReadTimeout ))
91- count , err := conn .Read (buf )
92- if err != nil {
93- L .Push (lua .LNil )
94- L .Push (lua .LString (fmt .Sprintf ("read from `%s`: %s" , conn .address , err .Error ())))
95- return 2
96- }
97- line := string (buf [0 :count ])
98- L .Push (lua .LString (line ))
99- return 1
89+ _ = conn .SetReadDeadline (time .Now ().Add (conn .readTimeout ))
90+ return lio .IOReaderRead (L )
10091}
10192
10293// Close lua tcp_client_ud:close()
10394func Close (L * lua.LState ) int {
10495 conn := checkLuaTCPClient (L , 1 )
105- conn .SetDeadline (time .Now ().Add (DefaultCloseTimeout ))
106- if conn != nil {
107- conn .Close ()
108- }
109- return 0
96+ _ = conn .SetDeadline (time .Now ().Add (conn .closeTimeout ))
97+ return lio .IOWriterClose (L )
11098}
0 commit comments