Skip to content

Commit 6187a4e

Browse files
committed
proxy: limit concurrent traceroute requests
1 parent 843af2e commit 6187a4e

File tree

4 files changed

+93
-12
lines changed

4 files changed

+93
-12
lines changed

proxy/main.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,21 @@ func accessHandler(next http.Handler) http.Handler {
6161
}
6262

6363
type settingType struct {
64-
birdSocket string
65-
listen []string
66-
allowedNets []*net.IPNet
67-
tr_bin string
68-
tr_flags []string
69-
tr_raw bool
64+
birdSocket string
65+
listen []string
66+
allowedNets []*net.IPNet
67+
tr_bin string
68+
tr_flags []string
69+
tr_raw bool
70+
tr_max_concurrent int
7071
}
7172

7273
var setting settingType
7374

7475
// Wrapper of tracer
7576
func main() {
7677
parseSettings()
78+
initTracerouteSemaphore(setting.tr_max_concurrent)
7779
tracerouteAutodetect()
7880

7981
mux := http.NewServeMux()

proxy/settings.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ import (
1111
)
1212

1313
type viperSettingType struct {
14-
BirdSocket string `mapstructure:"bird_socket"`
15-
Listen []string `mapstructure:"listen"`
16-
AllowedNets string `mapstructure:"allowed_ips"`
17-
TracerouteBin string `mapstructure:"traceroute_bin"`
18-
TracerouteFlags string `mapstructure:"traceroute_flags"`
19-
TracerouteRaw bool `mapstructure:"traceroute_raw"`
14+
BirdSocket string `mapstructure:"bird_socket"`
15+
Listen []string `mapstructure:"listen"`
16+
AllowedNets string `mapstructure:"allowed_ips"`
17+
TracerouteBin string `mapstructure:"traceroute_bin"`
18+
TracerouteFlags string `mapstructure:"traceroute_flags"`
19+
TracerouteRaw bool `mapstructure:"traceroute_raw"`
20+
TracerouteMaxConcurrent int `mapstructure:"traceroute_max_concurrent"`
2021
}
2122

2223
// Parse settings with viper, and convert to legacy setting format
@@ -52,6 +53,9 @@ func parseSettings() {
5253
pflag.Bool("traceroute_raw", false, "whether to display traceroute outputs raw; set via parameter or environment variable BIRDLG_TRACEROUTE_RAW")
5354
viper.BindPFlag("traceroute_raw", pflag.Lookup("traceroute_raw"))
5455

56+
pflag.Int("traceroute_max_concurrent", 10, "max concurrent traceroute requests allowed")
57+
viper.BindPFlag("traceroute_max_concurrent", pflag.Lookup("traceroute_max_concurrent"))
58+
5559
pflag.Parse()
5660

5761
if err := viper.ReadInConfig(); err != nil {
@@ -101,6 +105,7 @@ func parseSettings() {
101105
}
102106

103107
setting.tr_raw = viperSettings.TracerouteRaw
108+
setting.tr_max_concurrent = viperSettings.TracerouteMaxConcurrent
104109

105110
fmt.Printf("%#v\n", setting)
106111
}

proxy/traceroute.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ import (
1111
"github.com/google/shlex"
1212
)
1313

14+
var tracerouteSemaphore chan struct{}
15+
16+
func initTracerouteSemaphore(maxConcurrent int) {
17+
if maxConcurrent > 0 {
18+
tracerouteSemaphore = make(chan struct{}, maxConcurrent)
19+
}
20+
}
21+
1422
func tracerouteArgsToString(cmd string, args []string, target []string) string {
1523
var cmdCombined = append([]string{cmd}, args...)
1624
cmdCombined = append(cmdCombined, target...)
@@ -83,6 +91,20 @@ func tracerouteAutodetect() {
8391
}
8492

8593
func tracerouteHandler(httpW http.ResponseWriter, httpR *http.Request) {
94+
// Check concurrency limit
95+
if setting.tr_max_concurrent > 0 {
96+
select {
97+
case tracerouteSemaphore <- struct{}{}:
98+
// Successfully acquired semaphore slot
99+
defer func() { <-tracerouteSemaphore }()
100+
default:
101+
// Semaphore is full, reject request
102+
httpW.WriteHeader(http.StatusServiceUnavailable)
103+
httpW.Write([]byte("Too many concurrent traceroute requests. Please try again later.\n"))
104+
return
105+
}
106+
}
107+
86108
query := string(httpR.URL.Query().Get("q"))
87109
query = strings.TrimSpace(query)
88110

proxy/traceroute_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ func TestTracerouteAutodetectFlagsOnly(t *testing.T) {
9797
}
9898

9999
func TestTracerouteHandlerWithoutQuery(t *testing.T) {
100+
initTracerouteSemaphore(setting.tr_max_concurrent)
100101
r := httptest.NewRequest(http.MethodGet, "/traceroute", nil)
101102
w := httptest.NewRecorder()
102103
tracerouteHandler(w, r)
@@ -107,6 +108,7 @@ func TestTracerouteHandlerWithoutQuery(t *testing.T) {
107108
}
108109

109110
func TestTracerouteHandlerShlexError(t *testing.T) {
111+
initTracerouteSemaphore(setting.tr_max_concurrent)
110112
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("\"1.1.1.1"), nil)
111113
w := httptest.NewRecorder()
112114
tracerouteHandler(w, r)
@@ -117,6 +119,7 @@ func TestTracerouteHandlerShlexError(t *testing.T) {
117119
}
118120

119121
func TestTracerouteHandlerNoTracerouteFound(t *testing.T) {
122+
initTracerouteSemaphore(setting.tr_max_concurrent)
120123
setting.tr_bin = ""
121124
setting.tr_flags = nil
122125

@@ -130,6 +133,7 @@ func TestTracerouteHandlerNoTracerouteFound(t *testing.T) {
130133
}
131134

132135
func TestTracerouteHandlerExecuteError(t *testing.T) {
136+
initTracerouteSemaphore(setting.tr_max_concurrent)
133137
setting.tr_bin = "sh"
134138
setting.tr_flags = []string{"-c", "false"}
135139
setting.tr_raw = true
@@ -144,6 +148,7 @@ func TestTracerouteHandlerExecuteError(t *testing.T) {
144148
}
145149

146150
func TestTracerouteHandlerRaw(t *testing.T) {
151+
initTracerouteSemaphore(setting.tr_max_concurrent)
147152
setting.tr_bin = "sh"
148153
setting.tr_flags = []string{"-c", "echo Mock"}
149154
setting.tr_raw = true
@@ -156,6 +161,7 @@ func TestTracerouteHandlerRaw(t *testing.T) {
156161
}
157162

158163
func TestTracerouteHandlerPostprocess(t *testing.T) {
164+
initTracerouteSemaphore(setting.tr_max_concurrent)
159165
setting.tr_bin = "sh"
160166
setting.tr_flags = []string{"-c", "echo \"first line\n 2 *\nthird line\""}
161167
setting.tr_raw = false
@@ -166,3 +172,49 @@ func TestTracerouteHandlerPostprocess(t *testing.T) {
166172
assert.Equal(t, w.Code, http.StatusOK)
167173
assert.Equal(t, w.Body.String(), "first line\nthird line\n\n1 hops not responding.")
168174
}
175+
176+
func TestTracerouteHandlerConcurrencyLimit(t *testing.T) {
177+
// Set a low limit for testing
178+
maxConcurrent := 2
179+
initTracerouteSemaphore(maxConcurrent)
180+
setting.tr_max_concurrent = maxConcurrent
181+
182+
// Use a slow command to keep requests running
183+
setting.tr_bin = "sh"
184+
setting.tr_flags = []string{"-c", "sleep 1; echo Done"}
185+
setting.tr_raw = true
186+
187+
// Launch more concurrent requests than the limit
188+
numRequests := 5
189+
responses := make(chan int, numRequests)
190+
191+
for i := 0; i < numRequests; i++ {
192+
go func() {
193+
r := httptest.NewRequest(http.MethodGet, "/traceroute?q="+url.QueryEscape("1.1.1.1"), nil)
194+
w := httptest.NewRecorder()
195+
tracerouteHandler(w, r)
196+
responses <- w.Code
197+
}()
198+
}
199+
200+
// Collect all responses
201+
statusCodes := make(map[int]int)
202+
for i := 0; i < numRequests; i++ {
203+
code := <-responses
204+
statusCodes[code]++
205+
}
206+
207+
// Verify that some requests succeeded (200) and some were rejected (503)
208+
if statusCodes[http.StatusOK] == 0 {
209+
t.Error("Expected at least one request to succeed with 200")
210+
}
211+
if statusCodes[http.StatusServiceUnavailable] == 0 {
212+
t.Error("Expected at least one request to be rejected with 503")
213+
}
214+
215+
// Verify we didn't get any unexpected status codes
216+
totalRequests := statusCodes[http.StatusOK] + statusCodes[http.StatusServiceUnavailable]
217+
if totalRequests != numRequests {
218+
t.Errorf("Expected %d total requests, got %d", numRequests, totalRequests)
219+
}
220+
}

0 commit comments

Comments
 (0)