1
1
package app.util
2
2
3
3
import io.javalin.Javalin
4
- import java.util.*
5
4
import java.util.concurrent.ConcurrentHashMap
5
+ import java.util.concurrent.Executors
6
+ import java.util.concurrent.TimeUnit
6
7
7
8
/* *
8
9
* A very naive IP-based rate-limiting mechanism
@@ -21,30 +22,28 @@ object RateLimitUtil {
21
22
fun enableTerribleRateLimiting (app : Javalin ) {
22
23
23
24
app.before { ctx ->
24
- if (ipReqCount[ctx.ip()] ? : 0 > 25 ) {
25
- throw TerribleRateLimitException ()
26
- }
27
- ipReqCount[ctx.ip()] = (ipReqCount[ctx.ip()] ? : 0 ) + 1
25
+ ipReqCount.compute(ctx.ip(), { _, count ->
26
+ when (count) {
27
+ null -> 1
28
+ in 0 .. 25 -> count + 1
29
+ else -> throw TerribleRateLimitException ()
30
+ }
31
+ })
28
32
}
29
33
30
- app.exception(TerribleRateLimitException ::class .java) { e , ctx ->
34
+ app.exception(TerribleRateLimitException ::class .java) { _ , ctx ->
31
35
ctx.result(" You can't spam this much. I'll give you a new request every five seconds." )
32
36
}
33
37
34
- Timer ().scheduleAtFixedRate(decrementAllCounters(), 0 , 5000 ) // every 5s
38
+ Executors .newSingleThreadScheduledExecutor()
39
+ .scheduleAtFixedRate(decrementAllCounters, 0 , 5 , TimeUnit .SECONDS )
35
40
36
41
}
37
42
38
- private fun decrementAllCounters () = object : TimerTask () {
39
- override fun run () {
40
- ipReqCount.forEach { ip, count ->
41
- if (count > 0 ) {
42
- ipReqCount[ip] = ipReqCount[ip]!! - 1
43
- } else {
44
- ipReqCount.remove(ip)
45
- }
46
- }
47
- }
43
+ private val decrementAllCounters = Runnable {
44
+ ipReqCount.forEachKey(1 , { ip ->
45
+ ipReqCount.computeIfPresent(ip, { _, count -> if (count > 1 ) count - 1 else null })
46
+ })
48
47
}
49
48
50
49
}
0 commit comments