Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class RouterType(str, Enum):
SGLANG = "sglang"
DYNAMO = "dynamo"


class SGLangGatewayRouterConfig(CoreModel):
Expand Down Expand Up @@ -45,8 +46,15 @@ class SGLangServiceRouterConfig(CoreModel):

class ReplicaGroupRouterConfig(CoreModel):
type: Annotated[
Literal["sglang"],
Field(description="The router implementation for this replica group."),
Literal["sglang", "dynamo"],
Field(
description=(
"The router implementation for this replica group. "
"`sglang` runs the SGLang router and dstack syncs worker URLs to it. "
"`dynamo` runs the NVIDIA Dynamo frontend, which discovers workers "
"itself via etcd/NATS."
),
),
] = "sglang"


Expand Down
26 changes: 26 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from dstack._internal.core.models.repos import AnyRunRepoData
from dstack._internal.core.models.resources import Memory, ResourcesSpec
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint
from dstack._internal.utils import common as common_utils
Expand Down Expand Up @@ -603,6 +604,31 @@ def _merged_profile(cls, values) -> Dict:
values["merged_profile"] = merged_profile
return values

@root_validator
def _validate_dynamo_no_retry(cls, values) -> Dict:
"""Reject `retry` for services with a Dynamo router replica group.
Dynamo workers cache the router's internal IP at provisioning time. A
retry would produce a new router and likely a new internal_ip, leaving workers bound
to a router that no longer exists.
"""
merged_profile = values.get("merged_profile")
cfg = values.get("configuration")
if merged_profile is None or merged_profile.retry is None:
return values
if not isinstance(cfg, ServiceConfiguration):
return values
for g in cfg.replica_groups:
if g.router is not None and g.router.type == RouterType.DYNAMO:
raise ValueError(
"Retry cannot be configured for services with a Dynamo "
"router replica group. The router's address must remain "
"stable for the life of the run; allowing retry would "
"leave workers bound to a router that no longer exists. "
"Remove `retry` from the profile/configuration and "
"re-apply."
)
return values


class ServiceModelSpec(CoreModel):
name: str
Expand Down
18 changes: 9 additions & 9 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ServiceConfig,
)
from dstack._internal.proxy.lib import models
from dstack._internal.proxy.lib.const import SGLANG_WHITELISTED_PATHS
from dstack._internal.proxy.lib.const import ROUTER_WHITELISTED_PATHS
from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.proxy.lib.services.service_connection import (
Expand Down Expand Up @@ -344,7 +344,7 @@ async def get_nginx_service_config(
) -> ServiceConfig:
limit_req_zones: list[LimitReqZoneConfig] = []
locations: list[LocationConfig] = []
is_sglang = (
is_router = (
service.router is not None and service.router.type == RouterType.SGLANG
) or service.has_router_replica
sglang_limits: dict[str, LimitReqConfig] = {}
Expand All @@ -361,8 +361,8 @@ async def get_nginx_service_config(
limit_req_zones.append(
LimitReqZoneConfig(name=zone_name, key=key, rpm=round(rate_limit.rps * 60))
)
if is_sglang:
for path in SGLANG_WHITELISTED_PATHS:
if is_router:
for path in ROUTER_WHITELISTED_PATHS:
if rate_limit.prefix == path or path.startswith(rate_limit.prefix):
# Use the longest prefix if multiple prefixes match the same path
current_prefix_len = len(rate_limit.prefix)
Expand All @@ -381,9 +381,9 @@ async def get_nginx_service_config(
)
)

# Add SGLang whitelisted paths as locations
if is_sglang:
for path in SGLANG_WHITELISTED_PATHS:
# Add router whitelisted paths as locations
if is_router:
for path in ROUTER_WHITELISTED_PATHS:
# Use prefix match for paths that end with a slash and exact match for paths that don't
if path.endswith("/"):
locations.append(LocationConfig(prefix=path, limit_req=sglang_limits.get(path)))
Expand All @@ -392,8 +392,8 @@ async def get_nginx_service_config(
LocationConfig(prefix=f"= {path}", limit_req=sglang_limits.get(path))
)

# Don't auto-add / location for SGLang routers (catch-all 403 handles it)
if not any(location.prefix == "/" for location in locations) and not is_sglang:
# Don't auto-add / location for router-based services (catch-all 403 handles it)
if not any(location.prefix == "/" for location in locations) and not is_router:
locations.append(LocationConfig(prefix="/", limit_req=None))
return ServiceConfig(
domain=service.domain_safe,
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/proxy/lib/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Shared constants for proxy components (gateway + in-server proxy).
"""

SGLANG_WHITELISTED_PATHS: tuple[str, ...] = (
# Inference endpoints exposed by the in-replica HTTP router. Applies to both
# SGLang's router and Dynamo's `dynamo.frontend` — they share the
# OpenAI-compatible endpoint surface.
ROUTER_WHITELISTED_PATHS: tuple[str, ...] = (
"/generate",
"/v1/",
"/chat/completions",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dstack._internal.core.models.metrics import Metric
from dstack._internal.core.models.profiles import StartupOrder
from dstack._internal.core.models.repos import RemoteRepoCreds
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.runs import (
ClusterInfo,
ImagePullProgress,
Expand Down Expand Up @@ -102,6 +103,13 @@
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.services.runs import is_job_ready, run_model_to_run
from dstack._internal.server.services.runs.replicas import (
ROUTER_FAILED,
ROUTER_NOT_PROVISIONED,
get_router_env_for_job,
get_router_replica_group,
get_router_replica_num,
)
from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.storage import get_default_storage
from dstack._internal.server.utils import sentry_utils
Expand All @@ -114,6 +122,8 @@

JOB_STATUSES_WITH_MIN_PROCESSING_INTERVAL = [JobStatus.PROVISIONING, JobStatus.PULLING]

ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS = 30 * 60

JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2)
"""`The minimum time before terminating active job in case of connectivity issues."""

Expand Down Expand Up @@ -384,8 +394,12 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce
job_submissions=[job_model_to_job_submission(job_model)],
)
else:
# PROVISIONING/PULLING jobs need same-replica siblings for cluster coordination.
# All sibling access is replica-scoped, so only load jobs for this replica.
# PROVISIONING/PULLING jobs need same-replica siblings for cluster
# coordination, plus — when the run has a router replica group —
# the router replica's job (cross-replica) so the env-injection
# gate in _prepare_startup_context can read its status / IP.
# _fetch_run_model handles both: same-replica jobs always, plus
# the router replica's job when one exists.
run_model = await _fetch_run_model(
session=session, run_id=job_model.run_id, replica_num=item.replica_num
)
Expand Down Expand Up @@ -477,6 +491,54 @@ async def _prepare_startup_context(
)
return None

# If this run has a router replica group and this job is a worker, gate
# startup on the router replica's state. The helper returns None for the
# router itself and for runs without a router group, so this whole block
# is a no-op in those cases.
router_env = get_router_env_for_job(
run_model=context.run_model,
run_spec=context.run.run_spec,
job_model=context.job_model,
)
if router_env is ROUTER_FAILED:
# Router has reached a terminal state — the worker cannot recover by
# waiting. Terminate it now with a clear reason instead of letting it
# idle until the run-level reconciler tears the whole run down.
_terminate_job(
job_model=context.job_model,
job_update_map=result.job_update_map,
termination_reason=JobTerminationReason.TERMINATED_BY_SERVER,
termination_reason_message=(
"Router replica is in a terminal state; cannot provision worker "
"without a running router."
),
)
return None
if router_env is ROUTER_NOT_PROVISIONED:
# Router is alive but its internal_ip is not yet known. Defer this
# worker — the next pipeline tick will re-check. Bound the wait so a
# router that is genuinely stuck can't burn worker instance-hours
# forever; see ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS.
waited_seconds = (get_current_datetime() - context.job_model.submitted_at).total_seconds()
if waited_seconds > ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS:
_terminate_job(
job_model=context.job_model,
job_update_map=result.job_update_map,
termination_reason=JobTerminationReason.TERMINATED_BY_SERVER,
termination_reason_message=(
f"Router replica did not acquire an internal IP within "
f"{ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS}s; terminating worker."
),
)
return None
logger.debug(
"%s: waiting for router replica to be provisioned",
fmt(context.job_model),
)
return None
if router_env:
context.job.job_spec.env.update(router_env)

cluster_info = _get_cluster_info(
jobs=context.run.jobs,
replica_num=context.job.job_spec.replica_num,
Expand Down Expand Up @@ -549,7 +611,10 @@ async def _fetch_run_model(
Args:
replica_num: If None, skip loading jobs (for RUNNING jobs that don't need siblings).
If set, load only latest-submission jobs for that replica (for PROVISIONING/PULLING
jobs that need same-replica siblings for cluster coordination).
jobs that need same-replica siblings for cluster coordination). When the run has
a router replica group whose replica_num differs from this one, that replica's
jobs are also loaded so cross-replica router lookups (see get_router_env_for_job
in services/runs/replicas.py) can find it.
"""
query = (
select(RunModel)
Expand All @@ -560,14 +625,39 @@ async def _fetch_run_model(
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
)
if replica_num is not None:
# Pre-fetch the bare run_spec to discover whether the run has a
# router replica group, and if so at which replica_num. The query
# below then includes both this replica AND the router replica
# For runs without a router group (services without one, plus
# all tasks and dev-environments), the helper returns None and we
# fall through to the original single-replica behavior.
spec_res = await session.execute(select(RunModel.run_spec).where(RunModel.id == run_id))
run_spec_str = spec_res.scalar_one()
run_spec = RunSpec.__response__.parse_raw(run_spec_str)
# The router pre-fetch only exists to feed get_router_env_for_job,
# which is gated to Dynamo. Skip it for SGLang and non-router runs.
router_group = get_router_replica_group(run_spec)
if (
router_group is not None
and router_group.router is not None
and router_group.router.type == RouterType.DYNAMO
):
router_replica_num = get_router_replica_num(run_spec)
else:
router_replica_num = None

replica_nums: list[int] = [replica_num]
if router_replica_num is not None and router_replica_num != replica_num:
replica_nums.append(router_replica_num)

latest_submissions_sq = (
select(
JobModel.run_id.label("run_id"),
JobModel.replica_num.label("replica_num"),
JobModel.job_num.label("job_num"),
func.max(JobModel.submission_num).label("max_submission_num"),
)
.where(JobModel.run_id == run_id, JobModel.replica_num == replica_num)
.where(JobModel.run_id == run_id, JobModel.replica_num.in_(replica_nums))
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
.subquery()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.services.runs.router_worker_sync import (
run_model_has_router_replica_group,
run_model_has_sglang_router_replica_group,
sync_router_workers_for_run_model,
)
from dstack._internal.server.utils import sentry_utils
Expand Down Expand Up @@ -212,7 +212,7 @@ async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
run_model.deleted
or run_model.status.is_finished()
or run_model.status != RunStatus.RUNNING
or not run_model_has_router_replica_group(run_model)
or not run_model_has_sglang_router_replica_group(run_model)
):
early_cleanup_update_map: _SyncRowUpdateMap = {"deleted": True}
set_processed_update_map_fields(early_cleanup_update_map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette.requests import ClientDisconnect

from dstack._internal.core.models.routers import RouterType
from dstack._internal.proxy.lib.const import SGLANG_WHITELISTED_PATHS
from dstack._internal.proxy.lib.const import ROUTER_WHITELISTED_PATHS
from dstack._internal.proxy.lib.deps import ProxyAuthContext
from dstack._internal.proxy.lib.errors import ProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
Expand Down Expand Up @@ -45,7 +45,7 @@ async def proxy(
service.router is not None and service.router.type == RouterType.SGLANG
) or service.has_router_replica:
path_for_match = path if path.startswith("/") else f"/{path}"
if not _is_whitelisted_path(path_for_match, SGLANG_WHITELISTED_PATHS):
if not _is_whitelisted_path(path_for_match, ROUTER_WHITELISTED_PATHS):
raise ProxyError("Path is not allowed for this service", status.HTTP_403_FORBIDDEN)

client = await get_service_replica_client(service, repo, service_conn_pool)
Expand Down
Loading
Loading