Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ class RegisterModelRequest(BaseModel):
persist: bool


class AddModelRequest(BaseModel):
model_type: str
model_json: Dict[str, Any]


class UpdateModelRequest(BaseModel):
model_type: str


class BuildGradioInterfaceRequest(BaseModel):
model_type: str
model_name: str
Expand Down Expand Up @@ -900,6 +909,26 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/models/add",
self.add_model,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/models/update_type",
self.update_model_type,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:add"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/cache/models",
self.list_cached_models,
Expand Down Expand Up @@ -3123,25 +3152,171 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content=None)

async def add_model(self, request: Request) -> JSONResponse:
try:
# Debug: Log incoming request
logger.info(f"[DEBUG] Add model API called")
logger.info(f"[DEBUG] Request headers: {dict(request.headers)}")

# Parse request
raw_json = await request.json()
logger.info(f"[DEBUG] Raw request JSON: {raw_json}")

if "model_type" in raw_json and "model_json" in raw_json:
body = AddModelRequest.parse_obj(raw_json)
model_type = body.model_type
model_json = body.model_json
logger.info(f"[DEBUG] Using wrapped format, model_type: {model_type}")
else:
model_json = raw_json

# Priority 1: Check if model_type is explicitly provided in the JSON
if "model_type" in model_json:
model_type = model_json["model_type"]
logger.info(
f"[DEBUG] Using explicit model_type from JSON: {model_type}"
)
else:
# model_type is required in the JSON when using unwrapped format
logger.error(
f"[DEBUG] model_type not provided in JSON, this is required"
)
raise HTTPException(
status_code=400,
detail="model_type is required in the model JSON. Supported types: LLM, embedding, audio, image, video, rerank",
)

logger.info(f"[DEBUG] Parsed model_type: {model_type}")
logger.info(
f"[DEBUG] Parsed model_json keys: {list(model_json.keys()) if isinstance(model_json, dict) else 'Not a dict'}"
)
if isinstance(model_json, dict):
logger.info(f"[DEBUG] Model JSON content: {model_json}")

# Debug: Check supervisor reference
logger.info(f"[DEBUG] Getting supervisor reference...")
supervisor_ref = await self._get_supervisor_ref()
logger.info(f"[DEBUG] Supervisor reference obtained: {supervisor_ref}")

# Call supervisor
logger.info(
f"[DEBUG] Calling supervisor.add_model with model_type: {model_type}"
)
await supervisor_ref.add_model(model_type, model_json)
logger.info(f"[DEBUG] Supervisor.add_model completed successfully")

except ValueError as re:
logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True)
logger.error(f"[DEBUG] ValueError details: {type(re).__name__}: {re}")
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True
)
logger.error(f"[DEBUG] Error details: {type(e).__name__}: {e}")
import traceback

logger.error(f"[DEBUG] Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))

logger.info(
f"[DEBUG] Add model API completed successfully for model_type: {model_type}"
)
return JSONResponse(
content={"message": f"Model added successfully for type: {model_type}"}
)

async def update_model_type(self, request: Request) -> JSONResponse:
try:
# Parse request
raw_json = await request.json()
logger.info(f"[DEBUG] Update model type API called with: {raw_json}")

body = UpdateModelRequest.parse_obj(raw_json)
model_type = body.model_type

logger.info(f"[DEBUG] Parsed model_type for update: {model_type}")

# Get supervisor reference
supervisor_ref = await self._get_supervisor_ref()

# Call supervisor to update model type
logger.info(
f"[DEBUG] Calling supervisor.update_model_type with model_type: {model_type}"
)
await supervisor_ref.update_model_type(model_type)
logger.info(f"[DEBUG] Supervisor.update_model_type completed successfully")

except ValueError as re:
logger.error(
f"[DEBUG] ValueError in update_model_type API: {re}", exc_info=True
)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG] Unexpected error in update_model_type API: {e}", exc_info=True
)
raise HTTPException(status_code=500, detail=str(e))

logger.info(
f"[DEBUG] Update model type API completed successfully for model_type: {model_type}"
)
return JSONResponse(
content={
"message": f"Model configurations updated successfully for type: {model_type}"
}
)

async def list_model_registrations(
self, model_type: str, detailed: bool = Query(False)
) -> JSONResponse:
try:
logger.info(
f"[DEBUG API] list_model_registrations called with model_type: {model_type}, detailed: {detailed}"
)

data = await (await self._get_supervisor_ref()).list_model_registrations(
model_type, detailed=detailed
)

logger.info(f"[DEBUG API] Raw data from supervisor: {len(data)} items")
for i, item in enumerate(data):
logger.info(
f"[DEBUG API] Item {i}: {item.get('model_name', 'Unknown')} (builtin: {item.get('is_builtin', 'Unknown')})"
)

# Remove duplicate model names.
model_names = set()
final_data = []
for item in data:
if item["model_name"] not in model_names:
model_names.add(item["model_name"])
final_data.append(item)

logger.info(f"[DEBUG API] After deduplication: {len(final_data)} items")
builtin_count = sum(
1 for item in final_data if item.get("is_builtin", False)
)
custom_count = sum(
1 for item in final_data if not item.get("is_builtin", False)
)
logger.info(
f"[DEBUG API] Built-in models: {builtin_count}, Custom models: {custom_count}"
)

return JSONResponse(content=final_data)
except ValueError as re:
logger.error(
f"[DEBUG API] ValueError in list_model_registrations: {re}",
exc_info=True,
)
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(
f"[DEBUG API] Unexpected error in list_model_registrations: {e}",
exc_info=True,
)
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

Expand Down
Loading
Loading