Skip to content

Commit 17061ce

Browse files
authored
Merge pull request #39 from scr-oath/tcp-io
Use the io wrappers for tcp as well - so other lua signatures such as read("*a") and read("*l) are possible.
2 parents d282366 + dda4e72 commit 17061ce

File tree

5 files changed

+107
-61
lines changed

5 files changed

+107
-61
lines changed

io/wrappers.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func IOWriterClose(L *lua.LState) int {
161161
L.Pop(L.GetTop())
162162
if closer, ok := writer.(io.Closer); ok {
163163
if err := closer.Close(); err != nil {
164-
L.RaiseError(err.Error())
164+
L.RaiseError("%v", err)
165165
}
166166
}
167167
return 0
@@ -188,7 +188,7 @@ func IOReaderRead(L *lua.LState) int {
188188
return 1
189189
}
190190
if err != nil {
191-
L.RaiseError(err.Error())
191+
L.RaiseError("%v", err)
192192
return 0
193193
}
194194
if numRead < num {
@@ -209,7 +209,7 @@ func IOReaderRead(L *lua.LState) int {
209209
return 1
210210
}
211211
if err != nil {
212-
L.RaiseError(err.Error())
212+
L.RaiseError("%v", err)
213213
return 0
214214
}
215215
L.Push(num)
@@ -221,7 +221,7 @@ func IOReaderRead(L *lua.LState) int {
221221
return 1
222222
}
223223
if err != nil {
224-
L.RaiseError(err.Error())
224+
L.RaiseError("%v", err)
225225
return 0
226226
}
227227
L.Push(lua.LString(data))
@@ -234,7 +234,7 @@ func IOReaderRead(L *lua.LState) int {
234234
return 1
235235
}
236236
if err != nil {
237-
L.RaiseError(err.Error())
237+
L.RaiseError("%v", err)
238238
return 0
239239
}
240240
L.Push(line)
@@ -251,7 +251,7 @@ func IOReaderClose(L *lua.LState) int {
251251
L.Pop(L.GetTop())
252252
if closer, ok := reader.(io.Closer); ok {
253253
if err := closer.Close(); err != nil {
254-
L.RaiseError(err.Error())
254+
L.RaiseError("%v", err)
255255
}
256256
}
257257
return 0

tcp/api.go

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
package tcp
33

44
import (
5-
"fmt"
5+
lio "github.com/vadv/gopher-lua-libs/io"
66
"net"
77
"time"
88

@@ -22,11 +22,15 @@ const (
2222

2323
type luaTCPClient struct {
2424
net.Conn
25-
address string
25+
address string
26+
dialTimeout time.Duration
27+
writeTimeout time.Duration
28+
readTimeout time.Duration
29+
closeTimeout time.Duration
2630
}
2731

2832
func (c *luaTCPClient) connect() error {
29-
conn, err := net.DialTimeout("tcp", c.address, DefaultDialTimeout)
33+
conn, err := net.DialTimeout("tcp", c.address, c.dialTimeout)
3034
if err != nil {
3135
return err
3236
}
@@ -46,7 +50,16 @@ func checkLuaTCPClient(L *lua.LState, n int) *luaTCPClient {
4650
// Open lua tcp.open(string) returns (tcp_client_ud, err)
4751
func Open(L *lua.LState) int {
4852
addr := L.CheckString(1)
49-
t := &luaTCPClient{address: addr}
53+
t := &luaTCPClient{
54+
address: addr,
55+
dialTimeout: DefaultDialTimeout,
56+
writeTimeout: DefaultWriteTimeout,
57+
readTimeout: DefaultReadTimeout,
58+
closeTimeout: DefaultCloseTimeout,
59+
}
60+
if dialTimeout, ok := L.Get(2).(lua.LNumber); ok {
61+
t.dialTimeout = time.Duration(dialTimeout * lua.LNumber(time.Second))
62+
}
5063
if err := t.connect(); err != nil {
5164
L.Push(lua.LNil)
5265
L.Push(lua.LString(err.Error()))
@@ -62,49 +75,24 @@ func Open(L *lua.LState) int {
6275
// Write lua tcp_client_ud:write() returns err
6376
func Write(L *lua.LState) int {
6477
conn := checkLuaTCPClient(L, 1)
65-
data := L.CheckString(2)
66-
conn.SetWriteDeadline(time.Now().Add(DefaultWriteTimeout))
67-
count, err := conn.Write([]byte(data))
68-
if err != nil {
69-
L.Push(lua.LString(fmt.Sprintf("write to `%s`: %s", conn.address, err.Error())))
70-
return 1
71-
}
72-
if count != len(data) {
73-
L.Push(lua.LString(fmt.Sprintf("write to `%s` get: %d except: %d", conn.address, count, len(data))))
74-
return 1
75-
}
76-
return 0
78+
_ = conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
79+
return lio.IOWriterWrite(L)
7780
}
7881

7982
// Read lua tcp_client_ud:read(max_size_int) returns (string, err)
8083
func Read(L *lua.LState) int {
8184
conn := checkLuaTCPClient(L, 1)
82-
count := int(1024)
83-
if L.GetTop() > 1 {
84-
count = int(L.CheckInt64(2))
85-
if count < 1 {
86-
L.ArgError(2, "must be > 1")
87-
}
85+
// Backward compatibility for callers that don't pass a length
86+
if L.GetTop() < 2 {
87+
L.Push(lua.LNumber(1024))
8888
}
89-
buf := make([]byte, count)
90-
conn.SetReadDeadline(time.Now().Add(DefaultReadTimeout))
91-
count, err := conn.Read(buf)
92-
if err != nil {
93-
L.Push(lua.LNil)
94-
L.Push(lua.LString(fmt.Sprintf("read from `%s`: %s", conn.address, err.Error())))
95-
return 2
96-
}
97-
line := string(buf[0:count])
98-
L.Push(lua.LString(line))
99-
return 1
89+
_ = conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
90+
return lio.IOReaderRead(L)
10091
}
10192

10293
// Close lua tcp_client_ud:close()
10394
func Close(L *lua.LState) int {
10495
conn := checkLuaTCPClient(L, 1)
105-
conn.SetDeadline(time.Now().Add(DefaultCloseTimeout))
106-
if conn != nil {
107-
conn.Close()
108-
}
109-
return 0
96+
_ = conn.SetDeadline(time.Now().Add(conn.closeTimeout))
97+
return lio.IOWriterClose(L)
11098
}

tcp/api_test.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package tcp
33
import (
44
"github.com/stretchr/testify/assert"
55
"github.com/stretchr/testify/require"
6-
"github.com/vadv/gopher-lua-libs/strings"
76
"github.com/vadv/gopher-lua-libs/tests"
87
"io"
98
"net"
@@ -54,9 +53,5 @@ func TestApi(t *testing.T) {
5453
})
5554
time.Sleep(time.Second)
5655

57-
preload := tests.SeveralPreloadFuncs(
58-
strings.Preload,
59-
Preload,
60-
)
61-
assert.NotZero(t, tests.RunLuaTestFile(t, preload, "./test/test_api.lua"))
56+
assert.NotZero(t, tests.RunLuaTestFile(t, Preload, "./test/test_api.lua"))
6257
}

tcp/loader.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tcp
22

33
import (
44
lua "github.com/yuin/gopher-lua"
5+
"time"
56
)
67

78
// Preload adds tcp to the given Lua state's package.preload table. After it
@@ -17,11 +18,53 @@ func Loader(L *lua.LState) int {
1718

1819
tcp_client_ud := L.NewTypeMetatable(`tcp_client_ud`)
1920
L.SetGlobal(`tcp_client_ud`, tcp_client_ud)
20-
L.SetField(tcp_client_ud, "__index", L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
21+
22+
funcs := L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
2123
"write": Write,
2224
"close": Close,
2325
"read": Read,
24-
}))
26+
})
27+
L.SetFuncs(tcp_client_ud, map[string]lua.LGFunction{
28+
"__index": func(state *lua.LState) int {
29+
conn := checkLuaTCPClient(L, 1)
30+
k := L.CheckString(2)
31+
var duration time.Duration
32+
switch k {
33+
case "dialTimeout":
34+
duration = conn.dialTimeout
35+
case "writeTimeout":
36+
duration = conn.writeTimeout
37+
case "readTimeout":
38+
duration = conn.readTimeout
39+
case "closeTimeout":
40+
duration = conn.closeTimeout
41+
default:
42+
L.Push(L.GetField(funcs, k))
43+
return 1
44+
}
45+
L.Push(lua.LNumber(duration) / lua.LNumber(time.Second))
46+
return 1
47+
},
48+
"__newindex": func(state *lua.LState) int {
49+
conn := checkLuaTCPClient(L, 1)
50+
k := L.CheckString(2)
51+
var pDuration *time.Duration
52+
switch k {
53+
case "dialTimeout":
54+
pDuration = &conn.dialTimeout
55+
case "writeTimeout":
56+
pDuration = &conn.writeTimeout
57+
case "readTimeout":
58+
pDuration = &conn.readTimeout
59+
case "closeTimeout":
60+
pDuration = &conn.closeTimeout
61+
default:
62+
return 0
63+
}
64+
*pDuration = time.Duration(L.CheckNumber(3) * lua.LNumber(time.Second))
65+
return 0
66+
},
67+
})
2568

2669
t := L.NewTable()
2770
L.SetFuncs(t, api)

tcp/test/test_api.lua

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
local strings = require("strings")
21
local tcp = require("tcp")
32

43
function Test_tcp(t)
@@ -7,12 +6,33 @@ function Test_tcp(t)
76
assert(not err, err)
87
t:Log("done: tcp:open()")
98

10-
err = conn:write("ping")
11-
assert(not err, err)
12-
t:Log("done: tcp_client_ud:write()")
9+
local function assert_equal(expected, got)
10+
assert(got == expected, string.format([[expected "%s": got "%s"]], expected, got))
11+
end
1312

14-
local result, err = conn:read()
15-
assert(not err, err)
16-
assert(strings.trim_space(result) == "pong", string.format([[expected "%s"; got "%s"]], "pong", result))
17-
t:Log("done: tcp_client_ud:read_line()")
18-
end
13+
t:Run("write ping read pong", function(t)
14+
err = conn:write("ping")
15+
assert(not err, err)
16+
t:Log("done: tcp_client_ud:write()")
17+
18+
local result, err = conn:read("*l")
19+
assert(not err, err)
20+
assert_equal("pong", result)
21+
t:Log("done: tcp_client_ud:read_line()")
22+
end)
23+
24+
t:Run("read timeout fields", function(t)
25+
assert_equal(5, conn.dialTimeout)
26+
assert_equal(1, conn.writeTimeout)
27+
assert_equal(1, conn.readTimeout)
28+
assert_equal(1, conn.closeTimeout)
29+
end)
30+
31+
t:Run('write/read timeout fields', function(t)
32+
-- Check setting fields
33+
conn.closeTimeout = 2
34+
assert_equal(2, conn.closeTimeout)
35+
conn.closeTimeout = 0.5
36+
assert_equal(0.5, conn.closeTimeout)
37+
end)
38+
end

0 commit comments

Comments
 (0)