diff --git a/benchmarks/multi-round-qa/multi-round-qa.py b/benchmarks/multi-round-qa/multi-round-qa.py index e7d024c88..3c0b65dcd 100644 --- a/benchmarks/multi-round-qa/multi-round-qa.py +++ b/benchmarks/multi-round-qa/multi-round-qa.py @@ -40,6 +40,9 @@ class WorkloadConfig: # Whether to include user id in request header enable_user_id: bool + # Max number of unfinished queries allowed (None means no limit) + max_unfinished_queries: Optional[int] + @dataclass class UserConfig: @@ -419,6 +422,13 @@ def step(self, timestamp: float, executor: RequestExecutor): if self.start_time is None: self.start_time = timestamp + pending_queries = len([s for s in self.sessions if s.has_unfinished_request]) + # Only check limit if max_unfinished_queries is set + if (self.workload_config.max_unfinished_queries is not None and + pending_queries > self.workload_config.max_unfinished_queries): + logger.info(f"unfinished queries >{self.workload_config.max_unfinished_queries}, waiting") + return + if timestamp - self.last_user_join > self.gap_between_users: self._create_user_session() self.last_user_join = timestamp @@ -625,6 +635,12 @@ def parse_arguments() -> WorkloadConfig: parser.add_argument( "--sharegpt", action="store_true", help="Whether to use ShareGPT dataset" ) + parser.add_argument( + "--max-unfinished-queries", + type=int, + default=None, + help="Maximum number of unfinished queries allowed (default: no limit)", + ) args = parser.parse_args() return args @@ -675,6 +691,7 @@ def main(): qps=args.qps, model=args.model, enable_user_id=args.request_with_user_id, + max_unfinished_queries=args.max_unfinished_queries, ) manager = UserSessionManager( diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 0713e9c0f..e9a2c8d4a 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -109,6 +109,18 @@ async def lifespan(app: FastAPI): dyn_cfg_watcher.close() +def create_instance_id_to_url(lmcache_instances, static_backends): + if lmcache_instances is None or static_backends is None: + return None + instance_ids = [s.strip() for s in lmcache_instances.split(',') if s.strip()] + urls = parse_static_urls(static_backends) + if not instance_ids or not urls: + return None + if len(instance_ids) != len(urls): + raise ValueError("length of lmcache-instances & static-backends mismatched") + return dict(zip(instance_ids, urls)) + + def initialize_all(app: FastAPI, args): """ Initialize all the components of the router with the given arguments. @@ -206,6 +218,9 @@ def initialize_all(app: FastAPI, args): prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, kv_aware_threshold=args.kv_aware_threshold, + tokenizer=args.tokenizer, + instance_id_to_url=create_instance_id_to_url(args.lmcache_instances, + args.static_backends), ) # Initialize feature gates diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8b12cf983..9e6de9be0 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -20,6 +20,7 @@ from vllm_router.parsers.yaml_utils import ( read_and_process_yaml_config_file, ) +from vllm_router.routers.routing_logic import RoutingLogic from vllm_router.version import __version__ try: @@ -203,13 +204,7 @@ def parse_args(): parser.add_argument( "--routing-logic", type=str, - choices=[ - "roundrobin", - "session", - "kvaware", - "prefixaware", - "disaggregated_prefill", - ], + choices=[routing for routing in RoutingLogic], help="The routing logic to use", ) parser.add_argument( @@ -218,12 +213,25 @@ def parse_args(): default=9000, help="The port of the LMCache controller.", ) + parser.add_argument( + "--lmcache-instances", + type=str, + default=None, + help="The instance id in the lmcache config files, must be with the length of static-backends," + " separated by commas. E.g., instance_0,instance_1", + ) parser.add_argument( "--session-key", type=str, default=None, help="The key (in the header) to identify a session.", ) + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="The tokenizer model.", + ) parser.add_argument( "--callbacks", type=str, diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 093872b3f..2db263742 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -18,7 +18,8 @@ import math import random import threading -from typing import Dict, List +import traceback +from typing import Dict, List, Optional, Tuple, Union from fastapi import Request @@ -31,6 +32,7 @@ from lmcache.v1.cache_controller import controller_manager from lmcache.v1.cache_controller.message import ( LookupMsg, + FullLookupMsg, QueryInstMsg, ) except ImportError: @@ -40,18 +42,45 @@ from vllm_router.log import init_logger from vllm_router.service_discovery import EndpointInfo from vllm_router.stats.engine_stats import EngineStats -from vllm_router.stats.request_stats import RequestStats +from vllm_router.stats.request_stats import RequestStats, RequestStatsCacheInfo from vllm_router.utils import SingletonABCMeta logger = init_logger(__name__) +def extract_prompt(request_json: Dict): + """Extract prompt message from the request json object.""" + if "messages" in request_json: + # Get the last message from the messages array + messages = request_json["messages"] + if messages: + # Concatenate all message content + prompt_parts = [] + for message in messages: + content = message.get("content", "") + if isinstance(content, list): + # Handle multimodal messages + text_content = " ".join( + part.get("text", "") + for part in content + if part.get("type") == "text" + ) + prompt_parts.append(text_content) + elif content is not None: + prompt_parts.append(content) + return "\n".join(prompt_parts) + return "" + # Handle regular completions + return request_json["prompt"] + + class RoutingLogic(str, enum.Enum): ROUND_ROBIN = "roundrobin" SESSION_BASED = "session" KVAWARE = "kvaware" PREFIXAWARE = "prefixaware" DISAGGREGATED_PREFILL = "disaggregated_prefill" + TTFT = "ttft" class RoutingInterface(metaclass=SingletonABCMeta): @@ -108,7 +137,7 @@ def route_request( engine_stats: Dict[str, EngineStats], request_stats: Dict[str, RequestStats], request: Request, - ) -> str: + ) -> Union(str, Tuple[str, RequestStatsCacheInfo]): """ Route the request to the appropriate engine URL @@ -229,6 +258,8 @@ def __init__( lmcache_controller_port: int, session_key: str, kv_aware_threshold: int = 2000, + tokenizer_name: Optional[str] = None, + instance_id_to_url: Optional[Dict[str, str]] = None, ): self.lmcache_controller_port = lmcache_controller_port logger.info( @@ -238,9 +269,13 @@ def __init__( f"0.0.0.0:{self.lmcache_controller_port}" ) self.req_id = 0 - self.instance_id_to_ip = {} + if instance_id_to_url is None: + self.instance_id_to_url = {} + else: + self.instance_id_to_url = instance_id_to_url self.session_key = session_key self.hash_ring = HashRing() + self.tokenizer_name = tokenizer_name self.tokenizer = None self.threshold = kv_aware_threshold @@ -252,6 +287,8 @@ def start_kv_manager(self): self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) self.thread.start() asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop) + if self.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) def query_manager(self, msg) -> str: """ @@ -288,8 +325,8 @@ async def route_request( self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0]) url = endpoints[0].url + "/tokenize" # TODO (Yuhan): Handle chat completions - token_ids = self.tokenizer.encode(request_json["prompt"]) - msg = LookupMsg(tokens=token_ids) + token_ids = self.tokenizer.encode(extract_prompt(request_json)) + msg = LookupMsg(event_id="", tokens=token_ids) instance_id = await self.query_manager(msg) matched_tokens = math.inf if len(list(instance_id.layout_info.keys())) > 0: @@ -319,23 +356,24 @@ async def route_request( return url else: queried_instance_ids = [info for info in instance_id.layout_info] - if queried_instance_ids[0] not in self.instance_id_to_ip: + if queried_instance_ids[0] not in self.instance_id_to_url: for endpoint in endpoints: query_message = QueryInstMsg( + event_id="", ip=endpoint.url.split(f":{endpoint.url.split(':')[-1]}")[ 0 ].split("//")[1] ) endpoint_instance_id = await self.query_manager(query_message) - self.instance_id_to_ip[endpoint_instance_id.instance_id] = ( + self.instance_id_to_url[endpoint_instance_id.instance_id] = ( endpoint.url ) - logger.info(f"Instance id to ip: {self.instance_id_to_ip}") + logger.info(f"Instance id to ip: {self.instance_id_to_url}") logger.info( f"Routing request to {queried_instance_ids[0]} found by kvaware router" ) - return self.instance_id_to_ip[queried_instance_ids[0]] + return self.instance_id_to_url[queried_instance_ids[0]] class PrefixAwareRouter(RoutingInterface): @@ -378,33 +416,7 @@ async def route_request( request_json (Dict): The request body (needed for finding the longest prefix match) """ - - # Handle chat completions - if "messages" in request_json: - # Get the last message from the messages array - messages = request_json["messages"] - if messages: - # Concatenate all message content - prompt_parts = [] - for message in messages: - content = message.get("content", "") - if isinstance(content, list): - # Handle multimodal messages - text_content = " ".join( - part.get("text", "") - for part in content - if part.get("type") == "text" - ) - prompt_parts.append(text_content) - elif content is not None: - prompt_parts.append(content) - prompt = "\n".join(prompt_parts) - else: - prompt = "" - else: - # Handle regular completions - prompt = request_json["prompt"] - + prompt = extract_prompt(request_json) available_endpoints = set(endpoint.url for endpoint in endpoints) _, matched_endpoint = await self.hashtrie.longest_prefix_match( prompt, available_endpoints @@ -460,6 +472,216 @@ def route_request( return decoder_endpoints[0].url +class TtftRouter(RoutingInterface): + """ + Route the request to the qppropriate engine URL by the least estimated TTFT. + """ + + CACHE_LOC_TO_TRANS_TIME: Dict[str, float] = { + "LocalCPUBackend": 0.01, + "LocalDiskBackend": 0.015, + } + + DEFAULT_CACHE_TRANS_TIME: float = 0.01 + + def __init__( + self, + lmcache_controller_port: int, + session_key: str, + tokenizer_name: Optional[str] = None, + instance_id_to_url: Optional[Dict[str, str]] = None, + ): + logger.info( + f"Initializing TtftRouter with lmcache addr: 0.0.0.0:{lmcache_controller_port}" + ) + self.kv_manager = controller_manager.LMCacheControllerManager( + f"0.0.0.0:{lmcache_controller_port}" + ) + if instance_id_to_url is None: + self.instance_id_to_url = {} + else: + self.instance_id_to_url = instance_id_to_url + self.session_key = session_key + self.hash_ring = HashRing() + self.tokenizer_name = tokenizer_name + self.tokenizer = None + self.cached_prefix_tokens = None + + def start_kv_manager(self): + """ + Start the kv manager + """ + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) + self.thread.start() + asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop) + if self.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + + async def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + request_json: Dict, + ) -> Tuple[str, RequestStatsCacheInfo]: + """ + Route the request to the appropriate engine URL by where the KV cache + of the longest prefix match is found. + If there is no session id in the reqest header, it will pick a server + with round robin. + + Args: + endpoints (List[EndpointInfo]): The list of engine URLs + engine_stats (Dict[str, EngineStats]): The engine stats indicating + the 'physical' load of each engine + request_stats (Dict[str, RequestStats]): The request stats + indicating the request-level performance of each engine + request (Request): The incoming request + request_json (Dist): The request body (needed for finding the + longest prefix match) + """ + if self.tokenizer is None: + # fallback to use the model of the first endpoint as tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(endpoints[0].model_names[0]) + + token_ids = self.tokenizer.encode(extract_prompt(request_json)) + cache_info = RequestStatsCacheInfo() + cache_info.num_prefix_tokens = len(token_ids) + try: + if request_stats is None: + raise ValueError("no request stats was provided") + msg = FullLookupMsg(event_id="", tokens=token_ids) + ret_msg = await self.kv_manager.handle_orchestration_message(msg) + matched_infos = ret_msg.matched_info + if matched_infos: + best_matched_info = self._find_best_matched(matched_infos) + best_ttft_url = await self._find_best_ttft(endpoints, matched_infos, + best_matched_info, request_stats) + cache_info.num_cached_tokens = best_matched_info[1][-1][1] + return best_ttft_url, cache_info + except ValueError: + logger.info("Fallback to QPS routing due to:") + logger.info(traceback.format_exc()) + cache_info.num_cached_tokens = 0 + return self._fallback_routing(endpoints, request_stats, request), cache_info + + def _find_best_matched(self, matched_infos): + best_matched_info = None + for instance_info in matched_infos: + if best_matched_info is None or instance_info[1][-1][1] > best_matched_info[1][-1][1]: + best_matched_info = instance_info + if best_matched_info is None: + raise ValueError("no best matched instance was found") + return best_matched_info + + async def _find_best_ttft(self, endpoints, matched_infos, best_matched_info, + request_stats): + matched_stats = [] + matched_urls = [] + for matched_info in matched_infos: + url = await self._get_instance_url(endpoints, matched_info[0]) + stats = request_stats.get(url, None) + if stats is None: + raise ValueError(f"{url} provides no request stats ") + if stats.uncomputed_prefix_tokens > 0 and stats.engine_prefill_tps <= 0: + raise ValueError(f"{url} provides no way to forecasted queue time") + matched_urls.append(url) + matched_stats.append(stats) + + # cache matched pass + best_ttft = float('inf') + best_ttft_url = None + for i, matched_info in enumerate(matched_infos): + logger.debug(f"-------------- URL:{matched_urls[i]} --------------") + ttft = self._estimate_ttft(matched_info, best_matched_info, + matched_stats[i]) + if best_ttft_url is None or ttft <= best_ttft: + best_ttft = ttft + best_ttft_url = matched_urls[i] + + # cache not matched pass + matched_url_set = set(matched_urls) + not_matched_endpoints = [endpoint for endpoint in endpoints if endpoint.url not in matched_url_set] + for endpoint in not_matched_endpoints: + url = endpoint.url + stats = request_stats.get(url, None) + if stats is None: + raise ValueError(f"{url} provides no request stats ") + logger.debug(f"-------------- URL:{url} --------------") + ttft = self._estimate_ttft(None, best_matched_info, stats) + if best_ttft_url is None or ttft <= best_ttft: + best_ttft = ttft + best_ttft_url = url + + if best_ttft_url is None: + raise ValueError(f"no best TTFT instance was found") + return best_ttft_url + + def _estimate_ttft(self, matched_info, best_matched_info, stats): + transfer_time = self._calc_transfer_time(matched_info, best_matched_info) + # TODO take computation time of num_uncached_token into account + if stats.uncomputed_prefix_tokens == 0: + forecasted_queue_time = 0 + else: + forecasted_queue_time = (stats.uncomputed_prefix_tokens / + stats.engine_prefill_tps) + ttft = forecasted_queue_time + transfer_time + + logger.debug(f"-------------- time estimations --------------") + logger.debug(f"uncomputed_prefix_tokens: {stats.uncomputed_prefix_tokens}") + logger.debug(f"engine_prefill_tps: {stats.engine_prefill_tps}") + logger.debug(f"transfer_time: {transfer_time}") + logger.debug(f"forecasted_queue_time: {forecasted_queue_time}") + logger.debug(f"ttft: {ttft}") + return ttft + + async def _get_instance_url(self, endpoints, instance_id): + url = self.instance_id_to_url.get(instance_id, None) + if url is not None: + return url + for endpoint in endpoints: + msg = QueryInstMsg( + event_id="", + ip=endpoint.url.split(f":{endpoint.url.split(":")[-1]}")[ + 0 + ].split("//")[1] + ) + ret_msg = await self.kv_manager.handle_orchestration_message(msg) + self.instance_id_to_url[ret_msg.instance_id] = endpoint.url + if ret_msg.instance_id == instance_id: + url = endpoint.url + if url is None: + raise ValueError(f"cannot resolve URL for {instance_id}") + return url + + def _calc_transfer_time(self, matched_info, best_matched_info): + transfer_time = 0 + for chunk in best_matched_info[1]: + if matched_info is not None and chunk[1] <= matched_info[1][-1][1]: + continue + # TODO chunk transfer time measured realtime inside vllm engine + transfer_time += self.CACHE_LOC_TO_TRANS_TIME.get(chunk[0], + self.DEFAULT_CACHE_TRANS_TIME) + return transfer_time + + def _fallback_routing(self, endpoints, request_stats, request): + session_id = request.headers.get(self.session_key, None) + logger.debug(f"Got session id: {session_id}") + + # Update the hash ring with the current list of endpoints + self._update_hash_ring(endpoints) + + if session_id is None: + # Route base on QPS if no session ID is present + url = self._qps_routing(endpoints, request_stats) + else: + # Use the hash ring to get the endpoint for the session ID + url = self.hash_ring.get_node(session_id) + return url + + # Instead of managing a global _global_router, we can define the initialization functions as: def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs @@ -476,6 +698,8 @@ def initialize_routing_logic( kwargs.get("lmcache_controller_port"), kwargs.get("session_key"), kwargs.get("kv_aware_threshold"), + kwargs.get("tokenizer"), + kwargs.get("instance_id_to_url"), ) router.start_kv_manager() return router @@ -487,6 +711,16 @@ def initialize_routing_logic( return DisaggregatedPrefillRouter( kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") ) + elif routing_logic == RoutingLogic.TTFT: + logger.info("Initializing ttft routing logic") + router = TtftRouter( + kwargs.get("lmcache_controller_port"), + kwargs.get("session_key"), + kwargs.get("tokenizer"), + kwargs.get("instance_id_to_url"), + ) + router.start_kv_manager() + return router else: raise ValueError(f"Invalid routing logic {routing_logic}") @@ -514,6 +748,7 @@ def get_routing_logic() -> RoutingInterface: KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, + TtftRouter, ): if cls in SingletonABCMeta._instances: return cls() diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 83e647927..d3c60cf6f 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -28,6 +28,7 @@ DisaggregatedPrefillRouter, KvawareRouter, PrefixAwareRouter, + TtftRouter, ) from vllm_router.service_discovery import get_service_discovery from vllm_router.services.request_service.rewriter import ( @@ -59,6 +60,7 @@ async def process_request( endpoint, background_tasks: BackgroundTasks, debug_request=None, + cache_info=None, ): """ Process a request by sending it to the chosen backend. @@ -71,7 +73,7 @@ async def process_request( endpoint: The endpoint to send the request to on the backend. debug_request: The original request object from the client, used for optional debug logging. - + cache_info: Cache information. Yields: The response headers and status code, followed by the response content. @@ -82,7 +84,7 @@ async def process_request( total_len = 0 start_time = time.time() request.app.state.request_stats_monitor.on_new_request( - backend_url, request_id, start_time + backend_url, request_id, start_time, cache_info ) # Check if this is a streaming request try: @@ -221,7 +223,8 @@ async def route_general_request( ) engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() request_stats = request.app.state.request_stats_monitor.get_request_stats( - time.time() + time.time(), + [endpoint.url for endpoint in endpoints], ) else: endpoints = list( @@ -241,24 +244,31 @@ async def route_general_request( }, ) + route_result = None + cache_info = None + logger.debug(f"Routing request {request_id} for model: {requested_model}") if request_endpoint: server_url = endpoints[0].url logger.debug( f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" ) - - elif isinstance(request.app.state.router, KvawareRouter) or isinstance( - request.app.state.router, PrefixAwareRouter - ): - server_url = await request.app.state.router.route_request( + elif isinstance(request.app.state.router, (KvawareRouter, PrefixAwareRouter, TtftRouter)): + route_result = await request.app.state.router.route_request( endpoints, engine_stats, request_stats, request, request_json ) else: - server_url = request.app.state.router.route_request( + route_result = request.app.state.router.route_request( endpoints, engine_stats, request_stats, request ) + if isinstance(route_result, (tuple, list)): + server_url = route_result[0] + if len(route_result) > 1: + cache_info = route_result[1] + elif isinstance(route_result, str): + server_url = route_result + curr_time = time.time() # Extract actual session ID from request headers for logging session_key = ( @@ -289,6 +299,7 @@ async def route_general_request( request_id, endpoint, background_tasks, + cache_info=cache_info, ) headers, status = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()} diff --git a/src/vllm_router/stats/request_stats.py b/src/vllm_router/stats/request_stats.py index f0409b912..755fb137d 100644 --- a/src/vllm_router/stats/request_stats.py +++ b/src/vllm_router/stats/request_stats.py @@ -14,7 +14,8 @@ import time from collections import deque from dataclasses import dataclass -from typing import Deque, Dict, Tuple +from numbers import Number +from typing import Deque, Dict, Tuple, Set, List from vllm_router.log import init_logger @@ -53,6 +54,44 @@ class RequestStats: avg_itl: float # Number of swapped requests (moved from GPU to CPU) num_swapped_requests: int + # Engine prefill computation speed + engine_prefill_comp_speed: float + # Uncomputed prefix tokens + uncomputed_prefix_tokens: int + + +class TimePeriods: + """ + Utility for computing length of overlapping time periods. + """ + def __init__(self): + self.periods: List[Tuple[float, float]] = [] + + def union(self, begin: float, end: float): + overlap_periods = [] + for i, period in enumerate(self.periods): + if ((begin >= period[0] and begin <= period[1]) or \ + (end >= period[0] and end <= period[1])) or \ + ((period[0] >= begin and period[0] <= end) or \ + (period[1] >= begin and period[1] <= end)): + self.periods[i] = (min(period[0], begin), max(period[1], end)) + overlap_periods.append(i) + if len(overlap_periods) == 0: + self.periods.append((begin, end)) + return + if len(overlap_periods) == 1: + return + # merge all overlapping periods + merge_begin = min([self.periods[i][0] for i in overlap_periods]) + merge_end = max([self.periods[i][1] for i in overlap_periods]) + + remove_indices = set(overlap_periods) + new_periods = [period for i, period in enumerate(self.periods) if i not in remove_indices] + self.periods = new_periods + self.periods.append((merge_begin, merge_end)) + + def compute_length(self) -> float: + return sum([period[1] - period[0] for period in self.periods]) class MovingAverageMonitor: @@ -103,6 +142,15 @@ def get_sum(self) -> float: return sum(self.values) +class RequestStatsCacheInfo: + """ + Cache information. + """ + def __init__(self): + self.num_prefix_tokens : int = 0 + self.num_cached_tokens : int = 0 + + class RequestStatsMonitor(metaclass=SingletonMeta): """ Monitors the request statistics of all serving engines. @@ -127,6 +175,8 @@ def __init__(self, sliding_window_size: float = None): self.request_start_time: Dict[Tuple[str, str], float] = {} # Record time when first token is received: (engine_url, request_id) -> timestamp self.first_token_time: Dict[Tuple[str, str], float] = {} + # The number of cached prefix tokens + self.cache_infos: Dict[Tuple[str, str], RequestStatsCacheInfo] = {} # Number of requests in different stages (from the start of the router) self.in_prefill_requests: Dict[str, int] = {} @@ -142,7 +192,10 @@ def __init__(self, sliding_window_size: float = None): self.first_query_time: float = None self._initialized = True - def on_new_request(self, engine_url: str, request_id: str, timestamp: float): + def on_new_request(self, engine_url: str, + request_id: str, + timestamp: float, + cache_info: RequestStatsCacheInfo = None): """ Tell the monitor that a new request has been created. @@ -150,9 +203,13 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): engine_url: The URL of the serving engine request_id: The global request ID timestamp: the timestamp when the request was created + cache_info: The cache information """ self.request_start_time[(engine_url, request_id)] = timestamp + if cache_info is not None: + self.cache_infos[(engine_url, request_id)] = cache_info + if engine_url not in self.in_prefill_requests: self.in_prefill_requests[engine_url] = 0 self.in_prefill_requests[engine_url] += 1 @@ -197,7 +254,9 @@ def on_request_response(self, engine_url: str, request_id: str, timestamp: float self.sliding_window_size ) # Update TTFT as time from request start to first token - ttft = timestamp - self.request_start_time[(engine_url, request_id)] + # ttft = timestamp - self.request_start_time[(engine_url, request_id)] + start_time = self.request_start_time[(engine_url, request_id)] + ttft = timestamp - start_time self.ttft_monitors[engine_url].update(timestamp, ttft) def on_request_complete(self, engine_url: str, request_id: str, timestamp: float): @@ -235,12 +294,13 @@ def on_request_swapped(self, engine_url: str, request_id: str, timestamp: float) self.swapped_requests[engine_url] = 0 self.swapped_requests[engine_url] += 1 - def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: + def get_request_stats(self, current_time: float, urls: List[str] = None) -> Dict[str, RequestStats]: """ Get the request statistics for each serving engine Args: current_time: The current timestamp in seconds + urls: The URLs of engines Returns: A dictionary where the key is the serving engine URL and the value @@ -248,10 +308,11 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: The TTFT and inter token latency will be -1 if there is no requests finished in the sliding window. """ + if urls is None: + urls = set(self.in_prefill_requests.keys()).union( + set(self.in_decoding_requests.keys()) + ) ret = {} - urls = set(self.in_prefill_requests.keys()).union( - set(self.in_decoding_requests.keys()) - ) for engine_url in urls: if engine_url not in self.qps_monitors: qps = -1 @@ -289,6 +350,10 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: else: swapped = 0 + + engine_prefill_comp_speed = self._calc_engine_prefill_comp_speed(current_time, engine_url) + uncomputed_prefix_tokens = self._get_uncomputed_prefix_tokens(engine_url) + ret[engine_url] = RequestStats( qps=qps, ttft=ttft, @@ -302,9 +367,46 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: avg_latency=avg_lat, avg_itl=avg_itl_val, num_swapped_requests=swapped, + engine_prefill_comp_speed=engine_prefill_comp_speed, + uncomputed_prefix_tokens=uncomputed_prefix_tokens, ) return ret + def _calc_engine_prefill_comp_speed(self, current_time: float, engine_url: str) -> float: + min_start_time = current_time - self.sliding_window_size + prefill_periods = TimePeriods() + total_comp_amount = 0 + for (url, request_id), start_time in self.request_start_time.items(): + if url != engine_url or start_time < min_start_time: + continue + if ((url, request_id) not in self.first_token_time or + (url, request_id) not in self.cache_infos): + continue + + cache_info = self.cache_infos[(url, request_id)] + computed_tokens = cache_info.num_prefix_tokens - cache_info.num_cached_tokens + if computed_tokens > 0: + prefill_periods.union(start_time, self.first_token_time[(url, request_id)]) + # find computation amount by trapezoid area formula + top = cache_info.num_cached_tokens + bottom = cache_info.num_prefix_tokens - 1 + height = computed_tokens + total_comp_amount += (top + bottom) * height / 2 + + length = prefill_periods.compute_length() + if length > 0: + return total_comp_amount / length + return -1 + + def _get_uncomputed_prefix_tokens(self, engine_url: str) -> int: + uncomputed_prefix_tokens = 0 + for (url, request_id), cache_info in self.cache_infos.items(): + if url != engine_url or (url, request_id) in self.first_token_time: + continue + uncomputed_prefix_tokens += (cache_info.num_prefix_tokens - + cache_info.num_cached_tokens) + return uncomputed_prefix_tokens + def initialize_request_stats_monitor(sliding_window_size: float): return RequestStatsMonitor(sliding_window_size)