1
1
package app.util
2
2
3
3
import io.javalin.Javalin
4
+ import java.util.*
4
5
import java.util.concurrent.ConcurrentHashMap
5
- import java.util.concurrent.Executors
6
- import java.util.concurrent.TimeUnit
7
6
8
7
/* *
9
8
* A very naive IP-based rate-limiting mechanism
@@ -22,28 +21,30 @@ object RateLimitUtil {
22
21
fun enableTerribleRateLimiting (app : Javalin ) {
23
22
24
23
app.before { ctx ->
25
- ipReqCount.compute(ctx.ip(), { _, count ->
26
- when (count) {
27
- null -> 1
28
- in 0 .. 25 -> count + 1
29
- else -> throw TerribleRateLimitException ()
30
- }
31
- })
24
+ if (ipReqCount[ctx.ip()] ? : 0 > 25 ) {
25
+ throw TerribleRateLimitException ()
26
+ }
27
+ ipReqCount[ctx.ip()] = (ipReqCount[ctx.ip()] ? : 0 ) + 1
32
28
}
33
29
34
- app.exception(TerribleRateLimitException ::class .java) { _ , ctx ->
30
+ app.exception(TerribleRateLimitException ::class .java) { e , ctx ->
35
31
ctx.result(" You can't spam this much. I'll give you a new request every five seconds." )
36
32
}
37
33
38
- Executors .newSingleThreadScheduledExecutor()
39
- .scheduleAtFixedRate(decrementAllCounters, 0 , 5 , TimeUnit .SECONDS )
34
+ Timer ().scheduleAtFixedRate(decrementAllCounters(), 0 , 5000 ) // every 5s
40
35
41
36
}
42
37
43
- private val decrementAllCounters = Runnable {
44
- ipReqCount.forEachKey(1 , { ip ->
45
- ipReqCount.computeIfPresent(ip, { _, count -> if (count > 1 ) count - 1 else null })
46
- })
38
+ private fun decrementAllCounters () = object : TimerTask () {
39
+ override fun run () {
40
+ ipReqCount.forEach { ip, count ->
41
+ if (count > 0 ) {
42
+ ipReqCount[ip] = ipReqCount[ip]!! - 1
43
+ } else {
44
+ ipReqCount.remove(ip)
45
+ }
46
+ }
47
+ }
47
48
}
48
49
49
50
}
0 commit comments