44 "fmt"
55 "io"
66 "net"
7+ "strings"
78 "sync"
89 "time"
910
@@ -12,15 +13,15 @@ import (
1213)
1314
1415func main () {
15- if err := writePIDFile (); err != nil {
16- panic (fmt .Sprintf ("Failed to write PID file: %v" , err ))
17- }
18- defer removePIDFile ()
19-
2016 if err := loadConfig (); err != nil {
2117 panic (err )
2218 }
2319
20+ if err := writePIDFile (); err != nil {
21+ log .Err (err ).Msg ("Failed to write PID file" )
22+ }
23+ defer removePIDFile ()
24+
2425 watcher := watchConfig ()
2526 defer watcher .Close ()
2627
@@ -70,6 +71,7 @@ func runTcp(wg *sync.WaitGroup) {
7071 log .Err (err ).Msg ("Error accepting" )
7172 continue
7273 }
74+ setSocketOptions (conn )
7375 // 处理连接
7476 go handleRequest (conn )
7577 }
@@ -96,23 +98,47 @@ func handleRequest(conn net.Conn) {
9698 // 确保连接关闭
9799 defer conn .Close ()
98100
99- setSocketOptions (conn )
101+ client := mapToHost (conn )
102+ if client == nil {
103+ return
104+ }
105+ defer client .Close ()
100106
107+ var wg sync.WaitGroup
108+ wg .Add (1 )
109+
110+ go handleRead (client , conn , & wg )
111+ handleWrite (client , conn , nil )
112+
113+ // 等待所有读写操作完成
114+ // 不放在 defer 中,以防报错时无法关闭连接
115+ wg .Wait ()
116+ }
117+
118+ func mapToHost (conn net.Conn ) net.Conn {
101119 buf := make ([]byte , 1024 )
102120 n , err := conn .Read (buf )
103121 if err != nil {
104122 log .Err (err ).
105123 Str ("client" , conn .RemoteAddr ().String ()).
106- Msg ("Error reading hostname" )
107- return
124+ Msg ("failed to reading hostname" )
125+ return nil
108126 }
109127 if n == 0 {
110128 log .Err (errEmptyBuffer ).
111129 Str ("client" , conn .RemoteAddr ().String ()).
112- Msg ("Error: buffer is empty" )
113- return
130+ Msg ("buffer is empty" )
131+ return nil
114132 }
133+
115134 mc_host := protocol .GetMcHost (buf [:n ])
135+ if mc_host == "" {
136+ log .Err (errEmptyBuffer ).
137+ Str ("client" , conn .RemoteAddr ().String ()).
138+ Msg ("failed to parse mc host from buffer" )
139+ return nil
140+ }
141+
116142 host , ok := config .Hosts [mc_host ]
117143 if ! ok {
118144 host = config .Hosts ["default" ]
@@ -122,7 +148,7 @@ func handleRequest(conn net.Conn) {
122148 Str ("client" , conn .RemoteAddr ().String ()).
123149 Str ("host" , mc_host ).
124150 Msg ("failed to route host" )
125- return
151+ return nil
126152 }
127153
128154 log .Info ().
@@ -131,27 +157,33 @@ func handleRequest(conn net.Conn) {
131157 Str ("mc" , host ).
132158 Msg ("map to host" )
133159
134- client , err := net .Dial ("tcp" , host )
135- if err != nil {
136- log .Err (err ).Msg ("Error dialing" )
137- return
160+ var client net.Conn
161+
162+ if host , ok := strings .CutPrefix (host , "quic://" ); ok {
163+ client = upstreamQuic (host )
164+ } else if host , ok := strings .CutPrefix (host , "kcp://" ); ok {
165+ client = upstreamKcp (host )
166+ } else {
167+ client = upstreamTcp (host )
168+ }
169+ if client == nil {
170+ return nil
138171 }
139- defer client .Close ()
140- setSocketOptions (client )
141172
142173 client .Write (buf [:n ])
143- // 不需要 buf 了,释放掉
144- buf = nil
145174
146- var wg sync. WaitGroup
147- wg . Add ( 1 )
175+ return client
176+ }
148177
149- go handleRead (client , conn , & wg )
150- handleWrite (client , conn , nil )
178+ func upstreamTcp (host string ) net.Conn {
179+ conn , err := net .Dial ("tcp" , host )
180+ if err != nil {
181+ log .Err (err ).Str ("host" , host ).Msg ("Error dialing upstream" )
182+ return nil
183+ }
184+ setSocketOptions (conn )
185+ return conn
151186
152- // 等待所有读写操作完成
153- // 不放在 defer 中,以防报错时无法关闭连接
154- wg .Wait ()
155187}
156188
157189func handleRead (srv , cli net.Conn , wg * sync.WaitGroup ) {
0 commit comments