-
Notifications
You must be signed in to change notification settings - Fork 600
[Feature][main]reconstruction kvpool connector to ascend connector #4438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: fems14 <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the kvpool connector, renaming it to ascend connector and introducing a pluggable backend architecture. The changes are extensive, involving file renaming, code movement, and logic updates to support different storage backends like Mooncake and Memcache. While the refactoring improves modularity, I've identified several critical issues related to error handling, resource management, and potential runtime errors that need to be addressed to ensure the stability and correctness of the new implementation.
| def get(self, key: list[str], addr: list[list[int]], | ||
| size: list[list[int]]): | ||
| try: | ||
| res = self.store.batch_get_into_layers(key, addr, size, | ||
| MmcDirect.COPY_G2L.value) | ||
| for value in res: | ||
| if value != 0: | ||
| logger.error(f"Failed to get key {key},res:{res}") | ||
| except Exception as e: | ||
| logger.error(f"Failed to get key {key}. {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get method catches all exceptions and only logs them as errors, without re-raising. If fetching the KV cache from memcache fails, the model will proceed with uninitialized or incorrect data, leading to silent correctness issues and wrong model outputs. This is a critical bug.
The failure to retrieve cache data should be propagated up by raising an exception to ensure the system can handle the failure appropriately, rather than continuing with corrupted state. Additionally, the log message should be more specific about which key failed.
| def get(self, key: list[str], addr: list[list[int]], | |
| size: list[list[int]]): | |
| try: | |
| res = self.store.batch_get_into_layers(key, addr, size, | |
| MmcDirect.COPY_G2L.value) | |
| for value in res: | |
| if value != 0: | |
| logger.error(f"Failed to get key {key},res:{res}") | |
| except Exception as e: | |
| logger.error(f"Failed to get key {key}. {e}") | |
| def get(self, key: list[str], addr: list[list[int]], | |
| size: list[list[int]]): | |
| try: | |
| res = self.store.batch_get_into_layers(key, addr, size, | |
| MmcDirect.COPY_G2L.value) | |
| for i, value in enumerate(res): | |
| if value != 0: | |
| error_msg = f"Failed to get key {key[i]} from memcache, error code: {value}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| except Exception as e: | |
| logger.error(f"Exception during memcache get for keys {key}: {e}") | |
| raise |
| if self.config.protocol == "ascend": | ||
| local_hostname = get_ip() | ||
| transfer_engine = global_te.get_transfer_engine(local_hostname, | ||
| device_name=None) | ||
| self.local_seg = local_hostname + ":" + str( | ||
| transfer_engine.get_rpc_port()) | ||
| ret = self.store.setup(self.local_seg, self.config.metadata_server, | ||
| self.config.global_segment_size, | ||
| self.config.local_buffer_size, | ||
| self.config.protocol, | ||
| self.config.device_name, | ||
| self.config.master_server_address, | ||
| transfer_engine.get_engine()) | ||
| if ret != 0: | ||
| msg = "Initialize mooncake failed." | ||
| logger.error(msg) | ||
| raise RuntimeError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method has a potential UnboundLocalError. The ret variable is only assigned inside the if self.config.protocol == "ascend": block. If the protocol is anything else, ret will not be defined when it's checked on line 52, causing a crash.
If only the 'ascend' protocol is supported, you should add an else block to raise a NotImplementedError or ValueError to make this explicit and prevent runtime errors. If other protocols are intended to be supported, their setup logic is missing.
| if self.config.protocol == "ascend": | |
| local_hostname = get_ip() | |
| transfer_engine = global_te.get_transfer_engine(local_hostname, | |
| device_name=None) | |
| self.local_seg = local_hostname + ":" + str( | |
| transfer_engine.get_rpc_port()) | |
| ret = self.store.setup(self.local_seg, self.config.metadata_server, | |
| self.config.global_segment_size, | |
| self.config.local_buffer_size, | |
| self.config.protocol, | |
| self.config.device_name, | |
| self.config.master_server_address, | |
| transfer_engine.get_engine()) | |
| if ret != 0: | |
| msg = "Initialize mooncake failed." | |
| logger.error(msg) | |
| raise RuntimeError(msg) | |
| if self.config.protocol == "ascend": | |
| local_hostname = get_ip() | |
| transfer_engine = global_te.get_transfer_engine(local_hostname, | |
| device_name=None) | |
| self.local_seg = local_hostname + ":" + str( | |
| transfer_engine.get_rpc_port()) | |
| ret = self.store.setup(self.local_seg, self.config.metadata_server, | |
| self.config.global_segment_size, | |
| self.config.local_buffer_size, | |
| self.config.protocol, | |
| self.config.device_name, | |
| self.config.master_server_address, | |
| transfer_engine.get_engine()) | |
| else: | |
| raise NotImplementedError(f"Protocol '{self.config.protocol}' is not supported in MooncakeBackend.") | |
| if ret != 0: | |
| msg = f"Initialize mooncake failed with return code: {ret}." | |
| logger.error(msg) | |
| raise RuntimeError(msg) |
| def get(self, keys: list[str], addrs: list[list[int]], | ||
| sizes: list[list[int]]): | ||
| try: | ||
| res = self.store.batch_get_into_multi_buffers( | ||
| keys, addrs, sizes, True) | ||
| for value in res: | ||
| if value < 0: | ||
| logger.error(f"Failed to get key {keys}, res:{res}") | ||
| except Exception as e: | ||
| logger.error(f"Failed to get key {keys}, error:{e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get method catches all exceptions and only logs them, which is a critical issue. If retrieving the KV cache from Mooncake fails, the model will proceed with uninitialized or garbage data, leading to incorrect outputs. This is a silent data corruption bug. Failures in retrieving cache data must be propagated by raising an exception. The log message should also be improved to pinpoint the failing key.
| def get(self, keys: list[str], addrs: list[list[int]], | |
| sizes: list[list[int]]): | |
| try: | |
| res = self.store.batch_get_into_multi_buffers( | |
| keys, addrs, sizes, True) | |
| for value in res: | |
| if value < 0: | |
| logger.error(f"Failed to get key {keys}, res:{res}") | |
| except Exception as e: | |
| logger.error(f"Failed to get key {keys}, error:{e}") | |
| def get(self, keys: list[str], addrs: list[list[int]], | |
| sizes: list[list[int]]): | |
| try: | |
| res = self.store.batch_get_into_multi_buffers( | |
| keys, addrs, sizes, True) | |
| for i, value in enumerate(res): | |
| if value < 0: | |
| error_msg = f"Failed to get key {keys[i]} from mooncake, error code: {value}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| except Exception as e: | |
| logger.error(f"Exception during mooncake get for keys {keys}: {e}") | |
| raise |
| key_list_c = key_list[self.tp_rank % | ||
| len(key_list):] + key_list[:self.tp_rank % | ||
| len(key_list)] | ||
| addr_list_c = addr_list[self.tp_rank % | ||
| len(addr_list):] + addr_list[:self.tp_rank % | ||
| len(addr_list)] | ||
| size_list_c = size_list[self.tp_rank % | ||
| len(size_list):] + size_list[:self.tp_rank % | ||
| len(size_list)] | ||
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | ||
| self.set_finished_request(req_id) | ||
| self.request_queue.task_done() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential ZeroDivisionError here. If self.token_database.process_tokens yields no items, key_list will be empty. Consequently, len(key_list) will be 0, and the modulo operation self.tp_rank % len(key_list) will cause a ZeroDivisionError, crashing the thread.
You should add a check to ensure key_list is not empty before performing the cyclic shift.
| key_list_c = key_list[self.tp_rank % | |
| len(key_list):] + key_list[:self.tp_rank % | |
| len(key_list)] | |
| addr_list_c = addr_list[self.tp_rank % | |
| len(addr_list):] + addr_list[:self.tp_rank % | |
| len(addr_list)] | |
| size_list_c = size_list[self.tp_rank % | |
| len(size_list):] + size_list[:self.tp_rank % | |
| len(size_list)] | |
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | |
| self.set_finished_request(req_id) | |
| self.request_queue.task_done() | |
| if key_list: | |
| key_list_c = key_list[self.tp_rank % | |
| len(key_list):] + key_list[:self.tp_rank % | |
| len(key_list)] | |
| addr_list_c = addr_list[self.tp_rank % | |
| len(addr_list):] + addr_list[:self.tp_rank % | |
| len(addr_list)] | |
| size_list_c = size_list[self.tp_rank % | |
| len(size_list):] + size_list[:self.tp_rank % | |
| len(size_list)] | |
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | |
| self.set_finished_request(req_id) | |
| self.request_queue.task_done() |
| key_list_c = key_list[self.tp_rank % | ||
| len(key_list):] + key_list[:self.tp_rank % | ||
| len(key_list)] | ||
| addr_list_c = addr_list[self.tp_rank % | ||
| len(addr_list):] + addr_list[:self.tp_rank % | ||
| len(addr_list)] | ||
| size_list_c = size_list[self.tp_rank % | ||
| len(size_list):] + size_list[:self.tp_rank % | ||
| len(size_list)] | ||
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | ||
|
|
||
| self.request_queue.task_done() | ||
| self.get_event.set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to KVCacheStoreRecvingThread, this method is vulnerable to a ZeroDivisionError. If req_meta.keys is empty, key_list will be empty, and len(key_list) will be zero, causing a crash when the modulo operator is used. Please add a guard to handle the case where key_list is empty.
| key_list_c = key_list[self.tp_rank % | |
| len(key_list):] + key_list[:self.tp_rank % | |
| len(key_list)] | |
| addr_list_c = addr_list[self.tp_rank % | |
| len(addr_list):] + addr_list[:self.tp_rank % | |
| len(addr_list)] | |
| size_list_c = size_list[self.tp_rank % | |
| len(size_list):] + size_list[:self.tp_rank % | |
| len(size_list)] | |
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | |
| self.request_queue.task_done() | |
| self.get_event.set() | |
| if key_list: | |
| key_list_c = key_list[self.tp_rank % | |
| len(key_list):] + key_list[:self.tp_rank % | |
| len(key_list)] | |
| addr_list_c = addr_list[self.tp_rank % | |
| len(addr_list):] + addr_list[:self.tp_rank % | |
| len(addr_list)] | |
| size_list_c = size_list[self.tp_rank % | |
| len(size_list):] + size_list[:self.tp_rank % | |
| len(size_list)] | |
| self.m_store.get(key_list_c, addr_list_c, size_list_c) | |
| self.request_queue.task_done() | |
| self.get_event.set() |
| class LookupKeyServer: | ||
|
|
||
| def __init__( | ||
| self, | ||
| pool_worker: KVPoolWorker, | ||
| vllm_config: "VllmConfig", | ||
| use_layerwise: bool, | ||
| ): | ||
| self.decoder = MsgpackDecoder() | ||
| self.decoder_tensor = MsgpackDecoder(torch.Tensor) | ||
| self.ctx = zmq.Context() # type: ignore[attr-defined] | ||
| socket_path = get_zmq_rpc_path_lookup(vllm_config) | ||
| self.socket = make_zmq_socket( | ||
| self.ctx, | ||
| socket_path, | ||
| zmq.REP, # type: ignore[attr-defined] | ||
| bind=True, | ||
| ) | ||
|
|
||
| self.pool_worker = pool_worker | ||
| self.running = True | ||
| self.use_layerwise = use_layerwise | ||
|
|
||
| def process_request(): | ||
| while self.running: | ||
| all_frames = self.socket.recv_multipart(copy=False) | ||
| token_len = int.from_bytes(all_frames[0], byteorder="big") | ||
| hash_frames = all_frames[1:] | ||
| hashes_str = self.decoder.decode(hash_frames) | ||
| result = self.pool_worker.lookup_scheduler( | ||
| token_len, hashes_str, self.use_layerwise) | ||
| response = result.to_bytes(4, "big") | ||
| self.socket.send(response) | ||
|
|
||
| self.thread = threading.Thread(target=process_request, daemon=True) | ||
| self.thread.start() | ||
|
|
||
| def close(self): | ||
| self.socket.close(linger=0) | ||
| # TODO: close the thread! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LookupKeyServer does not handle thread termination gracefully. The close method has a TODO and only closes the socket, which will cause an unhandled zmq.ZMQError in the background thread, leading to a crash rather than a clean shutdown. This can cause resource leaks or unpredictable behavior.
To fix this, I recommend the following:
- Move the
process_requestlogic into a class method (e.g.,_process_request). - Add a
try...exceptblock within the_process_requestloop to catchzmq.ZMQErrorand exit gracefully when the socket is closed. - Update the
closemethod to setself.running = False, close the socket, and thenjoin()the thread to ensure it has terminated.
class LookupKeyServer:
def __init__(
self,
pool_worker: KVPoolWorker,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder()
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.pool_worker = pool_worker
self.running = True
self.use_layerwise = use_layerwise
self.thread = threading.Thread(target=self._process_request, daemon=True)
self.thread.start()
def _process_request(self):
while self.running:
try:
all_frames = self.socket.recv_multipart(copy=False)
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
except zmq.error.ZMQError:
# The socket is closed, exit the loop.
break
def close(self):
self.running = False
self.socket.close(linger=0)
self.thread.join()| def put(self, key: list[str], addr: list[list[int]], | ||
| size: list[list[int]]): | ||
| try: | ||
| res = self.store.batch_put_from_layers(key, addr, size, | ||
| MmcDirect.COPY_L2G.value) | ||
| for value in res: | ||
| if value != 0: | ||
| logger.error(f"Failed to get key {key},res:{res}") | ||
| except Exception as e: | ||
| logger.error(f"Failed to put key {key},error:{e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the get method, the put method silently fails by catching all exceptions and only logging them. If saving the KV cache fails, it won't be available for future requests, leading to performance degradation as the prefix will be recomputed. This should be a hard failure to alert operators to a problem with the caching backend.
Exceptions should be raised to ensure that storage failures are not ignored. The log message should also be improved to indicate the specific key that failed.
| def put(self, key: list[str], addr: list[list[int]], | |
| size: list[list[int]]): | |
| try: | |
| res = self.store.batch_put_from_layers(key, addr, size, | |
| MmcDirect.COPY_L2G.value) | |
| for value in res: | |
| if value != 0: | |
| logger.error(f"Failed to get key {key},res:{res}") | |
| except Exception as e: | |
| logger.error(f"Failed to put key {key},error:{e}") | |
| def put(self, key: list[str], addr: list[list[int]], | |
| size: list[list[int]]): | |
| try: | |
| res = self.store.batch_put_from_layers(key, addr, size, | |
| MmcDirect.COPY_L2G.value) | |
| for i, value in enumerate(res): | |
| if value != 0: | |
| error_msg = f"Failed to put key {key[i]} to memcache, error code: {value}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| except Exception as e: | |
| logger.error(f"Exception during memcache put for keys {key}: {e}") | |
| raise |
| def put(self, keys: list[str], addrs: list[list[int]], | ||
| sizes: list[list[int]]): | ||
| try: | ||
| config = ReplicateConfig() | ||
| # config.preferred_segment = self.local_seg | ||
| config.prefer_alloc_in_same_node = True | ||
| res = self.store.batch_put_from_multi_buffers( | ||
| keys, addrs, sizes, config) | ||
| for value in res: | ||
| if value < 0: | ||
| logger.error(f"Failed to put key {keys},res:{res}") | ||
| except Exception as e: | ||
| logger.error(f"Failed to put key {keys},error:{e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The put method catches all exceptions and only logs them, which can lead to silent failures. If saving to Mooncake fails, the KV cache for a prefix will not be stored, causing performance degradation in future requests as it will need to be recomputed. This failure should be propagated by raising an exception to ensure system administrators are aware of issues with the caching backend. The log message should also be improved to pinpoint the failing key.
| def put(self, keys: list[str], addrs: list[list[int]], | |
| sizes: list[list[int]]): | |
| try: | |
| config = ReplicateConfig() | |
| # config.preferred_segment = self.local_seg | |
| config.prefer_alloc_in_same_node = True | |
| res = self.store.batch_put_from_multi_buffers( | |
| keys, addrs, sizes, config) | |
| for value in res: | |
| if value < 0: | |
| logger.error(f"Failed to put key {keys},res:{res}") | |
| except Exception as e: | |
| logger.error(f"Failed to put key {keys},error:{e}") | |
| def put(self, keys: list[str], addrs: list[list[int]], | |
| sizes: list[list[int]]): | |
| try: | |
| config = ReplicateConfig() | |
| # config.preferred_segment = self.local_seg | |
| config.prefer_alloc_in_same_node = True | |
| res = self.store.batch_put_from_multi_buffers( | |
| keys, addrs, sizes, config) | |
| for i, value in enumerate(res): | |
| if value < 0: | |
| error_msg = f"Failed to put key {keys[i]} to mooncake, error code: {value}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) | |
| except Exception as e: | |
| logger.error(f"Exception during mooncake put for keys {keys}: {e}") | |
| raise |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?