Skip to content

Commit 55c2f0a

Browse files
committed
upstream 支持 quic 和 kcp
1 parent 8530df6 commit 55c2f0a

File tree

10 files changed

+186
-59
lines changed

10 files changed

+186
-59
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.vscode
22

3-
config.json
43
config.toml
5-
/gateway
4+
/gateway*
5+
/kcp*
6+
/quic*
67
/logs/

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@
1010
- log
1111
- KCP 的 data_shards 和 parity_Shards
1212
- QUIC 的 application_protocols
13+
- pid_file
14+
15+
### 顶层配置
16+
17+
| 配置 | 类型 | 备注 |
18+
| -------- | ------ | -------- |
19+
| pid_file | string | pid 文件 |
20+
21+
> pid_file 在非 windows 平台默认会写入 /var/run/mc-gateway.pid,
22+
> 在 windows 平台默认不会写入任何文件
1323
1424
### hosts
1525

16-
hosts 使用期望的 host 做 key,转发的目的地址为 value。参考`config.example.toml`
26+
hosts 使用期望的 host 做 key,转发的目的地址为 value。参考`config.example.toml`默认的 fallback host 配置 key 为 `default`
1727

1828
### log
1929

cmd/gateway/config.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type (
2727
Kcp KcpConfig `toml:"kcp"`
2828
Hosts map[string]string `toml:"hosts"`
2929
Log LogConfig `toml:"log"`
30+
PidFile string `toml:"pid_file"`
3031
}
3132

3233
ProtocolConfig struct {
@@ -42,9 +43,9 @@ type (
4243
}
4344

4445
QuicConfig struct {
45-
Enable bool `toml:"enable"`
46-
Port int `toml:"port"`
47-
ApplicionProtocols []string `toml:"application_protocols"`
46+
Enable bool `toml:"enable"`
47+
Port int `toml:"port"`
48+
ApplicationProtocols []string `toml:"application_protocols"`
4849
}
4950

5051
LogConfig struct {
@@ -72,6 +73,8 @@ func loadConfig() error {
7273
return err
7374
}
7475

76+
writePIDFile()
77+
7578
return loadLogger()
7679
}
7780

cmd/gateway/kcp.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"fmt"
5+
"net"
56
"sync"
67

78
"github.com/rs/zerolog/log"
@@ -39,3 +40,15 @@ func runKcp(wg *sync.WaitGroup) {
3940
go handleRequest(conn)
4041
}
4142
}
43+
44+
func upstreamKcp(host string) net.Conn {
45+
conn, err := kcp.DialWithOptions(host, nil, config.Kcp.DataShards, config.Kcp.ParityShards)
46+
if err != nil {
47+
log.Error().Err(err).
48+
Msg("Failed to dial KCP server")
49+
}
50+
defer conn.Close()
51+
52+
conn.SetACKNoDelay(true)
53+
return conn
54+
}

cmd/gateway/main.go

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"io"
66
"net"
7+
"strings"
78
"sync"
89
"time"
910

@@ -12,15 +13,15 @@ import (
1213
)
1314

1415
func main() {
15-
if err := writePIDFile(); err != nil {
16-
panic(fmt.Sprintf("Failed to write PID file: %v", err))
17-
}
18-
defer removePIDFile()
19-
2016
if err := loadConfig(); err != nil {
2117
panic(err)
2218
}
2319

20+
if err := writePIDFile(); err != nil {
21+
log.Err(err).Msg("Failed to write PID file")
22+
}
23+
defer removePIDFile()
24+
2425
watcher := watchConfig()
2526
defer watcher.Close()
2627

@@ -70,6 +71,7 @@ func runTcp(wg *sync.WaitGroup) {
7071
log.Err(err).Msg("Error accepting")
7172
continue
7273
}
74+
setSocketOptions(conn)
7375
// 处理连接
7476
go handleRequest(conn)
7577
}
@@ -96,23 +98,47 @@ func handleRequest(conn net.Conn) {
9698
// 确保连接关闭
9799
defer conn.Close()
98100

99-
setSocketOptions(conn)
101+
client := mapToHost(conn)
102+
if client == nil {
103+
return
104+
}
105+
defer client.Close()
100106

107+
var wg sync.WaitGroup
108+
wg.Add(1)
109+
110+
go handleRead(client, conn, &wg)
111+
handleWrite(client, conn, nil)
112+
113+
// 等待所有读写操作完成
114+
// 不放在 defer 中,以防报错时无法关闭连接
115+
wg.Wait()
116+
}
117+
118+
func mapToHost(conn net.Conn) net.Conn {
101119
buf := make([]byte, 1024)
102120
n, err := conn.Read(buf)
103121
if err != nil {
104122
log.Err(err).
105123
Str("client", conn.RemoteAddr().String()).
106-
Msg("Error reading hostname")
107-
return
124+
Msg("failed to reading hostname")
125+
return nil
108126
}
109127
if n == 0 {
110128
log.Err(errEmptyBuffer).
111129
Str("client", conn.RemoteAddr().String()).
112-
Msg("Error: buffer is empty")
113-
return
130+
Msg("buffer is empty")
131+
return nil
114132
}
133+
115134
mc_host := protocol.GetMcHost(buf[:n])
135+
if mc_host == "" {
136+
log.Err(errEmptyBuffer).
137+
Str("client", conn.RemoteAddr().String()).
138+
Msg("failed to parse mc host from buffer")
139+
return nil
140+
}
141+
116142
host, ok := config.Hosts[mc_host]
117143
if !ok {
118144
host = config.Hosts["default"]
@@ -122,7 +148,7 @@ func handleRequest(conn net.Conn) {
122148
Str("client", conn.RemoteAddr().String()).
123149
Str("host", mc_host).
124150
Msg("failed to route host")
125-
return
151+
return nil
126152
}
127153

128154
log.Info().
@@ -131,27 +157,33 @@ func handleRequest(conn net.Conn) {
131157
Str("mc", host).
132158
Msg("map to host")
133159

134-
client, err := net.Dial("tcp", host)
135-
if err != nil {
136-
log.Err(err).Msg("Error dialing")
137-
return
160+
var client net.Conn
161+
162+
if host, ok := strings.CutPrefix(host, "quic://"); ok {
163+
client = upstreamQuic(host)
164+
} else if host, ok := strings.CutPrefix(host, "kcp://"); ok {
165+
client = upstreamKcp(host)
166+
} else {
167+
client = upstreamTcp(host)
168+
}
169+
if client == nil {
170+
return nil
138171
}
139-
defer client.Close()
140-
setSocketOptions(client)
141172

142173
client.Write(buf[:n])
143-
// 不需要 buf 了,释放掉
144-
buf = nil
145174

146-
var wg sync.WaitGroup
147-
wg.Add(1)
175+
return client
176+
}
148177

149-
go handleRead(client, conn, &wg)
150-
handleWrite(client, conn, nil)
178+
func upstreamTcp(host string) net.Conn {
179+
conn, err := net.Dial("tcp", host)
180+
if err != nil {
181+
log.Err(err).Str("host", host).Msg("Error dialing upstream")
182+
return nil
183+
}
184+
setSocketOptions(conn)
185+
return conn
151186

152-
// 等待所有读写操作完成
153-
// 不放在 defer 中,以防报错时无法关闭连接
154-
wg.Wait()
155187
}
156188

157189
func handleRead(srv, cli net.Conn, wg *sync.WaitGroup) {

cmd/gateway/pid.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"os"
6+
)
7+
8+
var currentPidFile string
9+
10+
func writePIDFile() error {
11+
newPidFile := getPidFileFromConfig()
12+
if newPidFile == currentPidFile {
13+
return nil
14+
}
15+
16+
if currentPidFile != "" {
17+
removePIDFile()
18+
}
19+
20+
pid := os.Getpid()
21+
if err := os.WriteFile(newPidFile, []byte(fmt.Sprintf("%d\n", pid)), 0644); err != nil {
22+
return err
23+
}
24+
25+
currentPidFile = newPidFile
26+
return nil
27+
}
28+
29+
func removePIDFile() {
30+
os.Remove(currentPidFile)
31+
}

cmd/gateway/pid_unix.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,10 @@
33

44
package main
55

6-
import (
7-
"fmt"
8-
"os"
9-
)
10-
11-
func writePIDFile() error {
12-
pid := os.Getpid()
13-
return os.WriteFile("/dev/shm/mc-gateway.pid", []byte(fmt.Sprintf("%d\n", pid)), 0644)
14-
}
15-
16-
func removePIDFile() {
17-
os.Remove("/dev/shm/mc-gateway.pid")
6+
func getPidFileFromConfig() string {
7+
pidFile := config.PidFile
8+
if pidFile == "" {
9+
return "/dev/shm/mc-gateway.pid"
10+
}
11+
return pidFile
1812
}

cmd/gateway/pid_windows.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
package main
55

6-
func writePIDFile() error {
7-
return nil
8-
}
9-
10-
func removePIDFile() {
6+
func getPidFileFromConfig() string {
7+
return config.PidFile
118
}

cmd/gateway/quic.go

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ import (
1818
"github.com/rs/zerolog/log"
1919
)
2020

21-
type quicConn struct {
22-
quic.Connection
23-
quic.Stream
24-
}
21+
type (
22+
quicConn struct {
23+
quic.Connection
24+
quic.Stream
25+
}
26+
)
2527

2628
func runQuic(wg *sync.WaitGroup) {
2729
if wg != nil {
@@ -59,6 +61,34 @@ func runQuic(wg *sync.WaitGroup) {
5961
}
6062
}
6163

64+
func upstreamQuic(host string) net.Conn {
65+
tlsConf := &tls.Config{
66+
InsecureSkipVerify: true, // 跳过证书检查
67+
NextProtos: getQuicNextProtos(),
68+
}
69+
70+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
71+
defer cancel()
72+
73+
conn, err := quic.DialAddr(ctx, host, tlsConf, nil)
74+
if err != nil {
75+
log.Err(err).Str("host", host).Msg("Failed to dial QUIC")
76+
return nil
77+
}
78+
79+
stream, err := conn.OpenStream()
80+
if err != nil {
81+
log.Err(err).Str("host", host).Msg("Failed to open stream")
82+
return nil
83+
}
84+
log.Info().Str("host", host).Msg("QUIC stream opened")
85+
86+
return quicConn{
87+
Connection: conn,
88+
Stream: stream,
89+
}
90+
}
91+
6292
func handleQuicRequest(conn quic.Connection) {
6393
defer conn.CloseWithError(0, "Closing connection")
6494

@@ -118,14 +148,24 @@ func generateTLSConfig() (*tls.Config, error) {
118148
return nil, err
119149
}
120150

121-
nextProtos := config.Quic.ApplicionProtocols
122-
if len(nextProtos) == 0 {
123-
nextProtos = []string{"minecraft", "quic", "raw", "h3"} // 默认协议
124-
}
125-
126151
// 返回 tls.Config
127152
return &tls.Config{
128153
Certificates: []tls.Certificate{cert},
129-
NextProtos: nextProtos,
154+
NextProtos: getQuicNextProtos(),
130155
}, nil
131156
}
157+
158+
func getQuicNextProtos() []string {
159+
nextProtos := config.Quic.ApplicationProtocols
160+
if len(nextProtos) == 0 {
161+
return []string{"minecraft", "quic", "raw", "h3"} // 默认协议
162+
}
163+
return nextProtos
164+
}
165+
166+
func (c quicConn) Close() error {
167+
if err := c.Stream.Close(); err != nil {
168+
log.Err(err).Msg("Failed to close QUIC stream")
169+
}
170+
return c.Connection.CloseWithError(0, "Closing QUIC connection")
171+
}

0 commit comments

Comments
 (0)