diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index af4e17c38ac..a28c5464cfa 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -7,7 +7,6 @@ from enum import Enum from io import IOBase, BytesIO from pathlib import Path -from types import NoneType from typing import ( Annotated, Any, @@ -21,6 +20,13 @@ cast, ) +try: + # Python 3.10+ + from types import NoneType # type: ignore[attr-defined] +except ImportError: # pragma: no cover + # Python 3.9 and below + NoneType = type(None) # type: ignore[misc,assignment] + import filetype from tinytag import TinyTag, UnsupportedFormatError from typing_extensions import Self diff --git a/llama-index-core/llama_index/core/rate_limiter.py b/llama-index-core/llama_index/core/rate_limiter.py index 8a1ee127174..be4af1faa05 100644 --- a/llama-index-core/llama_index/core/rate_limiter.py +++ b/llama-index-core/llama_index/core/rate_limiter.py @@ -5,12 +5,21 @@ import threading import time from abc import ABC, abstractmethod -from typing import Optional +from collections import deque +from typing import Deque, Optional, Tuple -from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr +from llama_index.core.bridge.pydantic import ( + BaseModel, + Field, + PrivateAttr, + model_validator, +) logger = logging.getLogger(__name__) +# Sliding window duration in seconds (one minute) +_SLIDING_WINDOW_SECONDS = 60.0 + class BaseRateLimiter(ABC): """ @@ -206,5 +215,188 @@ async def async_acquire(self, num_tokens: int = 0) -> None: await asyncio.sleep(wait) +class SlidingWindowRateLimiter(BaseRateLimiter, BaseModel): + """ + Sliding-window rate limiter for strict per-minute caps. + + Unlike the token-bucket limiter, this implementation enforces a strict + sliding window: at any moment, only the requests (or tokens) that fall + within the last 60 seconds count toward the limit. There is no burst + allowance at window boundaries, which can be required by APIs that + specify hard limits per rolling minute. + + Supports both requests-per-minute (RPM) and tokens-per-minute (TPM) + limiting, with optional burst headroom to better match provider + semantics. Instances can be shared across multiple LLM and embedding + objects that hit the same API endpoint. + + Args: + requests_per_minute: Maximum requests allowed in any sliding + one-minute window. ``None`` disables request-rate limiting. + request_burst: Additional requests allowed as burst capacity + within the sliding window. Defaults to 0 (strict cap). + tokens_per_minute: Maximum tokens allowed in any sliding + one-minute window. ``None`` disables token-rate limiting. + token_burst: Additional tokens allowed as burst capacity + within the sliding window. Defaults to 0 (strict cap). + + Raises: + ValueError: If both ``requests_per_minute`` and ``tokens_per_minute`` + are ``None``, or if either is zero or negative. + + Examples: + .. code-block:: python + + from llama_index.core.rate_limiter import SlidingWindowRateLimiter + + limiter = SlidingWindowRateLimiter(requests_per_minute=60) + llm = SomeLLM(rate_limiter=limiter) + + """ + + requests_per_minute: Optional[float] = Field( + default=None, + description="Maximum number of requests in any sliding one-minute window.", + gt=0, + ) + request_burst: float = Field( + default=0.0, + ge=0.0, + description=( + "Additional requests allowed as burst capacity within the sliding window. " + "Set to 0 for a strict cap." + ), + ) + tokens_per_minute: Optional[float] = Field( + default=None, + description="Maximum number of tokens in any sliding one-minute window.", + gt=0, + ) + token_burst: float = Field( + default=0.0, + ge=0.0, + description=( + "Additional tokens allowed as burst capacity within the sliding window. " + "Set to 0 for a strict cap." + ), + ) + + _request_timestamps: Deque[float] = PrivateAttr(default_factory=deque) + _token_usage: Deque[Tuple[float, float]] = PrivateAttr(default_factory=deque) + _lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + + @model_validator(mode="after") + def _check_limits(self) -> "SlidingWindowRateLimiter": + if self.requests_per_minute is None and self.tokens_per_minute is None: + raise ValueError( + "At least one of requests_per_minute or tokens_per_minute must be set." + ) + return self + + def _prune_request_timestamps(self, now: float) -> None: + """Remove request timestamps outside the sliding window. Hold _lock.""" + while ( + self._request_timestamps + and self._request_timestamps[0] < now - _SLIDING_WINDOW_SECONDS + ): + self._request_timestamps.popleft() + + def _prune_token_usage(self, now: float) -> None: + """Remove token usage entries outside the sliding window. Hold _lock.""" + while ( + self._token_usage + and self._token_usage[0][0] < now - _SLIDING_WINDOW_SECONDS + ): + self._token_usage.popleft() + + def _current_token_usage(self) -> float: + """Sum of tokens currently in the sliding window. Hold _lock.""" + return sum(tokens for _ts, tokens in self._token_usage) + + def _wait_time(self, now: float, num_tokens: int = 0) -> float: + """ + Return seconds to wait before the next request is allowed. + + Must be called while holding ``_lock`` and after pruning. + """ + wait = 0.0 + if self.requests_per_minute is not None: + # Allow optional burst headroom in addition to the strict window cap. + allowed_requests = self.requests_per_minute + self.request_burst + if len(self._request_timestamps) >= allowed_requests: + # Wait until the oldest request exits the window + wait = max( + wait, + self._request_timestamps[0] + _SLIDING_WINDOW_SECONDS - now, + ) + if self.tokens_per_minute is not None and num_tokens > 0: + current = self._current_token_usage() + allowed_tokens = self.tokens_per_minute + self.token_burst + if current + num_tokens > allowed_tokens: + # Must wait until enough token usage expires + if not self._token_usage: + return wait + # How long until we can fit num_tokens (current usage must drop) + needed = current + num_tokens - allowed_tokens + # Expire from oldest; approximate wait from oldest entry + remaining = needed + for ts, tokens in self._token_usage: + remaining -= tokens + if remaining <= 0: + wait = max( + wait, + ts + _SLIDING_WINDOW_SECONDS - now, + ) + break + return wait + + def _record_usage(self, now: float, num_tokens: int = 0) -> None: + """Record one request and optional token usage. Hold _lock.""" + if self.requests_per_minute is not None: + self._request_timestamps.append(now) + if self.tokens_per_minute is not None and num_tokens > 0: + self._token_usage.append((now, float(num_tokens))) + + def acquire(self, num_tokens: int = 0) -> None: + """ + Block until one request is allowed (synchronous). + + Args: + num_tokens: Estimated token count for this request. Only + consulted when ``tokens_per_minute`` is configured. + + """ + while True: + now = time.monotonic() + with self._lock: + self._prune_request_timestamps(now) + self._prune_token_usage(now) + wait = self._wait_time(now, num_tokens) + if wait <= 0: + self._record_usage(now, num_tokens) + return + time.sleep(wait) + + async def async_acquire(self, num_tokens: int = 0) -> None: + """ + Wait until one request is allowed (asynchronous). + + Args: + num_tokens: Estimated token count for this request. Only + consulted when ``tokens_per_minute`` is configured. + + """ + while True: + now = time.monotonic() + with self._lock: + self._prune_request_timestamps(now) + self._prune_token_usage(now) + wait = self._wait_time(now, num_tokens) + if wait <= 0: + self._record_usage(now, num_tokens) + return + await asyncio.sleep(wait) + + # Backwards-compatible alias RateLimiter = TokenBucketRateLimiter diff --git a/llama-index-core/tests/test_rate_limiter.py b/llama-index-core/tests/test_rate_limiter.py index dbcfd6b01cf..6858a1977d2 100644 --- a/llama-index-core/tests/test_rate_limiter.py +++ b/llama-index-core/tests/test_rate_limiter.py @@ -11,6 +11,7 @@ from llama_index.core.rate_limiter import ( BaseRateLimiter, RateLimiter, + SlidingWindowRateLimiter, TokenBucketRateLimiter, ) @@ -279,3 +280,222 @@ def test_shared_limiter_between_llm_and_embedding() -> None: llm.rate_limiter = rl embed = MockEmbedding(embed_dim=8, rate_limiter=rl) assert llm.rate_limiter is embed.rate_limiter + + +# --------------------------------------------------------------------------- +# SlidingWindowRateLimiter tests +# --------------------------------------------------------------------------- + + +def test_sliding_window_is_subclass_of_base() -> None: + assert issubclass(SlidingWindowRateLimiter, BaseRateLimiter) + + +def test_sliding_window_creation_rpm_only() -> None: + rl = SlidingWindowRateLimiter(requests_per_minute=60) + assert rl.requests_per_minute == 60 + assert rl.tokens_per_minute is None + + +def test_sliding_window_creation_tpm_only() -> None: + rl = SlidingWindowRateLimiter(tokens_per_minute=10000) + assert rl.tokens_per_minute == 10000 + assert rl.requests_per_minute is None + + +def test_sliding_window_creation_both() -> None: + rl = SlidingWindowRateLimiter( + requests_per_minute=30, + tokens_per_minute=5000, + ) + assert rl.requests_per_minute == 30 + assert rl.tokens_per_minute == 5000 + + +def test_sliding_window_rejects_both_none() -> None: + with pytest.raises(ValueError, match="At least one of"): + SlidingWindowRateLimiter() + + +def test_sliding_window_rejects_zero_rpm() -> None: + with pytest.raises(ValueError): + SlidingWindowRateLimiter(requests_per_minute=0) + + +def test_sliding_window_rejects_negative_tpm() -> None: + with pytest.raises(ValueError): + SlidingWindowRateLimiter(tokens_per_minute=-1) + + +def test_sliding_window_burst_within_limit() -> None: + """First N requests within the window should not block.""" + rl = SlidingWindowRateLimiter(requests_per_minute=10) + start = time.monotonic() + for _ in range(10): + rl.acquire() + elapsed = time.monotonic() - start + assert elapsed < 1.0 + + +def test_sliding_window_blocks_after_limit() -> None: + """After exhausting the window, acquire blocks until oldest request exits.""" + rl = SlidingWindowRateLimiter(requests_per_minute=3) + for _ in range(3): + rl.acquire() + + with ( + patch("llama_index.core.rate_limiter.time.sleep") as mock_sleep, + patch("llama_index.core.rate_limiter.time.monotonic") as mock_time, + ): + base = 1000.0 + rl._request_timestamps.clear() + rl._request_timestamps.extend([base - 50, base - 40, base - 30]) + mock_time.side_effect = [base, base + 10, base + 30] + mock_sleep.return_value = None + rl.acquire() + mock_sleep.assert_called() + first_call_arg = mock_sleep.call_args_list[0][0][0] + assert first_call_arg >= 9.0 and first_call_arg <= 11.0 + + +def test_sliding_window_prune_removes_old_entries() -> None: + """Entries older than 60 seconds are pruned before checking limit.""" + rl = SlidingWindowRateLimiter(requests_per_minute=2) + now = 2000.0 + rl._request_timestamps.append(now - 70.0) + rl._request_timestamps.append(now - 65.0) + rl._request_timestamps.append(now - 5.0) + with patch("llama_index.core.rate_limiter.time.monotonic", return_value=now): + with rl._lock: + rl._prune_request_timestamps(now) + assert len(rl._request_timestamps) == 1 + + +def test_sliding_window_tpm_within_limit() -> None: + """Token usage within TPM limit should not block.""" + rl = SlidingWindowRateLimiter(tokens_per_minute=1000) + start = time.monotonic() + rl.acquire(num_tokens=100) + rl.acquire(num_tokens=200) + rl.acquire(num_tokens=300) + elapsed = time.monotonic() - start + assert elapsed < 1.0 + + +def test_sliding_window_tpm_blocks_when_exceeded() -> None: + """When token usage in window would exceed TPM, acquire blocks.""" + rl = SlidingWindowRateLimiter(tokens_per_minute=100) + rl.acquire(num_tokens=100) + + with ( + patch("llama_index.core.rate_limiter.time.sleep") as mock_sleep, + patch("llama_index.core.rate_limiter.time.monotonic") as mock_time, + ): + base = 5000.0 + rl._token_usage.clear() + rl._token_usage.append((base - 50.0, 100.0)) + mock_time.side_effect = [base, base + 10.0] + mock_sleep.return_value = None + rl.acquire(num_tokens=50) + mock_sleep.assert_called_once() + assert mock_sleep.call_args[0][0] >= 9.0 + + +def test_sliding_window_request_burst_allows_additional_requests() -> None: + """Configured request_burst should allow extra requests within the window.""" + rl = SlidingWindowRateLimiter(requests_per_minute=3, request_burst=2) + now = 3000.0 + rl._request_timestamps.clear() + rl._request_timestamps.extend([now - 5.0, now - 4.0, now - 3.0, now - 2.0]) + + with rl._lock: + rl._prune_request_timestamps(now) + wait = rl._wait_time(now, 0) + # 4 requests within a 3 rpm limit would normally block, but with + # request_burst=2 this should still be allowed without waiting. + assert wait == 0.0 + + # The 6th request should require waiting until at least one request + # falls out of the window. + rl._request_timestamps.append(now - 1.0) + with rl._lock: + rl._prune_request_timestamps(now) + wait_after = rl._wait_time(now, 0) + assert wait_after > 0.0 + + +def test_sliding_window_token_burst_allows_additional_tokens() -> None: + """Configured token_burst should allow extra token usage within the window.""" + rl = SlidingWindowRateLimiter(tokens_per_minute=100, token_burst=50) + now = 4000.0 + rl._token_usage.clear() + rl._token_usage.append((now - 5.0, 80.0)) + + with rl._lock: + rl._prune_token_usage(now) + wait = rl._wait_time(now, num_tokens=60) + # 80 + 60 = 140 tokens; with 100 TPM this would block, but with + # token_burst=50 the allowed total is 150 so no wait is required. + assert wait == 0.0 + + # Pushing beyond the burst should introduce a wait. + with rl._lock: + rl._prune_token_usage(now) + wait_after = rl._wait_time(now, num_tokens=80) + assert wait_after > 0.0 + + +@pytest.mark.asyncio +async def test_sliding_window_async_acquire_burst() -> None: + rl = SlidingWindowRateLimiter(requests_per_minute=15) + start = time.monotonic() + for _ in range(15): + await rl.async_acquire() + elapsed = time.monotonic() - start + assert elapsed < 1.0 + + +@pytest.mark.asyncio +async def test_sliding_window_concurrent_async() -> None: + """Multiple concurrent async_acquire calls must all complete.""" + rl = SlidingWindowRateLimiter(requests_per_minute=100) + results: list = [] + + async def worker(n: int) -> None: + await rl.async_acquire() + results.append(n) + + tasks = [worker(i) for i in range(25)] + await asyncio.gather(*tasks) + assert len(results) == 25 + + +def test_sliding_window_llm_sync_calls_acquire() -> None: + """MockLLM with SlidingWindowRateLimiter should call acquire.""" + mock_limiter = MagicMock(spec=SlidingWindowRateLimiter) + mock_limiter.acquire = MagicMock() + mock_limiter.async_acquire = AsyncMock() + llm = MockLLM() + llm.rate_limiter = mock_limiter + llm.complete("hello") + mock_limiter.acquire.assert_called_once() + + +@pytest.mark.asyncio +async def test_sliding_window_llm_async_calls_async_acquire() -> None: + mock_limiter = MagicMock(spec=SlidingWindowRateLimiter) + mock_limiter.acquire = MagicMock() + mock_limiter.async_acquire = AsyncMock() + llm = MockLLM() + llm.rate_limiter = mock_limiter + await llm.acomplete("hello") + mock_limiter.async_acquire.assert_called_once() + + +def test_sliding_window_embedding_calls_acquire() -> None: + from llama_index.core.embeddings.mock_embed_model import MockEmbedding + + rl = SlidingWindowRateLimiter(requests_per_minute=100) + embed = MockEmbedding(embed_dim=8, rate_limiter=rl) + result = embed.get_text_embedding("test") + assert len(result) == 8