Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
files: xinference
repos:
- repo: https://github.com/psf/black
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.0
hooks:
- id: black
Expand Down
51 changes: 51 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,57 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]:
v["replica"] = self._model_uid_to_replica_info[k].replica
return running_model_info

# Receive model infos of workers
@log_async(logger=logger)
async def sync_models(
self, worker_address: str, model_desc: Dict[str, Dict[str, Any]]
): # model_uid : ModelDescription{"address"}
for replica_model_uid, desc_dict in model_desc.items():
# Rebuild self._replica_model_uid_to_worker
if replica_model_uid in self._replica_model_uid_to_worker:
continue

model_name = desc_dict["model_name"] if "model_name" in desc_dict else ""
model_version = (
desc_dict["model_version"] if "model_version" in desc_dict else ""
)
logger.debug(
f"Receive model replica: {replica_model_uid} {worker_address} {model_name}"
)

assert (
worker_address in self._worker_address_to_worker
), f"Worker {worker_address} not exists when sync_models"

self._replica_model_uid_to_worker[
replica_model_uid
] = self._worker_address_to_worker[worker_address]

# Rebuild self._model_uid_to_replica_info
model_uid, rep_id = parse_replica_model_uid(replica_model_uid)
replica = rep_id + 1
if model_uid not in self._model_uid_to_replica_info:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)
else:
if replica > self._model_uid_to_replica_info[model_uid].replica:
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
replica=replica, scheduler=itertools.cycle(range(replica))
)

# Rebuild self._status_guard_ref
instance_info = InstanceInfo(
model_name=model_name,
model_uid=model_uid,
model_version=model_version,
model_ability=[],
replica=replica,
status=LaunchStatus.READY.name,
instance_created_ts=int(time.time()),
)
await self._status_guard_ref.set_instance_info(model_uid, instance_info)

def is_local_deployment(self) -> bool:
# TODO: temporary.
return (
Expand Down
87 changes: 87 additions & 0 deletions xinference/core/tests/test_restart_supervisor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import pytest
import xoscar as xo
from typing import List, Optional, Union, Dict
import multiprocessing

from ...core.supervisor import SupervisorActor



# test restart supervisor
@pytest.mark.asyncio
async def test_restart_supervisor():
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
from ...deploy.worker import main as _start_worker

def worker_run_in_subprocess(
address: str,
supervisor_address: str,
logging_conf: Optional[Dict] = None
) -> multiprocessing.Process:
p = multiprocessing.Process(target=_start_worker, args=(address, supervisor_address, None, None, logging_conf))
p.start()
return p

# start supervisor
supervisor_address = f"localhost:{xo.utils.get_next_port()}"
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)

await asyncio.sleep(5)

# start worker
worker_run_in_subprocess(
address=f"localhost:{xo.utils.get_next_port()}",
supervisor_address=supervisor_address
)

await asyncio.sleep(10)

# load model
supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)

model_uid = "qwen1.5-chat"
await supervisor_ref.launch_builtin_model(
model_uid=model_uid,
model_name="qwen1.5-chat",
model_size_in_billions="0_5",
quantization="q4_0",
model_engine="vLLM"
)

# query replica info
model_replica_info = await supervisor_ref.describe_model(model_uid)

# kill supervisor
proc_supervisor.terminate()
proc_supervisor.join()

# restart supervisor
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)

await asyncio.sleep(5)

supervisor_ref = await xo.actor_ref(
supervisor_address, SupervisorActor.default_uid()
)

# check replica info
model_replic_info_check = await supervisor_ref.describe_model(model_uid)

assert model_replica_info["replica"] == model_replic_info_check["replica"]
14 changes: 14 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ async def get_supervisor_ref(self, add_worker: bool = True) -> xo.ActorRefType:
await self._supervisor_ref.add_worker(self.address)
logger.info("Connected to supervisor as a fresh worker")

# Reconnect to Newly started supervisor, has running models
if add_worker and len(self._model_uid_to_model) > 0:
# Reconnect to Newly started supervisor, notify supervisor
await self._supervisor_ref.add_worker(self.address)
# Sync replica model infos
running_models = {}
running_models.update(await self.list_models())
await self._supervisor_ref.sync_models(self.address, running_models)
logger.info(
f"Connected to supervisor as a old worker with {len(running_models)} models"
)

self._status_guard_ref = await xo.actor_ref(
address=self._supervisor_address, uid=StatusGuardActor.default_uid()
)
Expand Down Expand Up @@ -1049,6 +1061,8 @@ async def _periodical_report_status(self):
except (
Exception
) as ex: # pragma: no cover # noqa: E722 # nosec # pylint: disable=bare-except
# Disconnect from supervisor, which maybe restart
self._supervisor_ref = None
logger.error(f"Failed to upload node info: {ex}")
try:
await asyncio.sleep(XINFERENCE_HEALTH_CHECK_INTERVAL)
Expand Down
Loading