diff --git a/src/AspNetCoreRateLimit.Redis/RedisProcessingStrategy.cs b/src/AspNetCoreRateLimit.Redis/RedisProcessingStrategy.cs index 8481be4..7a070c8 100644 --- a/src/AspNetCoreRateLimit.Redis/RedisProcessingStrategy.cs +++ b/src/AspNetCoreRateLimit.Redis/RedisProcessingStrategy.cs @@ -1,5 +1,4 @@ -using AspNetCoreRateLimit; -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using StackExchange.Redis; using System; using System.Threading; @@ -21,8 +20,7 @@ public RedisProcessingStrategy(IConnectionMultiplexer connectionMultiplexer, IRa _logger = logger; } - static private readonly LuaScript _atomicIncrement = LuaScript.Prepare("local count = redis.call(\"INCRBYFLOAT\", @key, tonumber(@delta)) local ttl = redis.call(\"TTL\", @key) if ttl == -1 then redis.call(\"EXPIRE\", @key, @timeout) end return count"); - + static private readonly LuaScript _atomicIncrement = LuaScript.Prepare("local count = redis.call(\"INCRBYFLOAT\", @key, tonumber(@delta)) local ttl = redis.call(\"TTL\", @key) if ttl == -1 then redis.call(\"EXPIRE\", @key, @timeout) end return { 'count', count, 'ttl', ttl }"); public override async Task ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default) { var counterId = BuildCounterKey(requestIdentity, rule, counterKeyBuilder, rateLimitOptions); @@ -31,16 +29,18 @@ public override async Task ProcessRequestAsync(ClientRequestId public async Task IncrementAsync(string counterId, TimeSpan interval, Func RateIncrementer = null) { - var now = DateTime.UtcNow; - var numberOfIntervals = now.Ticks / interval.Ticks; - var intervalStart = new DateTime(numberOfIntervals * interval.Ticks, DateTimeKind.Utc); - _logger.LogDebug("Calling Lua script. {counterId}, {timeout}, {delta}", counterId, interval.TotalSeconds, 1D); - var count = await _connectionMultiplexer.GetDatabase().ScriptEvaluateAsync(_atomicIncrement, new { key = new RedisKey(counterId), timeout = interval.TotalSeconds, delta = RateIncrementer?.Invoke() ?? 1D }); + var cacheStart = DateTime.UtcNow; + var cached = await _connectionMultiplexer.GetDatabase().ScriptEvaluateAsync(_atomicIncrement, new { key = new RedisKey(counterId), timeout = interval.TotalSeconds, delta = RateIncrementer?.Invoke() ?? 1D }); + var responseDict = cached.ToDictionary(); + var ttlSeconds = (int)responseDict["ttl"]; + if (ttlSeconds != -1) + cacheStart = cacheStart.Add(-interval).AddSeconds(ttlSeconds); // Subtract the amount of seconds the interval adds, then add the amount of seconds still left to live. + var count = (double)responseDict["count"]; return new RateLimitCounter { - Count = (double)count, - Timestamp = intervalStart + Count = count, + Timestamp = cacheStart }; } }