Skip to content

Commit 16f3cb3

Browse files
Add origin based rate limit (#41)
1 parent e6e2802 commit 16f3cb3

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ services:
66
- 80:8080
77
environment:
88
- RATE_LIMIT=60
9+
- ORIGIN_RATE_LIMIT=600
910
# - CLIENT_IP_HEADER=X-Real-IP
1011
restart: always

main.go

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"fmt"
66
"io/ioutil"
77
"net/http"
8+
"net/url"
89
"os"
910
"strconv"
11+
"strings"
1012
"sync"
1113
"time"
1214
)
@@ -25,6 +27,7 @@ type Response struct {
2527
var client *http.Client
2628
var limiter sync.Map
2729
var rateLimit int
30+
var originRateLimit int
2831

2932
func init() {
3033
client = &http.Client{}
@@ -37,6 +40,14 @@ func init() {
3740
panic("Invalid RATE_LIMIT value")
3841
}
3942
}
43+
originRateLimit = 0
44+
if os.Getenv("ORIGIN_RATE_LIMIT") != "" {
45+
var err error
46+
originRateLimit, err = strconv.Atoi(os.Getenv("ORIGIN_RATE_LIMIT"))
47+
if err != nil {
48+
panic("Invalid ORIGIN_RATE_LIMIT value")
49+
}
50+
}
4051
}
4152

4253
func CORS(next http.Handler) http.Handler {
@@ -96,13 +107,38 @@ func tunnel(URL string) Response {
96107
return result
97108
}
98109

99-
func check(address string) bool {
110+
func isLocalOrigin(origin string) bool {
111+
if origin == "" {
112+
return false
113+
}
114+
u, err := url.Parse(origin)
115+
if err != nil {
116+
return false
117+
}
118+
119+
hostname := u.Hostname()
120+
121+
localHostnames := []string{"localhost", "127.0.0.1", "0.0.0.0"}
122+
for _, lh := range localHostnames {
123+
if hostname == lh {
124+
return true
125+
}
126+
}
127+
128+
return strings.HasPrefix(hostname, "192.168.")
129+
}
130+
131+
func checkRateLimit(IP string, origin string) bool {
100132
var value int
101133

102-
count, _ := limiter.LoadOrStore(address, &value)
134+
if originRateLimit > 0 && !isLocalOrigin(origin) {
135+
count, _ := limiter.LoadOrStore(origin, &value)
136+
*count.(*int) += 1
137+
return *count.(*int) < originRateLimit
138+
}
103139

140+
count, _ := limiter.LoadOrStore(IP, &value)
104141
*count.(*int) += 1
105-
106142
return *count.(*int) < rateLimit
107143
}
108144

@@ -117,19 +153,27 @@ func getIP(request *http.Request) string {
117153

118154
func get(writer http.ResponseWriter, request *http.Request) {
119155
URL := request.URL.Query().Get("url")
156+
callback := request.URL.Query().Get("callback")
157+
158+
IP := getIP(request)
159+
origin := request.Header.Get("Origin")
120160

121161
if URL == "" {
122162
writer.Write([]byte("URL parameter is required."))
123163
return
124164
}
125165

126-
callback := request.URL.Query().Get("callback")
127-
128-
IP := getIP(request)
129-
allowed := check(IP)
166+
if origin == "" {
167+
writer.Write([]byte("Origin header is required."))
168+
return
169+
}
130170

131-
if !allowed {
132-
writer.Write([]byte(fmt.Sprintf("rate limited: you have a max of %d request (s) per minute", rateLimit)))
171+
if !checkRateLimit(IP, origin) {
172+
rateLimitValue := rateLimit
173+
if originRateLimit > 0 && !isLocalOrigin(origin) {
174+
rateLimitValue = originRateLimit
175+
}
176+
writer.Write([]byte(fmt.Sprintf("rate limited: limit %d request (s) per minute", rateLimitValue)))
133177
return
134178
}
135179

0 commit comments

Comments
 (0)