@@ -5,8 +5,10 @@ import (
55 "fmt"
66 "io/ioutil"
77 "net/http"
8+ "net/url"
89 "os"
910 "strconv"
11+ "strings"
1012 "sync"
1113 "time"
1214)
@@ -25,6 +27,7 @@ type Response struct {
2527var client * http.Client
2628var limiter sync.Map
2729var rateLimit int
30+ var originRateLimit int
2831
2932func init () {
3033 client = & http.Client {}
@@ -37,6 +40,14 @@ func init() {
3740 panic ("Invalid RATE_LIMIT value" )
3841 }
3942 }
43+ originRateLimit = 0
44+ if os .Getenv ("ORIGIN_RATE_LIMIT" ) != "" {
45+ var err error
46+ originRateLimit , err = strconv .Atoi (os .Getenv ("ORIGIN_RATE_LIMIT" ))
47+ if err != nil {
48+ panic ("Invalid ORIGIN_RATE_LIMIT value" )
49+ }
50+ }
4051}
4152
4253func CORS (next http.Handler ) http.Handler {
@@ -96,13 +107,38 @@ func tunnel(URL string) Response {
96107 return result
97108}
98109
99- func check (address string ) bool {
110+ func isLocalOrigin (origin string ) bool {
111+ if origin == "" {
112+ return false
113+ }
114+ u , err := url .Parse (origin )
115+ if err != nil {
116+ return false
117+ }
118+
119+ hostname := u .Hostname ()
120+
121+ localHostnames := []string {"localhost" , "127.0.0.1" , "0.0.0.0" }
122+ for _ , lh := range localHostnames {
123+ if hostname == lh {
124+ return true
125+ }
126+ }
127+
128+ return strings .HasPrefix (hostname , "192.168." )
129+ }
130+
131+ func checkRateLimit (IP string , origin string ) bool {
100132 var value int
101133
102- count , _ := limiter .LoadOrStore (address , & value )
134+ if originRateLimit > 0 && ! isLocalOrigin (origin ) {
135+ count , _ := limiter .LoadOrStore (origin , & value )
136+ * count .(* int ) += 1
137+ return * count .(* int ) < originRateLimit
138+ }
103139
140+ count , _ := limiter .LoadOrStore (IP , & value )
104141 * count .(* int ) += 1
105-
106142 return * count .(* int ) < rateLimit
107143}
108144
@@ -117,19 +153,27 @@ func getIP(request *http.Request) string {
117153
118154func get (writer http.ResponseWriter , request * http.Request ) {
119155 URL := request .URL .Query ().Get ("url" )
156+ callback := request .URL .Query ().Get ("callback" )
157+
158+ IP := getIP (request )
159+ origin := request .Header .Get ("Origin" )
120160
121161 if URL == "" {
122162 writer .Write ([]byte ("URL parameter is required." ))
123163 return
124164 }
125165
126- callback := request . URL . Query (). Get ( "callback" )
127-
128- IP := getIP ( request )
129- allowed := check ( IP )
166+ if origin == "" {
167+ writer . Write ([] byte ( "Origin header is required." ))
168+ return
169+ }
130170
131- if ! allowed {
132- writer .Write ([]byte (fmt .Sprintf ("rate limited: you have a max of %d request (s) per minute" , rateLimit )))
171+ if ! checkRateLimit (IP , origin ) {
172+ rateLimitValue := rateLimit
173+ if originRateLimit > 0 && ! isLocalOrigin (origin ) {
174+ rateLimitValue = originRateLimit
175+ }
176+ writer .Write ([]byte (fmt .Sprintf ("rate limited: limit %d request (s) per minute" , rateLimitValue )))
133177 return
134178 }
135179
0 commit comments