Skip to content

Commit 929d131

Browse files
committed
refactor(proxy): Optimize hop-by-hop header handling and routing logic
- Simplify hop-by-hop headers initialization using map literal - Create a local copy of hop headers to improve header filtering - Enhance routing logic with context-based timeout for alternative target checks - Improve error handling and logging in file size and routing detection - Reduce unnecessary goroutine complexity in target URL selection
1 parent ec07ae0 commit 929d131

File tree

2 files changed

+74
-74
lines changed

2 files changed

+74
-74
lines changed

internal/handler/proxy.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,20 @@ const (
2828
)
2929

3030
// 添加 hop-by-hop 头部映射
31-
var hopHeadersMap = make(map[string]bool)
31+
var hopHeadersBase = map[string]bool{
32+
"Connection": true,
33+
"Keep-Alive": true,
34+
"Proxy-Authenticate": true,
35+
"Proxy-Authorization": true,
36+
"Proxy-Connection": true,
37+
"Te": true,
38+
"Trailer": true,
39+
"Transfer-Encoding": true,
40+
"Upgrade": true,
41+
}
3242

3343
func init() {
34-
headers := []string{
35-
"Connection",
36-
"Keep-Alive",
37-
"Proxy-Authenticate",
38-
"Proxy-Authorization",
39-
"Proxy-Connection",
40-
"Te",
41-
"Trailer",
42-
"Transfer-Encoding",
43-
"Upgrade",
44-
}
45-
for _, h := range headers {
46-
hopHeadersMap[h] = true
47-
}
44+
// 移除旧的初始化代码,因为我们直接在 map 字面量中定义了所有值
4845
}
4946

5047
// ErrorHandler 定义错误处理函数类型
@@ -337,16 +334,22 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
337334
}
338335

339336
func copyHeader(dst, src http.Header) {
337+
// 创建一个新的局部 map,复制基础 hop headers
338+
hopHeaders := make(map[string]bool, len(hopHeadersBase))
339+
for k, v := range hopHeadersBase {
340+
hopHeaders[k] = v
341+
}
342+
340343
// 处理 Connection 头部指定的其他 hop-by-hop 头部
341344
if connection := src.Get("Connection"); connection != "" {
342345
for _, h := range strings.Split(connection, ",") {
343-
hopHeadersMap[strings.TrimSpace(h)] = true
346+
hopHeaders[strings.TrimSpace(h)] = true
344347
}
345348
}
346349

347-
// 使用 map 快速查找,跳过 hop-by-hop 头部
350+
// 使用局部 map 快速查找,跳过 hop-by-hop 头部
348351
for k, vv := range src {
349-
if !hopHeadersMap[k] {
352+
if !hopHeaders[k] {
350353
for _, v := range vv {
351354
dst.Add(k, v)
352355
}

internal/utils/utils.go

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -187,88 +187,85 @@ func GetTargetURL(client *http.Client, r *http.Request, pathConfig config.PathCo
187187
// 默认使用默认目标
188188
targetBase := pathConfig.DefaultTarget
189189

190-
// 如果没有设置最小阈值,使用默认值 500KB
191-
minThreshold := pathConfig.SizeThreshold
192-
if minThreshold <= 0 {
193-
minThreshold = 500 * 1024
194-
}
195-
196-
// 如果没有设置最大阈值,使用默认值 10MB
197-
maxThreshold := pathConfig.MaxSize
198-
if maxThreshold <= 0 {
199-
maxThreshold = 10 * 1024 * 1024
200-
}
201-
202-
// 检查文件扩展名
190+
// 如果配置了扩展名映射
203191
if pathConfig.ExtensionMap != nil {
204192
ext := strings.ToLower(filepath.Ext(path))
205193
if ext != "" {
206194
ext = ext[1:] // 移除开头的点
207-
// 先检查是否在扩展名映射中
195+
// 检查是否在扩展名映射中
208196
if altTarget, exists := pathConfig.GetExtensionTarget(ext); exists {
209-
// 使用 channel 来并发获取文件大小和检查可访问性
210-
type result struct {
211-
size int64
212-
accessible bool
213-
err error
214-
}
215-
defaultChan := make(chan result, 1)
216-
altChan := make(chan result, 1)
217-
218-
// 并发检查默认源和备用源
219-
go func() {
220-
size, err := GetFileSize(client, targetBase+path)
221-
defaultChan <- result{size: size, err: err}
222-
}()
223-
go func() {
224-
accessible := isTargetAccessible(client, altTarget+path)
225-
altChan <- result{accessible: accessible}
226-
}()
227-
228-
// 获取默认源结果
229-
defaultResult := <-defaultChan
230-
if defaultResult.err != nil {
231-
log.Printf("[FileSize] Failed to get size from default source for %s: %v", path, defaultResult.err)
197+
// 检查文件大小
198+
contentLength, err := GetFileSize(client, targetBase+path)
199+
if err != nil {
200+
log.Printf("[Route] %s -> %s (error getting size: %v)", path, targetBase, err)
232201
return targetBase
233202
}
234-
contentLength := defaultResult.size
235-
log.Printf("[FileSize] Path: %s, Size: %s (from default source)",
236-
path, FormatBytes(contentLength))
237203

238-
// 检查文件大小是否在阈值范围内
204+
// 如果没有设置最小阈值,使用默认值 500KB
205+
minThreshold := pathConfig.SizeThreshold
206+
if minThreshold <= 0 {
207+
minThreshold = 500 * 1024
208+
}
209+
210+
// 如果没有设置最大阈值,使用默认值 10MB
211+
maxThreshold := pathConfig.MaxSize
212+
if maxThreshold <= 0 {
213+
maxThreshold = 10 * 1024 * 1024
214+
}
215+
239216
if contentLength > minThreshold && contentLength <= maxThreshold {
240-
// 获取备用源检查结果
241-
altResult := <-altChan
242-
if altResult.accessible {
243-
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
244-
path, altTarget, FormatBytes(contentLength),
245-
FormatBytes(minThreshold), FormatBytes(maxThreshold))
246-
return altTarget
247-
} else {
217+
// 创建一个带超时的 context
218+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
219+
defer cancel()
220+
221+
// 使用 channel 来接收备用源检查结果
222+
altChan := make(chan struct {
223+
accessible bool
224+
err error
225+
}, 1)
226+
227+
// 在 goroutine 中检查备用源可访问性
228+
go func() {
229+
accessible := isTargetAccessible(client, altTarget+path)
230+
select {
231+
case altChan <- struct {
232+
accessible bool
233+
err error
234+
}{accessible: accessible}:
235+
case <-ctx.Done():
236+
// context 已取消,不需要发送结果
237+
}
238+
}()
239+
240+
// 等待结果或超时
241+
select {
242+
case result := <-altChan:
243+
if result.accessible {
244+
log.Printf("[Route] %s -> %s (size: %s > %s and <= %s)",
245+
path, altTarget, FormatBytes(contentLength),
246+
FormatBytes(minThreshold), FormatBytes(maxThreshold))
247+
return altTarget
248+
}
248249
log.Printf("[Route] %s -> %s (fallback: alternative target not accessible)",
249250
path, targetBase)
251+
case <-ctx.Done():
252+
log.Printf("[Route] %s -> %s (fallback: alternative target check timeout)",
253+
path, targetBase)
250254
}
251255
} else if contentLength <= minThreshold {
252-
// 如果文件大小不合适,直接丢弃备用源检查结果
253-
go func() { <-altChan }()
254256
log.Printf("[Route] %s -> %s (size: %s <= %s)",
255257
path, targetBase, FormatBytes(contentLength), FormatBytes(minThreshold))
256258
} else {
257-
// 如果文件大小不合适,直接丢弃备用源检查结果
258-
go func() { <-altChan }()
259259
log.Printf("[Route] %s -> %s (size: %s > %s)",
260260
path, targetBase, FormatBytes(contentLength), FormatBytes(maxThreshold))
261261
}
262262
} else {
263-
// 记录没有匹配扩展名映射的情况
264263
log.Printf("[Route] %s -> %s (no extension mapping)", path, targetBase)
265264
}
266265
} else {
267-
// 记录没有扩展名的情况
268266
log.Printf("[Route] %s -> %s (no extension)", path, targetBase)
269267
}
270268
} else {
271-
// 记录没有扩展名映射配置的情况
272269
log.Printf("[Route] %s -> %s (no extension map)", path, targetBase)
273270
}
274271

0 commit comments

Comments
 (0)