diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 84c7b18d80..05ce33df61 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -198,6 +198,15 @@ class RegisterModelRequest(BaseModel): persist: bool +class AddModelRequest(BaseModel): + model_type: str + model_json: Dict[str, Any] + + +class UpdateModelRequest(BaseModel): + model_type: str + + class BuildGradioInterfaceRequest(BaseModel): model_type: str model_name: str @@ -900,6 +909,26 @@ async def internal_exception_handler(request: Request, exc: Exception): else None ), ) + self._router.add_api_route( + "/v1/models/add", + self.add_model, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) + self._router.add_api_route( + "/v1/models/update_type", + self.update_model_type, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) self._router.add_api_route( "/v1/cache/models", self.list_cached_models, @@ -3123,13 +3152,139 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon raise HTTPException(status_code=500, detail=str(e)) return JSONResponse(content=None) + async def add_model(self, request: Request) -> JSONResponse: + try: + # Debug: Log incoming request + logger.info(f"[DEBUG] Add model API called") + logger.info(f"[DEBUG] Request headers: {dict(request.headers)}") + + # Parse request + raw_json = await request.json() + logger.info(f"[DEBUG] Raw request JSON: {raw_json}") + + if "model_type" in raw_json and "model_json" in raw_json: + body = AddModelRequest.parse_obj(raw_json) + model_type = body.model_type + model_json = body.model_json + logger.info(f"[DEBUG] Using wrapped format, model_type: {model_type}") + else: + model_json = raw_json + + # Priority 1: Check if model_type is explicitly provided in the JSON + if "model_type" in model_json: + model_type = model_json["model_type"] + logger.info( + f"[DEBUG] Using explicit model_type from JSON: {model_type}" + ) + else: + # model_type is required in the JSON when using unwrapped format + logger.error( + f"[DEBUG] model_type not provided in JSON, this is required" + ) + raise HTTPException( + status_code=400, + detail="model_type is required in the model JSON. Supported types: LLM, embedding, audio, image, video, rerank", + ) + + logger.info(f"[DEBUG] Parsed model_type: {model_type}") + logger.info( + f"[DEBUG] Parsed model_json keys: {list(model_json.keys()) if isinstance(model_json, dict) else 'Not a dict'}" + ) + if isinstance(model_json, dict): + logger.info(f"[DEBUG] Model JSON content: {model_json}") + + # Debug: Check supervisor reference + logger.info(f"[DEBUG] Getting supervisor reference...") + supervisor_ref = await self._get_supervisor_ref() + logger.info(f"[DEBUG] Supervisor reference obtained: {supervisor_ref}") + + # Call supervisor + logger.info( + f"[DEBUG] Calling supervisor.add_model with model_type: {model_type}" + ) + await supervisor_ref.add_model(model_type, model_json) + logger.info(f"[DEBUG] Supervisor.add_model completed successfully") + + except ValueError as re: + logger.error(f"[DEBUG] ValueError in add_model API: {re}", exc_info=True) + logger.error(f"[DEBUG] ValueError details: {type(re).__name__}: {re}") + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error( + f"[DEBUG] Unexpected error in add_model API: {e}", exc_info=True + ) + logger.error(f"[DEBUG] Error details: {type(e).__name__}: {e}") + import traceback + + logger.error(f"[DEBUG] Full traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(e)) + + logger.info( + f"[DEBUG] Add model API completed successfully for model_type: {model_type}" + ) + return JSONResponse( + content={"message": f"Model added successfully for type: {model_type}"} + ) + + async def update_model_type(self, request: Request) -> JSONResponse: + try: + # Parse request + raw_json = await request.json() + logger.info(f"[DEBUG] Update model type API called with: {raw_json}") + + body = UpdateModelRequest.parse_obj(raw_json) + model_type = body.model_type + + logger.info(f"[DEBUG] Parsed model_type for update: {model_type}") + + # Get supervisor reference + supervisor_ref = await self._get_supervisor_ref() + + # Call supervisor to update model type + logger.info( + f"[DEBUG] Calling supervisor.update_model_type with model_type: {model_type}" + ) + await supervisor_ref.update_model_type(model_type) + logger.info(f"[DEBUG] Supervisor.update_model_type completed successfully") + + except ValueError as re: + logger.error( + f"[DEBUG] ValueError in update_model_type API: {re}", exc_info=True + ) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error( + f"[DEBUG] Unexpected error in update_model_type API: {e}", exc_info=True + ) + raise HTTPException(status_code=500, detail=str(e)) + + logger.info( + f"[DEBUG] Update model type API completed successfully for model_type: {model_type}" + ) + return JSONResponse( + content={ + "message": f"Model configurations updated successfully for type: {model_type}" + } + ) + async def list_model_registrations( self, model_type: str, detailed: bool = Query(False) ) -> JSONResponse: try: + logger.info( + f"[DEBUG API] list_model_registrations called with model_type: {model_type}, detailed: {detailed}" + ) + data = await (await self._get_supervisor_ref()).list_model_registrations( model_type, detailed=detailed ) + + logger.info(f"[DEBUG API] Raw data from supervisor: {len(data)} items") + for i, item in enumerate(data): + logger.info( + f"[DEBUG API] Item {i}: {item.get('model_name', 'Unknown')} (builtin: {item.get('is_builtin', 'Unknown')})" + ) + # Remove duplicate model names. model_names = set() final_data = [] @@ -3137,11 +3292,31 @@ async def list_model_registrations( if item["model_name"] not in model_names: model_names.add(item["model_name"]) final_data.append(item) + + logger.info(f"[DEBUG API] After deduplication: {len(final_data)} items") + builtin_count = sum( + 1 for item in final_data if item.get("is_builtin", False) + ) + custom_count = sum( + 1 for item in final_data if not item.get("is_builtin", False) + ) + logger.info( + f"[DEBUG API] Built-in models: {builtin_count}, Custom models: {custom_count}" + ) + return JSONResponse(content=final_data) except ValueError as re: + logger.error( + f"[DEBUG API] ValueError in list_model_registrations: {re}", + exc_info=True, + ) logger.error(re, exc_info=True) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: + logger.error( + f"[DEBUG API] Unexpected error in list_model_registrations: {e}", + exc_info=True, + ) logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 1ed96cd703..03bd14173e 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -14,6 +14,7 @@ import asyncio import itertools +import json import os import signal import time @@ -217,6 +218,13 @@ async def __post_create__(self): register_rerank, unregister_rerank, ) + from ..model.video import ( + CustomVideoModelFamilyV2, + generate_video_description, + get_video_model_descriptions, + register_video, + unregister_video, + ) self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore "LLM": ( @@ -249,6 +257,12 @@ async def __post_create__(self): unregister_audio, generate_audio_description, ), + "video": ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, + generate_video_description, + ), "flexible": ( FlexibleModelSpec, register_flexible_model, @@ -264,6 +278,7 @@ async def __post_create__(self): model_version_infos.update(get_rerank_model_descriptions()) model_version_infos.update(get_image_model_descriptions()) model_version_infos.update(get_audio_model_descriptions()) + model_version_infos.update(get_video_model_descriptions()) model_version_infos.update(get_flexible_model_descriptions()) await self._cache_tracker_ref.record_model_version( model_version_infos, self.address @@ -609,33 +624,74 @@ def sort_helper(item): assert isinstance(item["model_name"], str) return item.get("model_name").lower() + logger.info( + f"[DEBUG SUPERVISOR] list_model_registrations called with model_type: {model_type}, detailed: {detailed}" + ) + ret = [] if not self.is_local_deployment(): + logger.info(f"[DEBUG SUPERVISOR] Not local deployment, checking workers...") workers = list(self._worker_address_to_worker.values()) for worker in workers: - ret.extend(await worker.list_model_registrations(model_type, detailed)) + worker_data = await worker.list_model_registrations( + model_type, detailed + ) + logger.info( + f"[DEBUG SUPERVISOR] Worker returned {len(worker_data)} models" + ) + ret.extend(worker_data) + else: + logger.info(f"[DEBUG SUPERVISOR] Local deployment mode") - if model_type == "LLM": - from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families + if model_type.upper() == "LLM": + from ..model.llm import ( + BUILTIN_LLM_FAMILIES, + get_user_defined_llm_families, + register_builtin_model, + ) + + logger.info(f"[DEBUG SUPERVISOR] Processing LLM models") + + register_builtin_model() for family in BUILTIN_LLM_FAMILIES: + logger.debug( + f"[DEBUG SUPERVISOR] Processing builtin LLM: {family.model_name}" + ) if detailed: - ret.append(await self._to_llm_reg(family, True)) + reg_data = await self._to_llm_reg(family, True) + ret.append(reg_data) else: ret.append({"model_name": family.model_name, "is_builtin": True}) - for family in get_user_defined_llm_families(): - if detailed: - ret.append(await self._to_llm_reg(family, False)) - else: - ret.append({"model_name": family.model_name, "is_builtin": False}) + user_defined_families = get_user_defined_llm_families() + builtin_names = {family.model_name for family in BUILTIN_LLM_FAMILIES} - ret.sort(key=sort_helper) + for family in user_defined_families: + if family.model_name not in builtin_names: + logger.debug( + f"[DEBUG SUPERVISOR] Processing dynamic LLM: {family.model_name}" + ) + if detailed: + reg_data = await self._to_llm_reg(family, True) + ret.append(reg_data) + else: + ret.append( + {"model_name": family.model_name, "is_builtin": True} + ) + + ret.sort(key=sort_helper) + logger.info(f"[DEBUG SUPERVISOR] LLM: Returning {len(ret)} total models") return ret elif model_type == "embedding": - from ..model.embedding import BUILTIN_EMBEDDING_MODELS + from ..model.embedding import ( + BUILTIN_EMBEDDING_MODELS, + register_builtin_model, + ) from ..model.embedding.custom import get_user_defined_embeddings + register_builtin_model() + for model_name, family in BUILTIN_EMBEDDING_MODELS.items(): if detailed: ret.append( @@ -645,21 +701,55 @@ def sort_helper(item): ret.append({"model_name": model_name, "is_builtin": True}) for model_spec in get_user_defined_embeddings(): + # Check if this model is persisted (added via add_model API) + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_builtin_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_builtin_dir_prefix, + "embedding", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_persist_path): + is_persisted_model = True + else: + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + potential_custom_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "embedding", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_custom_path): + is_persisted_model = True + + is_builtin = is_persisted_model # Treat persisted models as built-in + logger.info( + f"[DEBUG SUPERVISOR] Embedding model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_builtin}" + ) + if detailed: ret.append( - await self._to_embedding_model_reg(model_spec, is_builtin=False) + await self._to_embedding_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) return ret elif model_type == "image": - from ..model.image import BUILTIN_IMAGE_MODELS + from ..model.image import BUILTIN_IMAGE_MODELS, register_builtin_model from ..model.image.custom import get_user_defined_images + register_builtin_model() + for model_name, families in BUILTIN_IMAGE_MODELS.items(): if detailed: family = [x for x in families if x.model_hub == "huggingface"][0] @@ -670,21 +760,54 @@ def sort_helper(item): ret.append({"model_name": model_name, "is_builtin": True}) for model_spec in get_user_defined_images(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_builtin_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_builtin_dir_prefix, + "image", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_persist_path): + is_persisted_model = True + else: + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + potential_custom_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "image", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_custom_path): + is_persisted_model = True + + is_builtin = is_persisted_model # Treat persisted models as built-in + logger.info( + f"[DEBUG SUPERVISOR] Image model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_builtin}" + ) + if detailed: ret.append( - await self._to_image_model_reg(model_spec, is_builtin=False) + await self._to_image_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) return ret elif model_type == "audio": - from ..model.audio import BUILTIN_AUDIO_MODELS + from ..model.audio import BUILTIN_AUDIO_MODELS, register_builtin_model from ..model.audio.custom import get_user_defined_audios + register_builtin_model() + for model_name, families in BUILTIN_AUDIO_MODELS.items(): if detailed: family = [x for x in families if x.model_hub == "huggingface"][0] @@ -695,19 +818,53 @@ def sort_helper(item): ret.append({"model_name": model_name, "is_builtin": True}) for model_spec in get_user_defined_audios(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_builtin_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_builtin_dir_prefix, + "audio", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_persist_path): + is_persisted_model = True + else: + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + potential_custom_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "audio", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_custom_path): + is_persisted_model = True + + is_builtin = is_persisted_model # Treat persisted models as built-in + logger.info( + f"[DEBUG SUPERVISOR] Audio model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_builtin}" + ) + if detailed: ret.append( - await self._to_audio_model_reg(model_spec, is_builtin=False) + await self._to_audio_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) return ret elif model_type == "video": - from ..model.video import BUILTIN_VIDEO_MODELS + from ..model.video import BUILTIN_VIDEO_MODELS, register_builtin_model + from ..model.video.custom import get_user_defined_videos + + register_builtin_model() for model_name, families in BUILTIN_VIDEO_MODELS.items(): if detailed: @@ -717,13 +874,54 @@ def sort_helper(item): ret.append(info) else: ret.append({"model_name": model_name, "is_builtin": True}) - + for model_spec in get_user_defined_videos(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_builtin_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_builtin_dir_prefix, + "video", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_persist_path): + is_persisted_model = True + else: + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + potential_custom_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "video", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_custom_path): + is_persisted_model = True + logger.debug( + f"[DEBUG SUPERVISOR] Video model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_persisted_model}" + ) + if detailed: + ret.append( + await self._to_video_model_reg( + model_spec, is_builtin=is_persisted_model + ) + ) + else: + ret.append( + { + "model_name": model_spec.model_name, + "is_builtin": is_persisted_model, + } + ) ret.sort(key=sort_helper) return ret elif model_type == "rerank": - from ..model.rerank import BUILTIN_RERANK_MODELS + from ..model.rerank import BUILTIN_RERANK_MODELS, register_builtin_model from ..model.rerank.custom import get_user_defined_reranks + register_builtin_model() + for model_name, family in BUILTIN_RERANK_MODELS.items(): if detailed: ret.append(await self._to_rerank_model_reg(family, is_builtin=True)) @@ -731,13 +929,44 @@ def sort_helper(item): ret.append({"model_name": model_name, "is_builtin": True}) for model_spec in get_user_defined_reranks(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_builtin_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_builtin_dir_prefix, + "rerank", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_persist_path): + is_persisted_model = True + else: + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + potential_custom_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "rerank", + f"{model_spec.model_name}.json", + ) + if os.path.exists(potential_custom_path): + is_persisted_model = True + + is_builtin = is_persisted_model # Treat persisted models as built-in + logger.info( + f"[DEBUG SUPERVISOR] Rerank model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_builtin}" + ) + if detailed: ret.append( - await self._to_rerank_model_reg(model_spec, is_builtin=False) + await self._to_rerank_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) @@ -748,13 +977,34 @@ def sort_helper(item): ret = [] for model_spec in get_flexible_models(): + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "flexible", + f"{model_spec.model_name}.json", + ) + is_persisted_model = os.path.exists(potential_persist_path) + + is_builtin = is_persisted_model # Treat persisted models as built-in + logger.info( + f"[DEBUG SUPERVISOR] Flexible model {model_spec.model_name} persisted: {is_persisted_model}, treating as builtin: {is_builtin}" + ) + if detailed: ret.append( - await self._to_flexible_model_reg(model_spec, is_builtin=False) + await self._to_flexible_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) @@ -772,7 +1022,7 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if f is not None: return f - if model_type == "LLM": + if model_type.upper() == "LLM": from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families for f in BUILTIN_LLM_FAMILIES + get_user_defined_llm_families(): @@ -837,6 +1087,7 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: raise ValueError(f"Model {model_name} not found") elif model_type == "video": from ..model.video import BUILTIN_VIDEO_MODELS + from ..model.video.custom import get_user_defined_videos if model_name in BUILTIN_VIDEO_MODELS: return [ @@ -844,6 +1095,10 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: for x in BUILTIN_VIDEO_MODELS[model_name] if x.model_hub == "huggingface" ][0] + else: + for f in get_user_defined_videos(): + if f.model_name == model_name: + return f raise ValueError(f"Model {model_name} not found") else: raise ValueError(f"Unsupported model type: {model_type}") @@ -932,6 +1187,414 @@ async def register_model( else: raise ValueError(f"Unsupported model type: {model_type}") + @log_async(logger=logger) + async def add_model(self, model_type: str, model_json: Dict[str, Any]): + """ + Add a new model by parsing the provided JSON and registering it. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + model_json: JSON configuration for the model + """ + logger.info( + f"[DEBUG SUPERVISOR] add_model called with model_type: {model_type}" + ) + logger.info(f"[DEBUG SUPERVISOR] model_json type: {type(model_json)}") + logger.info( + f"[DEBUG SUPERVISOR] model_json keys: {list(model_json.keys()) if isinstance(model_json, dict) else 'Not a dict'}" + ) + if isinstance(model_json, dict): + logger.info(f"[DEBUG SUPERVISOR] model_json content: {model_json}") + + # Validate model type (with case normalization) + supported_types = list(self._custom_register_type_to_cls.keys()) + logger.info(f"[DEBUG SUPERVISOR] Supported model types: {supported_types}") + logger.info(f"[DEBUG SUPERVISOR] Received model_type: '{model_type}'") + + normalized_model_type = model_type + + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_model_type = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_model_type = "llm" + + logger.info( + f"[DEBUG SUPERVISOR] Normalized model_type: '{normalized_model_type}'" + ) + + if normalized_model_type not in self._custom_register_type_to_cls: + logger.error( + f"[DEBUG SUPERVISOR] Unsupported model type: {normalized_model_type} (original: {model_type})" + ) + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + # Use normalized model type for the rest of the function + model_type = normalized_model_type + logger.info( + f"[DEBUG SUPERVISOR] Using model_type: '{model_type}' for registration" + ) + + # Get the appropriate model class and register function + ( + model_spec_cls, + register_fn, + unregister_fn, + generate_fn, + ) = self._custom_register_type_to_cls[model_type] + logger.info(f"[DEBUG SUPERVISOR] Model spec class: {model_spec_cls}") + logger.info(f"[DEBUG SUPERVISOR] Register function: {register_fn}") + logger.info(f"[DEBUG SUPERVISOR] Unregister function: {unregister_fn}") + logger.info(f"[DEBUG SUPERVISOR] Generate function: {generate_fn}") + + # Validate required fields (only model_name is required) + required_fields = ["model_name"] + logger.info(f"[DEBUG SUPERVISOR] Checking required fields: {required_fields}") + for field in required_fields: + if field not in model_json: + logger.error(f"[DEBUG SUPERVISOR] Missing required field: {field}") + raise ValueError(f"Missing required field: {field}") + + # Validate model name format + from ..model.utils import is_valid_model_name + + model_name = model_json["model_name"] + logger.info(f"[DEBUG SUPERVISOR] Extracted model_name: {model_name}") + + if not is_valid_model_name(model_name): + logger.error(f"[DEBUG SUPERVISOR] Invalid model name format: {model_name}") + raise ValueError(f"Invalid model name format: {model_name}") + + logger.info(f"[DEBUG SUPERVISOR] Model name validation passed") + + # Convert model hub JSON format to Xinference expected format + logger.info(f"[DEBUG SUPERVISOR] Converting model JSON format...") + try: + converted_model_json = self._convert_model_json_format(model_json) + logger.info( + f"[DEBUG SUPERVISOR] Converted model JSON: {converted_model_json}" + ) + except Exception as e: + logger.error( + f"[DEBUG SUPERVISOR] Format conversion failed: {str(e)}", exc_info=True + ) + raise ValueError(f"Failed to convert model JSON format: {str(e)}") + + # Parse the JSON into the appropriate model spec + logger.info(f"[DEBUG SUPERVISOR] Parsing model spec...") + try: + model_spec = model_spec_cls.parse_obj(converted_model_json) + logger.info(f"[DEBUG SUPERVISOR] Parsed model spec: {model_spec}") + except Exception as e: + logger.error( + f"[DEBUG SUPERVISOR] Model spec parsing failed: {str(e)}", exc_info=True + ) + raise ValueError(f"Invalid model JSON format: {str(e)}") + + # Check if model already exists + logger.info(f"[DEBUG SUPERVISOR] Checking if model already exists...") + try: + existing_model = await self.get_model_registration( + model_type, model_spec.model_name + ) + logger.info( + f"[DEBUG SUPERVISOR] Existing model check result: {existing_model}" + ) + + if existing_model is not None: + logger.error( + f"[DEBUG SUPERVISOR] Model already exists: {model_spec.model_name}" + ) + raise ValueError( + f"Model '{model_spec.model_name}' already exists for type '{model_type}'. " + f"Please choose a different model name or remove the existing model first." + ) + + except ValueError as e: + if "not found" in str(e): + # Model doesn't exist, we can proceed + logger.info( + f"[DEBUG SUPERVISOR] Model doesn't exist yet, proceeding with registration" + ) + pass + else: + # Re-raise validation errors + logger.error( + f"[DEBUG SUPERVISOR] Validation error during model check: {str(e)}" + ) + raise e + except Exception as ex: + logger.error( + f"[DEBUG SUPERVISOR] Unexpected error during model check: {str(ex)}", + exc_info=True, + ) + raise ValueError(f"Failed to validate model registration: {str(ex)}") + + logger.info(f"[DEBUG SUPERVISOR] Storing single model as built-in...") + try: + # Create CacheManager and store as built-in model + from ..model.cache_manager import CacheManager + + cache_manager = CacheManager(model_spec) + cache_manager.register_builtin_model(model_type.lower()) + logger.info(f"[DEBUG SUPERVISOR] Built-in model stored successfully") + + # Register in the model registry without persisting to avoid duplicate storage + register_fn(model_spec, persist=False) + logger.info( + f"[DEBUG SUPERVISOR] Model registry registration completed successfully" + ) + + # Record model version + logger.info(f"[DEBUG SUPERVISOR] Generating version info...") + version_info = generate_fn(model_spec) + logger.info(f"[DEBUG SUPERVISOR] Generated version_info: {version_info}") + + logger.info( + f"[DEBUG SUPERVISOR] Recording model version in cache tracker..." + ) + await self._cache_tracker_ref.record_model_version( + version_info, self.address + ) + logger.info(f"[DEBUG SUPERVISOR] Cache tracker recording completed") + + # Sync to workers if not local deployment + is_local = self.is_local_deployment() + logger.info(f"[DEBUG SUPERVISOR] Is local deployment: {is_local}") + if not is_local: + # Convert back to JSON string for sync compatibility + model_json_str = json.dumps(converted_model_json) + logger.info(f"[DEBUG SUPERVISOR] Syncing model to workers...") + await self._sync_register_model( + model_type, model_json_str, True, model_spec.model_name + ) + logger.info(f"[DEBUG SUPERVISOR] Model sync to workers completed") + + logger.info( + f"Successfully added model '{model_spec.model_name}' (type: {model_type})" + ) + + except ValueError as e: + # Validation errors - don't need cleanup as model wasn't registered + logger.error(f"[DEBUG SUPERVISOR] ValueError during registration: {str(e)}") + raise e + except Exception as e: + # Unexpected errors - attempt cleanup + logger.error( + f"[DEBUG SUPERVISOR] Unexpected error during registration: {str(e)}", + exc_info=True, + ) + try: + logger.info(f"[DEBUG SUPERVISOR] Attempting cleanup...") + unregister_fn(model_spec.model_name, raise_error=False) + logger.info(f"[DEBUG SUPERVISOR] Cleanup completed successfully") + except Exception as cleanup_error: + logger.warning(f"[DEBUG SUPERVISOR] Cleanup failed: {cleanup_error}") + raise ValueError( + f"Failed to register model '{model_spec.model_name}': {str(e)}" + ) + + logger.info(f"[DEBUG SUPERVISOR] add_model completed successfully") + + def _convert_model_json_format(self, model_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert model hub JSON format to Xinference expected format. + + The input format uses nested 'model_src' structure, but Xinference expects + flattened fields at the spec level. + + Also handles cases where model_specs field is missing by providing a default. + """ + logger.info(f"[DEBUG SUPERVISOR] _convert_model_json_format called") + logger.info(f"[DEBUG SUPERVISOR] Input model_json: {model_json}") + + if model_json.get("model_id") is None and "model_src" in model_json: + logger.info( + f"[DEBUG SUPERVISOR] model_id is null, attempting to extract from model_src" + ) + model_src = model_json["model_src"] + + if "huggingface" in model_src and "model_id" in model_src["huggingface"]: + model_json["model_id"] = model_src["huggingface"]["model_id"] + logger.info( + f"[DEBUG SUPERVISOR] Extracted model_id from huggingface: {model_json['model_id']}" + ) + elif "modelscope" in model_src and "model_id" in model_src["modelscope"]: + model_json["model_id"] = model_src["modelscope"]["model_id"] + logger.info( + f"[DEBUG SUPERVISOR] Extracted model_id from modelscope: {model_json['model_id']}" + ) + + if model_json.get("model_revision") is None: + if ( + "huggingface" in model_src + and "model_revision" in model_src["huggingface"] + ): + model_json["model_revision"] = model_src["huggingface"][ + "model_revision" + ] + logger.info( + f"[DEBUG SUPERVISOR] Extracted model_revision from huggingface: {model_json['model_revision']}" + ) + elif ( + "modelscope" in model_src + and "model_revision" in model_src["modelscope"] + ): + model_json["model_revision"] = model_src["modelscope"][ + "model_revision" + ] + logger.info( + f"[DEBUG SUPERVISOR] Extracted model_revision from modelscope: {model_json['model_revision']}" + ) + + # If model_specs is missing, provide a default minimal spec + if "model_specs" not in model_json or not model_json["model_specs"]: + logger.info( + f"[DEBUG SUPERVISOR] model_specs missing or empty, creating default spec" + ) + # Create a minimal default spec + default_spec = { + **model_json, + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": None, + "quantization": "none", + } + ], + } + logger.info(f"[DEBUG SUPERVISOR] Created default spec: {default_spec}") + return default_spec + + logger.info( + f"[DEBUG SUPERVISOR] Found model_specs: {model_json['model_specs']}" + ) + + # Check if conversion is needed (detect model_src structure) + needs_conversion = False + for i, spec in enumerate(model_json["model_specs"]): + logger.info(f"[DEBUG SUPERVISOR] Checking spec {i}: {spec}") + if "model_src" in spec: + logger.info( + f"[DEBUG SUPERVISOR] Found model_src in spec {i}, conversion needed" + ) + needs_conversion = True + break + + if not needs_conversion: + logger.info( + f"[DEBUG SUPERVISOR] No conversion needed, returning original model_json" + ) + return model_json + + converted = model_json.copy() + converted_specs = [] + + for spec in model_json["model_specs"]: + model_format = spec.get("model_format", "pytorch") + model_size = spec.get("model_size_in_billions") + + if "model_src" not in spec: + # No model_src, keep spec as is but ensure required fields + converted_spec = spec.copy() + if "quantization" not in converted_spec: + converted_spec["quantization"] = "none" + if "model_format" not in converted_spec: + converted_spec["model_format"] = "pytorch" + if "model_file_name_template" not in converted_spec: + converted_spec["model_file_name_template"] = "model.bin" + if "model_hub" not in converted_spec and "model_id" in converted_spec: + converted_spec["model_hub"] = "huggingface" + converted_specs.append(converted_spec) + continue + + model_src = spec["model_src"] + + # Handle different model sources + if "huggingface" in model_src: + hf_info = model_src["huggingface"] + quantizations = hf_info.get("quantizations", ["none"]) + + # Create separate specs for each quantization + for quant in quantizations: + converted_spec = { + "model_format": model_format, + "model_size_in_billions": model_size, + "quantization": quant, + "model_hub": "huggingface", + } + + # Add common fields + if "model_id" in hf_info: + converted_spec["model_id"] = hf_info["model_id"] + if "model_revision" in hf_info: + converted_spec["model_revision"] = hf_info["model_revision"] + + # Format-specific fields + if model_format == "ggufv2": + if "model_id" in hf_info: + converted_spec["model_id"] = hf_info["model_id"] + if "model_file_name_template" in hf_info: + converted_spec["model_file_name_template"] = hf_info[ + "model_file_name_template" + ] + else: + # Default template + model_name = model_json["model_name"] + converted_spec["model_file_name_template"] = ( + f"{model_name}-{{quantization}}.gguf" + ) + elif model_format in ["pytorch", "mlx"]: + if "model_id" in hf_info: + converted_spec["model_id"] = hf_info["model_id"] + if "model_revision" in hf_info: + converted_spec["model_revision"] = hf_info["model_revision"] + converted_spec["model_file_name_template"] = "pytorch_model.bin" + + converted_specs.append(converted_spec) + + elif "modelscope" in model_src: + # Handle ModelScope similarly + ms_info = model_src["modelscope"] + quantizations = ms_info.get("quantizations", ["none"]) + + for quant in quantizations: + converted_spec = { + "model_format": model_format, + "model_size_in_billions": model_size, + "quantization": quant, + "model_hub": "modelscope", + } + + if "model_id" in ms_info: + converted_spec["model_id"] = ms_info["model_id"] + if "model_revision" in ms_info: + converted_spec["model_revision"] = ms_info["model_revision"] + converted_spec["model_file_name_template"] = "pytorch_model.bin" + + converted_specs.append(converted_spec) + + else: + # Unknown model source, skip or handle as error + logger.warning( + f"Unknown model source in spec: {list(model_src.keys())}" + ) + # Keep original spec but add required fields + converted_spec = spec.copy() + if "quantization" not in converted_spec: + converted_spec["quantization"] = "none" + if "model_format" not in converted_spec: + converted_spec["model_format"] = "pytorch" + if "model_file_name_template" not in converted_spec: + converted_spec["model_file_name_template"] = "model.bin" + converted_specs.append(converted_spec) + + converted["model_specs"] = converted_specs + + return converted + async def _sync_register_model( self, model_type: str, model: str, persist: bool, model_name: str ): @@ -956,6 +1619,274 @@ async def _sync_register_model( logger.warning(f"finish unregister model: {model} for {name}") raise e + @log_async(logger=logger) + async def update_model_type(self, model_type: str): + """ + Update model configurations for a specific model type by downloading + the latest JSON from the remote API and storing it locally. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + """ + import json + + import requests + + logger.info( + f"[DEBUG SUPERVISOR] update_model_type called with model_type: {model_type}" + ) + + supported_types = list(self._custom_register_type_to_cls.keys()) + + normalized_for_validation = model_type + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_for_validation = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_for_validation = "llm" + + if normalized_for_validation not in supported_types: + logger.error( + f"[DEBUG SUPERVISOR] Unsupported model type: {normalized_for_validation}" + ) + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + model_type_for_operations = normalized_for_validation + logger.info( + f"[DEBUG SUPERVISOR] Using model_type: '{model_type_for_operations}' for operations" + ) + + # Construct the URL to download JSON + url = f"https://model.xinference.io/api/models/download?model_type={model_type.lower()}" + logger.info(f"[DEBUG SUPERVISOR] Downloading model configurations from: {url}") + + try: + # Download JSON from remote API + response = requests.get(url, timeout=30) + response.raise_for_status() + + # Parse JSON response + model_data = response.json() + logger.info( + f"[DEBUG SUPERVISOR] Successfully downloaded JSON for model type: {model_type}" + ) + logger.info(f"[DEBUG SUPERVISOR] JSON data type: {type(model_data)}") + + if isinstance(model_data, dict): + logger.info( + f"[DEBUG SUPERVISOR] JSON data keys: {list(model_data.keys())}" + ) + elif isinstance(model_data, list): + logger.info( + f"[DEBUG SUPERVISOR] JSON data contains {len(model_data)} items" + ) + if model_data: + logger.info( + f"[DEBUG SUPERVISOR] First item keys: {list(model_data[0].keys()) if isinstance(model_data[0], dict) else 'Not a dict'}" + ) + + # Store the JSON data using CacheManager as built-in models + logger.info( + f"[DEBUG SUPERVISOR] Storing model configurations as built-in models..." + ) + await self._store_model_configurations(model_type, model_data) + logger.info( + f"[DEBUG SUPERVISOR] Built-in model configurations stored successfully" + ) + + # Dynamically reload built-in models to make them immediately available + logger.info( + f"[DEBUG SUPERVISOR] Reloading built-in models for immediate availability..." + ) + try: + if model_type.lower() == "llm": + from ..model.llm import register_builtin_model + + register_builtin_model() + logger.info(f"[DEBUG SUPERVISOR] LLM models reloaded successfully") + elif model_type.lower() == "embedding": + from ..model.embedding import register_builtin_model + + register_builtin_model() + logger.info( + f"[DEBUG SUPERVISOR] Embedding models reloaded successfully" + ) + elif model_type.lower() == "audio": + from ..model.audio import register_builtin_model + + register_builtin_model() + logger.info( + f"[DEBUG SUPERVISOR] Audio models reloaded successfully" + ) + elif model_type.lower() == "image": + from ..model.image import register_builtin_model + + register_builtin_model() + logger.info( + f"[DEBUG SUPERVISOR] Image models reloaded successfully" + ) + elif model_type.lower() == "rerank": + from ..model.rerank import register_builtin_model + + register_builtin_model() + logger.info( + f"[DEBUG SUPERVISOR] Rerank models reloaded successfully" + ) + elif model_type.lower() == "video": + from ..model.video import register_builtin_model + + register_builtin_model() + logger.info( + f"[DEBUG SUPERVISOR] Video models reloaded successfully" + ) + else: + logger.warning( + f"[DEBUG SUPERVISOR] No dynamic loading available for model type: {model_type}" + ) + except Exception as reload_error: + logger.error( + f"[DEBUG SUPERVISOR] Error reloading built-in models: {reload_error}", + exc_info=True, + ) + # Don't fail the update if reload fails, just log the error + + except requests.exceptions.RequestException as e: + logger.error( + f"[DEBUG SUPERVISOR] Network error downloading model configurations: {e}" + ) + raise ValueError(f"Failed to download model configurations: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"[DEBUG SUPERVISOR] JSON decode error: {e}") + raise ValueError(f"Invalid JSON response from remote API: {str(e)}") + except Exception as e: + logger.error( + f"[DEBUG SUPERVISOR] Unexpected error during model update: {e}", + exc_info=True, + ) + raise ValueError(f"Failed to update model configurations: {str(e)}") + + async def _store_model_configurations(self, model_type: str, model_data): + """ + Store model configurations using the appropriate CacheManager as built-in models. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing model configurations + """ + + logger.info( + f"[DEBUG SUPERVISOR] Storing configurations for model type: {model_type}" + ) + + try: + # Create a temporary model spec to get CacheManager instance + # We need to determine the appropriate model spec class for this model type + lookup_key = None + for key in self._custom_register_type_to_cls.keys(): + if key.lower() == model_type.lower(): + lookup_key = key + break + + if lookup_key is None: + raise ValueError(f"Unsupported model type: {model_type}") + + model_spec_cls, _, _, _ = self._custom_register_type_to_cls[lookup_key] + logger.info( + f"[DEBUG SUPERVISOR] Using model spec class: {model_spec_cls.__name__} with key: {lookup_key}" + ) + + # Handle different response formats + if isinstance(model_data, dict): + # Single model configuration + logger.info(f"[DEBUG SUPERVISOR] Processing single model configuration") + await self._store_single_model_config( + model_type, model_data, model_spec_cls + ) + elif isinstance(model_data, list): + # Multiple model configurations + logger.info( + f"[DEBUG SUPERVISOR] Processing {len(model_data)} model configurations" + ) + for i, model_config in enumerate(model_data): + if isinstance(model_config, dict): + logger.info(f"[DEBUG SUPERVISOR] Processing model config {i+1}") + await self._store_single_model_config( + model_type, model_config, model_spec_cls + ) + else: + logger.warning( + f"[DEBUG SUPERVISOR] Skipping invalid model config {i+1}: not a dict" + ) + else: + raise ValueError( + f"Invalid model data format: expected dict or list, got {type(model_data)}" + ) + + except Exception as e: + logger.error( + f"[DEBUG SUPERVISOR] Error storing model configurations: {e}", + exc_info=True, + ) + raise + + async def _store_single_model_config( + self, model_type: str, model_config: dict, model_spec_cls + ): + """ + Store a single model configuration as built-in model. + + Args: + model_type: Type of model + model_config: Single model configuration dictionary + model_spec_cls: Model specification class + """ + from ..model.cache_manager import CacheManager + + # Ensure required fields are present + if "model_name" not in model_config: + logger.warning( + f"[DEBUG SUPERVISOR] Skipping model config without model_name: {model_config}" + ) + return + + model_name = model_config["model_name"] + logger.info(f"[DEBUG SUPERVISOR] Storing model: {model_name}") + + # Validate model name format + from ..model.utils import is_valid_model_name + + if not is_valid_model_name(model_name): + logger.warning( + f"[DEBUG SUPERVISOR] Skipping model with invalid name: {model_name}" + ) + return + + try: + # Convert model hub JSON format to Xinference expected format + converted_config = self._convert_model_json_format(model_config) + logger.info(f"[DEBUG SUPERVISOR] Converted model config for: {model_name}") + + # Create model spec instance + model_spec = model_spec_cls.parse_obj(converted_config) + logger.info(f"[DEBUG SUPERVISOR] Created model spec for: {model_name}") + + # Create CacheManager and store the configuration as built-in model + cache_manager = CacheManager(model_spec) + cache_manager.register_builtin_model(model_type) + logger.info( + f"[DEBUG SUPERVISOR] Stored built-in model configuration for: {model_name}" + ) + + except Exception as e: + logger.error( + f"[DEBUG SUPERVISOR] Error storing model {model_name}: {e}", + exc_info=True, + ) + # Continue with other models instead of failing completely + return + @log_async(logger=logger) async def unregister_model(self, model_type: str, model_name: str): if model_type in self._custom_register_type_to_cls: diff --git a/xinference/model/audio/__init__.py b/xinference/model/audio/__init__.py index 9465771917..89a8cb0a4e 100644 --- a/xinference/model/audio/__init__.py +++ b/xinference/model/audio/__init__.py @@ -60,6 +60,71 @@ def register_custom_model(): warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("audio") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_audio_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "audio") + if os.path.isdir(builtin_audio_dir): + for f in os.listdir(builtin_audio_dir): + if f.endswith(".json"): + try: + with codecs.open( + os.path.join(builtin_audio_dir, f), encoding="utf-8" + ) as fd: + model_data = json.load(fd) + + # Apply conversion logic to handle null model_id and other issues + if ( + model_data.get("model_id") is None + and "model_src" in model_data + ): + model_src = model_data["model_src"] + # Extract model_id from available sources + if ( + "huggingface" in model_src + and "model_id" in model_src["huggingface"] + ): + model_data["model_id"] = model_src["huggingface"][ + "model_id" + ] + elif ( + "modelscope" in model_src + and "model_id" in model_src["modelscope"] + ): + model_data["model_id"] = model_src["modelscope"][ + "model_id" + ] + + # Extract model_revision if available + if model_data.get("model_revision") is None: + if ( + "huggingface" in model_src + and "model_revision" in model_src["huggingface"] + ): + model_data["model_revision"] = model_src[ + "huggingface" + ]["model_revision"] + elif ( + "modelscope" in model_src + and "model_revision" in model_src["modelscope"] + ): + model_data["model_revision"] = model_src[ + "modelscope" + ]["model_revision"] + + builtin_audio_family = AudioModelFamilyV2.parse_obj(model_data) + + # Only register if model doesn't already exist + if builtin_audio_family.model_name not in existing_model_names: + register_audio(builtin_audio_family, persist=False) + existing_model_names.add(builtin_audio_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_audio_dir}/{f} has error, {e}") + + def _need_filter(spec: dict): if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get( "engine", "" diff --git a/xinference/model/audio/builtin.py b/xinference/model/audio/builtin.py new file mode 100644 index 0000000000..e78cc756d6 --- /dev/null +++ b/xinference/model/audio/builtin.py @@ -0,0 +1,115 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .custom import AudioModelFamilyV2 + + +class BuiltinAudioModelRegistry: + """ + Registry for built-in audio models downloaded from official model hub. + + These models are treated as built-in models and don't require model_family validation. + They are stored in ~/.xinference/model/v2/builtin/audio/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "audio") + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["AudioModelFamilyV2"]: + """Load all built-in audio models from the builtin directory.""" + from .custom import AudioModelFamilyV2 + + models: List["AudioModelFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Parse using AudioModelFamilyV2 (no model_family validation required) + model = AudioModelFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in audio model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model) -> None: + """Register a built-in audio model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in audio model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in audio model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + os.remove(persist_path) + logger.info(f"Unregistered built-in audio model: {model_name}") + else: + logger.warning(f"Built-in model file not found: {persist_path}") + + +# Global registry instance +_builtin_registry = None + + +def get_builtin_audio_registry() -> BuiltinAudioModelRegistry: + """Get the global built-in audio model registry instance.""" + global _builtin_registry + if _builtin_registry is None: + _builtin_registry = BuiltinAudioModelRegistry() + return _builtin_registry + + +def get_builtin_audio_families() -> List: + """Get all built-in audio model families.""" + return get_builtin_audio_registry().get_builtin_models() + + +def register_builtin_audio(audio_family) -> None: + """Register a built-in audio model family.""" + return get_builtin_audio_registry().register_builtin_model(audio_family) + + +def unregister_builtin_audio(model_name: str) -> None: + """Unregister a built-in audio model family.""" + return get_builtin_audio_registry().unregister_builtin_model(model_name) diff --git a/xinference/model/cache_manager.py b/xinference/model/cache_manager.py index ae9a9f1bfd..e4b74e2177 100644 --- a/xinference/model/cache_manager.py +++ b/xinference/model/cache_manager.py @@ -16,8 +16,12 @@ def __init__(self, model_family: "CacheableModelSpec"): self._model_family = model_family self._v2_cache_dir_prefix = os.path.join(XINFERENCE_CACHE_DIR, "v2") self._v2_custom_dir_prefix = os.path.join(XINFERENCE_MODEL_DIR, "v2") + self._v2_builtin_dir_prefix = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin" + ) os.makedirs(self._v2_cache_dir_prefix, exist_ok=True) os.makedirs(self._v2_custom_dir_prefix, exist_ok=True) + os.makedirs(self._v2_builtin_dir_prefix, exist_ok=True) self._cache_dir = os.path.join( self._v2_cache_dir_prefix, self._model_family.model_name.replace(".", "_") ) @@ -109,9 +113,21 @@ def cache(self) -> str: return self._cache() def register_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, + f"{self._model_family.model_name}.json", + ) + os.makedirs(os.path.dirname(persist_path), exist_ok=True) + with open(persist_path, mode="w") as fd: + fd.write(self._model_family.json()) + + def register_builtin_model(self, model_type: str): + model_type_dir = model_type.lower() + persist_path = os.path.join( + self._v2_builtin_dir_prefix, + model_type_dir, f"{self._model_family.model_name}.json", ) os.makedirs(os.path.dirname(persist_path), exist_ok=True) @@ -119,9 +135,10 @@ def register_custom_model(self, model_type: str): fd.write(self._model_family.json()) def unregister_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, f"{self._model_family.model_name}.json", ) if os.path.exists(persist_path): diff --git a/xinference/model/custom.py b/xinference/model/custom.py index f08a09dfea..a1adee9aea 100644 --- a/xinference/model/custom.py +++ b/xinference/model/custom.py @@ -118,6 +118,7 @@ def get_registry(cls, model_type: str) -> ModelRegistry: from .image.custom import ImageModelRegistry from .llm.custom import LLMModelRegistry from .rerank.custom import RerankModelRegistry + from .video.custom import VideoModelRegistry if model_type not in cls._instances: if model_type == "rerank": @@ -126,6 +127,8 @@ def get_registry(cls, model_type: str) -> ModelRegistry: cls._instances[model_type] = ImageModelRegistry() elif model_type == "audio": cls._instances[model_type] = AudioModelRegistry() + elif model_type == "video": + cls._instances[model_type] = VideoModelRegistry() elif model_type == "llm": cls._instances[model_type] = LLMModelRegistry() elif model_type == "flexible": diff --git a/xinference/model/embedding/__init__.py b/xinference/model/embedding/__init__.py index f1e822e112..cebf1eee03 100644 --- a/xinference/model/embedding/__init__.py +++ b/xinference/model/embedding/__init__.py @@ -64,6 +64,40 @@ def register_custom_model(): warnings.warn(f"{user_defined_embedding_dir}/{f} has error, {e}") +def register_builtin_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("embedding") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_embedding_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", "embedding" + ) + if os.path.isdir(builtin_embedding_dir): + for f in os.listdir(builtin_embedding_dir): + if f.endswith(".json"): + try: + with codecs.open( + os.path.join(builtin_embedding_dir, f), encoding="utf-8" + ) as fd: + builtin_embedding_family = EmbeddingModelFamilyV2.parse_obj( + json.load(fd) + ) + + # Only register if model doesn't already exist + if ( + builtin_embedding_family.model_name + not in existing_model_names + ): + register_embedding(builtin_embedding_family, persist=False) + existing_model_names.add( + builtin_embedding_family.model_name + ) + except Exception as e: + warnings.warn(f"{builtin_embedding_dir}/{f} has error, {e}") + + def check_format_with_engine(model_format, engine): if model_format in ["ggufv2"] and engine not in ["llama.cpp"]: return False diff --git a/xinference/model/embedding/builtin.py b/xinference/model/embedding/builtin.py new file mode 100644 index 0000000000..d100931136 --- /dev/null +++ b/xinference/model/embedding/builtin.py @@ -0,0 +1,117 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .custom import EmbeddingModelFamilyV2 + + +class BuiltinEmbeddingModelRegistry: + """ + Registry for built-in embedding models downloaded from official model hub. + + These models are treated as built-in models. + They are stored in ~/.xinference/model/v2/builtin/embedding/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", "embedding" + ) + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["EmbeddingModelFamilyV2"]: + """Load all built-in embedding models from the builtin directory.""" + from .custom import EmbeddingModelFamilyV2 + + models: List["EmbeddingModelFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Parse using EmbeddingFamilyV2 + model = EmbeddingModelFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in embedding model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model) -> None: + """Register a built-in embedding model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in embedding model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in embedding model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + os.remove(persist_path) + logger.info(f"Unregistered built-in embedding model: {model_name}") + else: + logger.warning(f"Built-in model file not found: {persist_path}") + + +# Global registry instance +_builtin_registry = None + + +def get_builtin_embedding_registry() -> BuiltinEmbeddingModelRegistry: + """Get the global built-in embedding model registry instance.""" + global _builtin_registry + if _builtin_registry is None: + _builtin_registry = BuiltinEmbeddingModelRegistry() + return _builtin_registry + + +def get_builtin_embedding_families() -> List: + """Get all built-in embedding model families.""" + return get_builtin_embedding_registry().get_builtin_models() + + +def register_builtin_embedding(embedding_family) -> None: + """Register a built-in embedding model family.""" + return get_builtin_embedding_registry().register_builtin_model(embedding_family) + + +def unregister_builtin_embedding(model_name: str) -> None: + """Unregister a built-in embedding model family.""" + return get_builtin_embedding_registry().unregister_builtin_model(model_name) diff --git a/xinference/model/flexible/launchers/__init__.py b/xinference/model/flexible/launchers/__init__.py index f8de4cd8d4..09138b5b2a 100644 --- a/xinference/model/flexible/launchers/__init__.py +++ b/xinference/model/flexible/launchers/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .image_process_launcher import launcher as image_process -from .modelscope_launcher import launcher as modelscope -from .transformers_launcher import launcher as transformers -from .yolo_launcher import launcher as yolo diff --git a/xinference/model/image/__init__.py b/xinference/model/image/__init__.py index 14230ea41c..a08b72c1dd 100644 --- a/xinference/model/image/__init__.py +++ b/xinference/model/image/__init__.py @@ -55,6 +55,33 @@ def register_custom_model(): warnings.warn(f"{user_defined_image_dir}/{f} has error, {e}") +def register_builtin_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("image") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "image") + if os.path.isdir(builtin_image_dir): + for f in os.listdir(builtin_image_dir): + if f.endswith(".json"): + try: + with codecs.open( + os.path.join(builtin_image_dir, f), encoding="utf-8" + ) as fd: + builtin_image_family = ImageModelFamilyV2.parse_obj( + json.load(fd) + ) + + # Only register if model doesn't already exist + if builtin_image_family.model_name not in existing_model_names: + register_image(builtin_image_family, persist=False) + existing_model_names.add(builtin_image_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_image_dir}/{f} has error, {e}") + + def _install(): load_model_family_from_json("model_spec.json", BUILTIN_IMAGE_MODELS) diff --git a/xinference/model/image/builtin.py b/xinference/model/image/builtin.py new file mode 100644 index 0000000000..230b3b8d7c --- /dev/null +++ b/xinference/model/image/builtin.py @@ -0,0 +1,115 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .custom import ImageModelFamilyV2 + + +class BuiltinImageModelRegistry: + """ + Registry for built-in image models downloaded from official model hub. + + These models are treated as built-in models and don't require model_family validation. + They are stored in ~/.xinference/model/v2/builtin/image/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "image") + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["ImageModelFamilyV2"]: + """Load all built-in image models from the builtin directory.""" + from .custom import ImageModelFamilyV2 + + models: List["ImageModelFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Parse using ImageModelFamilyV2 (no model_family validation required) + model = ImageModelFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in image model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model) -> None: + """Register a built-in image model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in image model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in image model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + os.remove(persist_path) + logger.info(f"Unregistered built-in image model: {model_name}") + else: + logger.warning(f"Built-in model file not found: {persist_path}") + + +# Global registry instance +_builtin_registry = None + + +def get_builtin_image_registry() -> BuiltinImageModelRegistry: + """Get the global built-in image model registry instance.""" + global _builtin_registry + if _builtin_registry is None: + _builtin_registry = BuiltinImageModelRegistry() + return _builtin_registry + + +def get_builtin_image_families() -> List: + """Get all built-in image model families.""" + return get_builtin_image_registry().get_builtin_models() + + +def register_builtin_image(image_family) -> None: + """Register a built-in image model family.""" + return get_builtin_image_registry().register_builtin_model(image_family) + + +def unregister_builtin_image(model_name: str) -> None: + """Unregister a built-in image model family.""" + return get_builtin_image_registry().unregister_builtin_model(model_name) diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index a4c4704ce4..f417d1acfc 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -128,6 +128,31 @@ def register_custom_model(): warnings.warn(f"{user_defined_llm_dir}/{f} has error, {e}") +def register_builtin_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("llm") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "llm") + if os.path.isdir(builtin_llm_dir): + for f in os.listdir(builtin_llm_dir): + if f.endswith(".json"): + try: + with codecs.open( + os.path.join(builtin_llm_dir, f), encoding="utf-8" + ) as fd: + builtin_llm_family = LLMFamilyV2.parse_raw(fd.read()) + + # Only register if model doesn't already exist + if builtin_llm_family.model_name not in existing_model_names: + register_llm(builtin_llm_family, persist=False) + existing_model_names.add(builtin_llm_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_llm_dir}/{f} has error, {e}") + + def load_model_family_from_json(json_filename, target_families): json_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), json_filename) for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")): diff --git a/xinference/model/llm/builtin.py b/xinference/model/llm/builtin.py new file mode 100644 index 0000000000..d82378db1e --- /dev/null +++ b/xinference/model/llm/builtin.py @@ -0,0 +1,115 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from ..llm.llm_family import LLMFamilyV2 + +logger = logging.getLogger(__name__) + + +class BuiltinLLMModelRegistry: + """ + Registry for built-in LLM models downloaded from official model hub. + + These models are treated as built-in models and don't require model_family validation. + They are stored in ~/.xinference/model/v2/builtin/llm/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "llm") + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["LLMFamilyV2"]: + """Load all built-in LLM models from the builtin directory.""" + from ..llm.llm_family import LLMFamilyV2 + + models: List["LLMFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Parse using LLMFamilyV2 (no model_family validation required) + model = LLMFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in LLM model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model: "LLMFamilyV2") -> None: + """Register a built-in LLM model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in LLM model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in LLM model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + os.remove(persist_path) + logger.info(f"Unregistered built-in LLM model: {model_name}") + else: + logger.warning(f"Built-in model file not found: {persist_path}") + + +# Global registry instance +_builtin_registry = None + + +def get_builtin_llm_registry() -> BuiltinLLMModelRegistry: + """Get the global built-in LLM model registry instance.""" + global _builtin_registry + if _builtin_registry is None: + _builtin_registry = BuiltinLLMModelRegistry() + return _builtin_registry + + +def get_builtin_llm_families() -> List["LLMFamilyV2"]: + """Get all built-in LLM model families.""" + return get_builtin_llm_registry().get_builtin_models() + + +def register_builtin_llm(llm_family: "LLMFamilyV2") -> None: + """Register a built-in LLM model family.""" + return get_builtin_llm_registry().register_builtin_model(llm_family) + + +def unregister_builtin_llm(model_name: str) -> None: + """Unregister a built-in LLM model family.""" + return get_builtin_llm_registry().unregister_builtin_model(model_name) diff --git a/xinference/model/rerank/__init__.py b/xinference/model/rerank/__init__.py index 36334cb9fc..5ed2b2fd14 100644 --- a/xinference/model/rerank/__init__.py +++ b/xinference/model/rerank/__init__.py @@ -63,6 +63,32 @@ def register_custom_model(): warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}") +def register_builtin_model(): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("rerank") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_rerank_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "rerank") + if os.path.isdir(builtin_rerank_dir): + for f in os.listdir(builtin_rerank_dir): + if f.endswith(".json"): + try: + with codecs.open( + os.path.join(builtin_rerank_dir, f), encoding="utf-8" + ) as fd: + builtin_rerank_family = RerankModelFamilyV2.parse_obj( + json.load(fd) + ) + + # Only register if model doesn't already exist + if builtin_rerank_family.model_name not in existing_model_names: + register_rerank(builtin_rerank_family, persist=False) + existing_model_names.add(builtin_rerank_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_rerank_dir}/{f} has error, {e}") + + def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"): model_name = model_family.model_name engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get( diff --git a/xinference/model/rerank/builtin.py b/xinference/model/rerank/builtin.py new file mode 100644 index 0000000000..3fe0cd927b --- /dev/null +++ b/xinference/model/rerank/builtin.py @@ -0,0 +1,115 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .custom import RerankModelFamilyV2 + + +class BuiltinRerankModelRegistry: + """ + Registry for built-in rerank models downloaded from official model hub. + + These models are treated as built-in models and don't require model_family validation. + They are stored in ~/.xinference/model/v2/builtin/rerank/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "rerank") + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["RerankModelFamilyV2"]: + """Load all built-in rerank models from the builtin directory.""" + from .custom import RerankModelFamilyV2 + + models: List["RerankModelFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Parse using RerankModelFamilyV2 (no model_family validation required) + model = RerankModelFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in rerank model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model) -> None: + """Register a built-in rerank model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in rerank model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in rerank model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + os.remove(persist_path) + logger.info(f"Unregistered built-in rerank model: {model_name}") + else: + logger.warning(f"Built-in model file not found: {persist_path}") + + +# Global registry instance +_builtin_registry = None + + +def get_builtin_rerank_registry() -> BuiltinRerankModelRegistry: + """Get the global built-in rerank model registry instance.""" + global _builtin_registry + if _builtin_registry is None: + _builtin_registry = BuiltinRerankModelRegistry() + return _builtin_registry + + +def get_builtin_rerank_families() -> List: + """Get all built-in rerank model families.""" + return get_builtin_rerank_registry().get_builtin_models() + + +def register_builtin_rerank(rerank_family) -> None: + """Register a built-in rerank model family.""" + return get_builtin_rerank_registry().register_builtin_model(rerank_family) + + +def unregister_builtin_rerank(model_name: str) -> None: + """Unregister a built-in rerank model family.""" + return get_builtin_rerank_registry().unregister_builtin_model(model_name) diff --git a/xinference/model/video/__init__.py b/xinference/model/video/__init__.py index 5002fcc039..5d6d95425b 100644 --- a/xinference/model/video/__init__.py +++ b/xinference/model/video/__init__.py @@ -15,6 +15,7 @@ import codecs import json import os +import warnings from ..utils import flatten_model_src from .core import ( @@ -24,6 +25,57 @@ generate_video_description, get_video_model_descriptions, ) +from .custom import ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, +) + + +def register_custom_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import migrate_from_v1_to_v2 + + # migrate from v1 to v2 first + migrate_from_v1_to_v2("video", CustomVideoModelFamilyV2) + + user_defined_video_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "video") + if os.path.isdir(user_defined_video_dir): + for f in os.listdir(user_defined_video_dir): + try: + with codecs.open( + os.path.join(user_defined_video_dir, f), encoding="utf-8" + ) as fd: + user_defined_video_family = CustomVideoModelFamilyV2.parse_obj( + json.load(fd) + ) + register_video(user_defined_video_family, persist=False) + except Exception as e: + warnings.warn(f"{user_defined_video_dir}/{f} has error, {e}") + + +def register_builtin_model(): + """ + Dynamically load built-in video models from builtin/video directory. + This function is called every time model list is requested, + ensuring real-time updates without server restart. + """ + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + # Use the builtin registry to load models + from .builtin import BuiltinVideoModelRegistry + + builtin_registry = BuiltinVideoModelRegistry() + builtin_models = builtin_registry.get_builtin_models() + + for model in builtin_models: + # Only register if model doesn't already exist + if model.model_name not in existing_model_names: + register_video(model, persist=False) + existing_model_names.add(model.model_name) def _install(): diff --git a/xinference/model/video/builtin.py b/xinference/model/video/builtin.py new file mode 100644 index 0000000000..6affe65ab2 --- /dev/null +++ b/xinference/model/video/builtin.py @@ -0,0 +1,128 @@ +# Copyright 2022-2025 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import TYPE_CHECKING, List + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .custom import CustomVideoModelFamilyV2 + + +class BuiltinVideoModelRegistry: + """ + Registry for built-in video models downloaded from official model hub. + + These models are treated as built-in models and don't require model_family validation. + They are stored in ~/.xinference/model/v2/builtin/video/ directory. + """ + + def __init__(self): + from ...constants import XINFERENCE_MODEL_DIR + + self.builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "video") + os.makedirs(self.builtin_dir, exist_ok=True) + + def get_builtin_models(self) -> List["CustomVideoModelFamilyV2"]: + """Load all built-in video models from the builtin directory.""" + from .custom import CustomVideoModelFamilyV2 + + models: List["CustomVideoModelFamilyV2"] = [] + + if not os.path.exists(self.builtin_dir): + return models + + for filename in os.listdir(self.builtin_dir): + if filename.endswith(".json"): + file_path = os.path.join(self.builtin_dir, filename) + try: + with open(file_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Apply conversion logic to handle null model_id and other issues + if model_data.get("model_id") is None and "model_src" in model_data: + model_src = model_data["model_src"] + # Extract model_id from available sources + if ( + "huggingface" in model_src + and "model_id" in model_src["huggingface"] + ): + model_data["model_id"] = model_src["huggingface"][ + "model_id" + ] + elif ( + "modelscope" in model_src + and "model_id" in model_src["modelscope"] + ): + model_data["model_id"] = model_src["modelscope"]["model_id"] + + # Extract model_revision if available + if model_data.get("model_revision") is None: + if ( + "huggingface" in model_src + and "model_revision" in model_src["huggingface"] + ): + model_data["model_revision"] = model_src["huggingface"][ + "model_revision" + ] + elif ( + "modelscope" in model_src + and "model_revision" in model_src["modelscope"] + ): + model_data["model_revision"] = model_src["modelscope"][ + "model_revision" + ] + + # Parse using CustomVideoModelFamilyV2 + model = CustomVideoModelFamilyV2.parse_obj(model_data) + models.append(model) + logger.info(f"Loaded built-in video model: {model.model_name}") + + except Exception as e: + logger.warning( + f"Failed to load built-in model from {filename}: {e}" + ) + + return models + + def register_builtin_model(self, model) -> None: + """Register a built-in video model by saving it to the builtin directory.""" + persist_path = os.path.join(self.builtin_dir, f"{model.model_name}.json") + + try: + with open(persist_path, "w", encoding="utf-8") as f: + f.write(model.json(exclude_none=True)) + logger.info(f"Registered built-in video model: {model.model_name}") + except Exception as e: + logger.error(f"Failed to register built-in model {model.model_name}: {e}") + raise + + def unregister_builtin_model(self, model_name: str) -> None: + """Unregister a built-in video model by removing its JSON file.""" + persist_path = os.path.join(self.builtin_dir, f"{model_name}.json") + + if os.path.exists(persist_path): + try: + os.remove(persist_path) + logger.info(f"Unregistered built-in video model: {model_name}") + except Exception as e: + logger.error(f"Failed to unregister built-in model {model_name}: {e}") + raise + else: + logger.warning( + f"Built-in video model {model_name} not found for unregistration" + ) diff --git a/xinference/model/video/custom.py b/xinference/model/video/custom.py new file mode 100644 index 0000000000..11fe07a2d5 --- /dev/null +++ b/xinference/model/video/custom.py @@ -0,0 +1,70 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, List, Optional + +from ..._compat import ( + Literal, +) +from ..custom import ModelRegistry +from .core import VideoModelFamilyV2 + +logger = logging.getLogger(__name__) + + +class CustomVideoModelFamilyV2(VideoModelFamilyV2): + version: Literal[2] = 2 + model_id: Optional[str] # type: ignore + model_revision: Optional[str] # type: ignore + model_uri: Optional[str] + + +if TYPE_CHECKING: + from typing import TypeVar + + _T = TypeVar("_T", bound="CustomVideoModelFamilyV2") + + +class VideoModelRegistry(ModelRegistry): + model_type = "video" + + def __init__(self): + super().__init__() + + def get_user_defined_models(self) -> List["CustomVideoModelFamilyV2"]: + return self.get_custom_models() + + +video_registry = VideoModelRegistry() + + +def register_video(model_spec: CustomVideoModelFamilyV2, persist: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.register(model_spec, persist) + + +def unregister_video(model_name: str, raise_error: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.unregister(model_name, raise_error) + + +def get_user_defined_videos() -> List[CustomVideoModelFamilyV2]: + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + return registry.get_custom_models() diff --git a/xinference/ui/web/ui/src/locales/en.json b/xinference/ui/web/ui/src/locales/en.json index a12662732d..437fb45a5c 100644 --- a/xinference/ui/web/ui/src/locales/en.json +++ b/xinference/ui/web/ui/src/locales/en.json @@ -124,7 +124,23 @@ "featured": "featured", "all": "all", "cancelledSuccessfully": "Cancelled Successfully!", - "mustBeUnique": "{{key}} must be unique" + "mustBeUnique": "{{key}} must be unique", + "addModel": "Add Model", + "addModelDialog": { + "introPrefix": "To add a model, please go to the", + "platformLinkText": "Xinference Model Hub", + "introSuffix": "and fill in the corresponding model name.", + "modelName": "Model Name", + "modelName.tip": "Please enter the model name", + "placeholder": "e.g. qwen3 (case-sensitive)" + }, + "update": "Update", + "error": { + "name_not_matched": "No exact model name match found (case-sensitive)", + "downloadFailed": "Download failed", + "requestFailed": "Request failed", + "json_parse_error": "Failed to parse JSON" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ja.json b/xinference/ui/web/ui/src/locales/ja.json index dc1636bfd3..e4075f9e1d 100644 --- a/xinference/ui/web/ui/src/locales/ja.json +++ b/xinference/ui/web/ui/src/locales/ja.json @@ -124,7 +124,23 @@ "featured": "おすすめとお気に入り", "all": "すべて", "cancelledSuccessfully": "正常にキャンセルされました!", - "mustBeUnique": "{{key}} は一意でなければなりません" + "mustBeUnique": "{{key}} は一意でなければなりません", + "addModel": "モデルを追加", + "addModelDialog": { + "introPrefix": "モデルを追加するには、", + "platformLinkText": "Xinference モデルセンター", + "introSuffix": "で対応するモデル名を入力してください。", + "modelName": "モデル名", + "modelName.tip": "モデル名を入力してください", + "placeholder": "例:qwen3(大文字と小文字を区別します)" + }, + "update": "更新", + "error": { + "name_not_matched": "完全に一致するモデル名が見つかりません(大文字と小文字を区別します)", + "downloadFailed": "ダウンロードに失敗しました", + "requestFailed": "リクエストに失敗しました", + "json_parse_error": "JSON の解析に失敗しました" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ko.json b/xinference/ui/web/ui/src/locales/ko.json index 17ad7626a6..36fd0cd0c2 100644 --- a/xinference/ui/web/ui/src/locales/ko.json +++ b/xinference/ui/web/ui/src/locales/ko.json @@ -124,7 +124,23 @@ "featured": "추천 및 즐겨찾기", "all": "모두", "cancelledSuccessfully": "성공적으로 취소되었습니다!", - "mustBeUnique": "{{key}} 는 고유해야 합니다" + "mustBeUnique": "{{key}} 는 고유해야 합니다", + "addModel": "모델 추가", + "addModelDialog": { + "introPrefix": "모델을 추가하려면", + "platformLinkText": "Xinference 모델 센터", + "introSuffix": "에서 해당 모델 이름을 입력하세요.", + "modelName": "모델 이름", + "modelName.tip": "모델 이름을 입력하세요", + "placeholder": "예: qwen3 (대소문자를 구분합니다)" + }, + "update": "업데이트", + "error": { + "name_not_matched": "완전히 일치하는 모델 이름을 찾을 수 없습니다(대소문자 구분)", + "downloadFailed": "다운로드 실패", + "requestFailed": "요청 실패", + "json_parse_error": "JSON 구문 분석에 실패했습니다" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/zh.json b/xinference/ui/web/ui/src/locales/zh.json index 36daec1756..3a0a1d7a19 100644 --- a/xinference/ui/web/ui/src/locales/zh.json +++ b/xinference/ui/web/ui/src/locales/zh.json @@ -124,7 +124,23 @@ "featured": "推荐和收藏", "all": "全部", "cancelledSuccessfully": "取消成功!", - "mustBeUnique": "{{key}} 必须唯一" + "mustBeUnique": "{{key}} 必须唯一", + "addModel": "添加模型", + "addModelDialog": { + "introPrefix": "添加模型需基于", + "platformLinkText": "Xinference 模型中心", + "introSuffix": ",填写模型对应的名称", + "modelName": "模型名称", + "modelName.tip": "请输入模型名称", + "placeholder": "例如:qwen3(需大小写完全匹配)" + }, + "update": "更新", + "error": { + "name_not_matched": "未找到完全匹配的模型名称(需大小写一致)", + "downloadFailed": "下载失败", + "requestFailed": "请求失败", + "json_parse_error": "JSON 解析失败" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js index cba7bf9a65..623a122b6d 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js @@ -10,9 +10,11 @@ import { Select, } from '@mui/material' import React, { + forwardRef, useCallback, useContext, useEffect, + useImperativeHandle, useRef, useState, } from 'react' @@ -28,494 +30,507 @@ import ModelCard from './modelCard' // Toggle pagination globally for this page. Set to false to disable pagination and load all items. const ENABLE_PAGINATION = false -const LaunchModelComponent = ({ modelType, gpuAvailable, featureModels }) => { - const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) - const { isUpdatingModel } = useContext(ApiContext) - const { setErrorMsg } = useContext(ApiContext) - const [cookie] = useCookies(['token']) - - const [registrationData, setRegistrationData] = useState([]) - // States used for filtering - const [searchTerm, setSearchTerm] = useState('') - const [status, setStatus] = useState('') - const [statusArr, setStatusArr] = useState([]) - const [collectionArr, setCollectionArr] = useState([]) - const [filterArr, setFilterArr] = useState([]) - const { t } = useTranslation() - const [modelListType, setModelListType] = useState('featured') - const [modelAbilityData, setModelAbilityData] = useState({ - type: modelType, - modelAbility: '', - options: [], - }) - const [selectedModel, setSelectedModel] = useState(null) - const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = useState(false) - - // Pagination status - const [displayedData, setDisplayedData] = useState([]) - const [currentPage, setCurrentPage] = useState(1) - const [hasMore, setHasMore] = useState(true) - const itemsPerPage = 20 - const loaderRef = useRef(null) - - const filter = useCallback( - (registration) => { - if (searchTerm !== '') { - if (!registration || typeof searchTerm !== 'string') return false - const modelName = registration.model_name - ? registration.model_name.toLowerCase() - : '' - const modelDescription = registration.model_description - ? registration.model_description.toLowerCase() - : '' +const LaunchModelComponent = forwardRef( + ({ modelType, gpuAvailable, featureModels }, ref) => { + const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) + const { isUpdatingModel } = useContext(ApiContext) + const { setErrorMsg } = useContext(ApiContext) + const [cookie] = useCookies(['token']) + + const [registrationData, setRegistrationData] = useState([]) + // States used for filtering + const [searchTerm, setSearchTerm] = useState('') + const [status, setStatus] = useState('') + const [statusArr, setStatusArr] = useState([]) + const [collectionArr, setCollectionArr] = useState([]) + const [filterArr, setFilterArr] = useState([]) + const { t } = useTranslation() + const [modelListType, setModelListType] = useState('featured') + const [modelAbilityData, setModelAbilityData] = useState({ + type: modelType, + modelAbility: '', + options: [], + }) + const [selectedModel, setSelectedModel] = useState(null) + const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = + useState(false) + + // Pagination status + const [displayedData, setDisplayedData] = useState([]) + const [currentPage, setCurrentPage] = useState(1) + const [hasMore, setHasMore] = useState(true) + const itemsPerPage = 20 + const loaderRef = useRef(null) + + const filter = useCallback( + (registration) => { + if (searchTerm !== '') { + if (!registration || typeof searchTerm !== 'string') return false + const modelName = registration.model_name + ? registration.model_name.toLowerCase() + : '' + const modelDescription = registration.model_description + ? registration.model_description.toLowerCase() + : '' + + if ( + !modelName.includes(searchTerm.toLowerCase()) && + !modelDescription.includes(searchTerm.toLowerCase()) + ) { + return false + } + } - if ( - !modelName.includes(searchTerm.toLowerCase()) && - !modelDescription.includes(searchTerm.toLowerCase()) - ) { - return false + if (modelListType === 'featured') { + if ( + featureModels.length && + !featureModels.includes(registration.model_name) && + !collectionArr?.includes(registration.model_name) + ) { + return false + } } - } - if (modelListType === 'featured') { if ( - featureModels.length && - !featureModels.includes(registration.model_name) && - !collectionArr?.includes(registration.model_name) - ) { + modelAbilityData.modelAbility && + ((Array.isArray(registration.model_ability) && + registration.model_ability.indexOf(modelAbilityData.modelAbility) < + 0) || + (typeof registration.model_ability === 'string' && + registration.model_ability !== modelAbilityData.modelAbility)) + ) return false - } - } - if ( - modelAbilityData.modelAbility && - ((Array.isArray(registration.model_ability) && - registration.model_ability.indexOf(modelAbilityData.modelAbility) < - 0) || - (typeof registration.model_ability === 'string' && - registration.model_ability !== modelAbilityData.modelAbility)) - ) - return false - - if (statusArr.length === 1) { - if (statusArr[0] === 'cached') { + if (statusArr.length === 1) { + if (statusArr[0] === 'cached') { + const judge = + registration.model_specs?.some((spec) => filterCache(spec)) || + registration?.cache_status + return judge + } else { + return collectionArr?.includes(registration.model_name) + } + } else if (statusArr.length > 1) { const judge = registration.model_specs?.some((spec) => filterCache(spec)) || registration?.cache_status - return judge - } else { - return collectionArr?.includes(registration.model_name) + return judge && collectionArr?.includes(registration.model_name) } - } else if (statusArr.length > 1) { - const judge = - registration.model_specs?.some((spec) => filterCache(spec)) || - registration?.cache_status - return judge && collectionArr?.includes(registration.model_name) - } - return true - }, - [ - searchTerm, - modelListType, - featureModels, - collectionArr, - modelAbilityData.modelAbility, - statusArr, - ] - ) - - const filterCache = useCallback((spec) => { - if (Array.isArray(spec.cache_status)) { - return spec.cache_status?.some((cs) => cs) - } else { - return spec.cache_status === true - } - }, []) - - function getUniqueModelAbilities(arr) { - const uniqueAbilities = new Set() + return true + }, + [ + searchTerm, + modelListType, + featureModels, + collectionArr, + modelAbilityData.modelAbility, + statusArr, + ] + ) - arr.forEach((item) => { - if (Array.isArray(item.model_ability)) { - item.model_ability.forEach((ability) => { - uniqueAbilities.add(ability) - }) + const filterCache = useCallback((spec) => { + if (Array.isArray(spec.cache_status)) { + return spec.cache_status?.some((cs) => cs) + } else { + return spec.cache_status === true } - }) + }, []) - return Array.from(uniqueAbilities) - } + function getUniqueModelAbilities(arr) { + const uniqueAbilities = new Set() - const update = () => { - if ( - isCallingApi || - isUpdatingModel || - (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) - ) - return - - try { - setIsCallingApi(true) - - fetchWrapper - .get(`/v1/model_registrations/${modelType}?detailed=true`) - .then((data) => { - const builtinRegistrations = data.filter((v) => v.is_builtin) - setModelAbilityData({ - ...modelAbilityData, - options: getUniqueModelAbilities(builtinRegistrations), + arr.forEach((item) => { + if (Array.isArray(item.model_ability)) { + item.model_ability.forEach((ability) => { + uniqueAbilities.add(ability) }) - setRegistrationData(builtinRegistrations) - const collectionData = JSON.parse( - localStorage.getItem('collectionArr') - ) - setCollectionArr(collectionData) + } + }) - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - }) - .catch((error) => { - console.error('Error:', error) - if (error.response.status !== 403 && error.response.status !== 401) { - setErrorMsg(error.message) - } - }) - } catch (error) { - console.error('Error:', error) - } finally { - setIsCallingApi(false) + return Array.from(uniqueAbilities) } - } - useEffect(() => { - update() - }, [cookie.token]) + const update = () => { + if ( + isCallingApi || + isUpdatingModel || + (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) + ) + return + + try { + setIsCallingApi(true) + + fetchWrapper + .get(`/v1/model_registrations/${modelType}?detailed=true`) + .then((data) => { + const builtinRegistrations = data.filter((v) => v.is_builtin) + setModelAbilityData({ + ...modelAbilityData, + options: getUniqueModelAbilities(builtinRegistrations), + }) + setRegistrationData(builtinRegistrations) + const collectionData = JSON.parse( + localStorage.getItem('collectionArr') + ) + setCollectionArr(collectionData) + + // Reset pagination status + setCurrentPage(1) + setHasMore(true) + }) + .catch((error) => { + console.error('Error:', error) + if ( + error.response.status !== 403 && + error.response.status !== 401 + ) { + setErrorMsg(error.message) + } + }) + } catch (error) { + console.error('Error:', error) + } finally { + setIsCallingApi(false) + } + } - // Update pagination data - const updateDisplayedData = useCallback(() => { - const filteredData = registrationData.filter((registration) => - filter(registration) - ) + useEffect(() => { + update() + }, [cookie.token]) - const sortedData = [...filteredData].sort((a, b) => { - if (modelListType === 'featured') { - const indexA = featureModels.indexOf(a.model_name) - const indexB = featureModels.indexOf(b.model_name) - return ( - (indexA !== -1 ? indexA : Infinity) - - (indexB !== -1 ? indexB : Infinity) - ) + // Update pagination data + const updateDisplayedData = useCallback(() => { + const filteredData = registrationData.filter((registration) => + filter(registration) + ) + + const sortedData = [...filteredData].sort((a, b) => { + if (modelListType === 'featured') { + const indexA = featureModels.indexOf(a.model_name) + const indexB = featureModels.indexOf(b.model_name) + return ( + (indexA !== -1 ? indexA : Infinity) - + (indexB !== -1 ? indexB : Infinity) + ) + } + return 0 + }) + + // If pagination is disabled, show all data at once + if (!ENABLE_PAGINATION) { + setDisplayedData(sortedData) + setHasMore(false) + return } - return 0 - }) - // If pagination is disabled, show all data at once - if (!ENABLE_PAGINATION) { - setDisplayedData(sortedData) - setHasMore(false) - return - } + const startIndex = (currentPage - 1) * itemsPerPage + const endIndex = currentPage * itemsPerPage + const newData = sortedData.slice(startIndex, endIndex) - const startIndex = (currentPage - 1) * itemsPerPage - const endIndex = currentPage * itemsPerPage - const newData = sortedData.slice(startIndex, endIndex) + if (currentPage === 1) { + setDisplayedData(newData) + } else { + setDisplayedData((prev) => [...prev, ...newData]) + } + setHasMore(endIndex < sortedData.length) + }, [ + registrationData, + filter, + modelListType, + featureModels, + currentPage, + itemsPerPage, + ]) - if (currentPage === 1) { - setDisplayedData(newData) - } else { - setDisplayedData((prev) => [...prev, ...newData]) - } - setHasMore(endIndex < sortedData.length) - }, [ - registrationData, - filter, - modelListType, - featureModels, - currentPage, - itemsPerPage, - ]) - - useEffect(() => { - updateDisplayedData() - }, [updateDisplayedData]) - - // Reset pagination when filters change - useEffect(() => { - setCurrentPage(1) - setHasMore(true) - }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) - - // Infinite scroll observer - useEffect(() => { - if (!ENABLE_PAGINATION) return - - const observer = new IntersectionObserver( - (entries) => { - if (entries[0].isIntersecting && hasMore && !isCallingApi) { - setCurrentPage((prev) => prev + 1) - } - }, - { threshold: 1.0 } - ) + useEffect(() => { + updateDisplayedData() + }, [updateDisplayedData]) - if (loaderRef.current) { - observer.observe(loaderRef.current) - } + // Reset pagination when filters change + useEffect(() => { + setCurrentPage(1) + setHasMore(true) + }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) + + // Infinite scroll observer + useEffect(() => { + if (!ENABLE_PAGINATION) return + + const observer = new IntersectionObserver( + (entries) => { + if (entries[0].isIntersecting && hasMore && !isCallingApi) { + setCurrentPage((prev) => prev + 1) + } + }, + { threshold: 1.0 } + ) - return () => { if (loaderRef.current) { - observer.unobserve(loaderRef.current) + observer.observe(loaderRef.current) } - } - }, [hasMore, isCallingApi, currentPage]) - const getCollectionArr = (data) => { - setCollectionArr(data) - } + return () => { + if (loaderRef.current) { + observer.unobserve(loaderRef.current) + } + } + }, [hasMore, isCallingApi, currentPage]) - const handleChangeFilter = (type, value) => { - const typeMap = { - modelAbility: { - setter: (value) => { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: value, - }) - }, - filterArr: modelAbilityData.options, - }, - status: { setter: setStatus, filterArr: [] }, + const getCollectionArr = (data) => { + setCollectionArr(data) } - const { setter, filterArr: excludeArr } = typeMap[type] || {} - if (!setter) return + const handleChangeFilter = (type, value) => { + const typeMap = { + modelAbility: { + setter: (value) => { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: value, + }) + }, + filterArr: modelAbilityData.options, + }, + status: { setter: setStatus, filterArr: [] }, + } - setter(value) + const { setter, filterArr: excludeArr } = typeMap[type] || {} + if (!setter) return - const updatedFilterArr = Array.from( - new Set([ - ...filterArr.filter((item) => !excludeArr.includes(item)), - value, - ]) - ) + setter(value) + + const updatedFilterArr = Array.from( + new Set([ + ...filterArr.filter((item) => !excludeArr.includes(item)), + value, + ]) + ) - setFilterArr(updatedFilterArr) + setFilterArr(updatedFilterArr) - if (type === 'status') { - setStatusArr( - updatedFilterArr.filter( - (item) => ![...modelAbilityData.options].includes(item) + if (type === 'status') { + setStatusArr( + updatedFilterArr.filter( + (item) => ![...modelAbilityData.options].includes(item) + ) ) - ) - } + } - // Reset pagination status - setDisplayedData([]) - setCurrentPage(1) - setHasMore(true) - } + // Reset pagination status + setDisplayedData([]) + setCurrentPage(1) + setHasMore(true) + } - const handleDeleteChip = (item) => { - setFilterArr( - filterArr.filter((subItem) => { - return subItem !== item - }) - ) - if (item === modelAbilityData.modelAbility) { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: '', - }) - } else { - setStatusArr( - statusArr.filter((subItem) => { + const handleDeleteChip = (item) => { + setFilterArr( + filterArr.filter((subItem) => { return subItem !== item }) ) - if (item === status) setStatus('') - } - - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - } - - const handleModelType = (newModelType) => { - if (newModelType !== null) { - setModelListType(newModelType) + if (item === modelAbilityData.modelAbility) { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: '', + }) + } else { + setStatusArr( + statusArr.filter((subItem) => { + return subItem !== item + }) + ) + if (item === status) setStatus('') + } // Reset pagination status - setDisplayedData([]) setCurrentPage(1) setHasMore(true) } - } - function getLabel(item) { - const translation = t(`launchModel.${item}`) - return translation === `launchModel.${item}` ? item : translation - } + const handleModelType = (newModelType) => { + if (newModelType !== null) { + setModelListType(newModelType) - return ( - -
{ - const hasAbility = modelAbilityData.options.length > 0 - const hasFeature = featureModels.length > 0 - - const baseColumns = hasAbility ? ['200px', '150px'] : ['200px'] - const altColumns = hasAbility ? ['150px', '150px'] : ['150px'] - - const columns = hasFeature - ? [...baseColumns, '150px', '1fr'] - : [...altColumns, '1fr'] - - return columns.join(' ') - })(), - columnGap: '20px', - margin: '30px 2rem', - alignItems: 'center', - }} - > - {featureModels.length > 0 && ( - - - + + + + )} + {modelAbilityData.options.length > 0 && ( + + + {t('launchModel.modelAbility')} + + + + )} - - {t('launchModel.modelAbility')} + + {t('launchModel.status')} - )} - - {t('launchModel.status')} - - - - - { - setSearchTerm(e.target.value) - }} - size="small" - hotkey="Enter" - t={t} - /> - -
-
- {filterArr.map((item, index) => ( - handleDeleteChip(item)} - /> - ))} -
-
- {displayedData.map((filteredRegistration) => ( - + { + setSearchTerm(e.target.value) + }} + size="small" + hotkey="Enter" + t={t} + /> + +
+
+ {filterArr.map((item, index) => ( + handleDeleteChip(item)} + /> + ))} +
+
+ {displayedData.map((filteredRegistration) => ( + { + setSelectedModel(filteredRegistration) + setIsOpenLaunchModelDrawer(true) + }} + /> + ))} +
+ +
+ {ENABLE_PAGINATION && hasMore && !isCallingApi && ( +
+ +
+ )} +
+ + {selectedModel && ( + { - setSelectedModel(filteredRegistration) - setIsOpenLaunchModelDrawer(true) - }} + gpuAvailable={gpuAvailable} + open={isOpenLaunchModelDrawer} + onClose={() => setIsOpenLaunchModelDrawer(false)} /> - ))} - - -
- {ENABLE_PAGINATION && hasMore && !isCallingApi && ( -
- -
)} -
- - {selectedModel && ( - setIsOpenLaunchModelDrawer(false)} - /> - )} -
- ) -} + + ) + } +) + +LaunchModelComponent.displayName = 'LaunchModelComponent' export default LaunchModelComponent diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js new file mode 100644 index 0000000000..b258a2bdb5 --- /dev/null +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js @@ -0,0 +1,192 @@ +import { + Button, + Dialog, + DialogActions, + DialogContent, + DialogTitle, + TextField, +} from '@mui/material' +import React, { useContext, useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { ApiContext } from '../../../components/apiContext' + +const API_BASE_URL = 'https://model.xinference.io' + +const AddModelDialog = ({ open, onClose, onUpdateList }) => { + const { t } = useTranslation() + const [modelName, setModelName] = useState('') + const [loading, setLoading] = useState(false) + const { endPoint, setErrorMsg } = useContext(ApiContext) + + const searchModelByName = async (name) => { + try { + const url = `${API_BASE_URL}/api/models?order=featured&query=${encodeURIComponent( + name + )}&page=1&pageSize=5` + const res = await fetch(url, { method: 'GET' }) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return null + } + try { + const data = JSON.parse(rawText) + const items = data?.data || [] + const exact = items.find((it) => it?.model_name === name) + if (!exact) { + setErrorMsg(t('launchModel.error.name_not_matched')) + return null + } + const id = exact?.id + const modelType = exact?.model_type + if (!id || !modelType) { + setErrorMsg(t('launchModel.error.downloadFailed')) + return null + } + return { id, modelType } + } catch { + setErrorMsg(rawText || t('launchModel.error.json_parse_error')) + return null + } + } catch (err) { + console.error(err) + setErrorMsg(err.message || t('launchModel.error.requestFailed')) + return null + } + } + + const fetchModelJson = async (modelId) => { + try { + const res = await fetch( + `${API_BASE_URL}/api/models/download?model_id=${encodeURIComponent( + modelId + )}`, + { method: 'GET' } + ) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return null + } + try { + const data = JSON.parse(rawText) + return data + } catch { + setErrorMsg(rawText || t('launchModel.error.json_parse_error')) + return null + } + } catch (err) { + console.error(err) + setErrorMsg(err.message || t('launchModel.error.requestFailed')) + return null + } + } + + const addToLocal = async (modelType, modelJson) => { + try { + const res = await fetch(endPoint + '/v1/models/add', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ model_type: modelType, model_json: modelJson }), + }) + const rawText = await res.text().catch(() => '') + if (!res.ok) { + setErrorMsg(rawText || `HTTP ${res.status}`) + return + } + onClose(`/launch_model/${modelType}`) + onUpdateList(modelType) + } catch (error) { + console.error('Error:', error) + if (error?.response?.status !== 403) { + setErrorMsg(error.message) + } + } + } + + const handleFormSubmit = async (e) => { + e.preventDefault() + const name = modelName?.trim() + if (!name) { + setErrorMsg(t('launchModel.addModelDialog.modelName.tip')) + return + } + setLoading(true) + setErrorMsg('') + try { + const found = await searchModelByName(name) + if (!found) return + const { id, modelType } = found + + const modelJson = await fetchModelJson(id) + if (!modelJson) return + + await addToLocal(modelType, modelJson) + } finally { + setLoading(false) + } + } + + return ( + onClose()} width={500}> + {t('launchModel.addModel')} + +
+
+ {t('launchModel.addModelDialog.introPrefix')}{' '} + + {t('launchModel.addModelDialog.platformLinkText')} + + {t('launchModel.addModelDialog.introSuffix')} +
+
+ { + setModelName(e.target.value) + }} + disabled={loading} + /> + +
+
+ + + + +
+ ) +} + +export default AddModelDialog diff --git a/xinference/ui/web/ui/src/scenes/launch_model/index.js b/xinference/ui/web/ui/src/scenes/launch_model/index.js index 24f886a80d..4ac6cff612 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/index.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/index.js @@ -1,6 +1,7 @@ -import { TabContext, TabList, TabPanel } from '@mui/lab' -import { Box, Tab } from '@mui/material' -import React, { useContext, useEffect, useState } from 'react' +import Add from '@mui/icons-material/Add' +import { LoadingButton, TabContext, TabList, TabPanel } from '@mui/lab' +import { Box, Button, MenuItem, Select, Tab } from '@mui/material' +import React, { useContext, useEffect, useRef, useState } from 'react' import { useCookies } from 'react-cookie' import { useTranslation } from 'react-i18next' import { useNavigate } from 'react-router-dom' @@ -11,6 +12,7 @@ import fetchWrapper from '../../components/fetchWrapper' import SuccessMessageSnackBar from '../../components/successMessageSnackBar' import Title from '../../components/Title' import { isValidBearerToken } from '../../components/utils' +import AddModelDialog from './components/addModelDialog' import { featureModels } from './data/data' import LaunchCustom from './launchCustom' import LaunchModelComponent from './LaunchModel' @@ -22,13 +24,17 @@ const LaunchModel = () => { : '/launch_model/llm' ) const [gpuAvailable, setGPUAvailable] = useState(-1) + const [open, setOpen] = useState(false) + const [modelType, setModelType] = useState('llm') + const [loading, setLoading] = useState(false) const { setErrorMsg } = useContext(ApiContext) const [cookie] = useCookies(['token']) const navigate = useNavigate() const { t } = useTranslation() + const LaunchModelRefs = useRef({}) - const handleTabChange = (event, newValue) => { + const handleTabChange = (newValue) => { setValue(newValue) navigate(newValue) sessionStorage.setItem('modelType', newValue) @@ -59,14 +65,56 @@ const LaunchModel = () => { } }, [cookie.token]) + const updateList = (modelType) => { + LaunchModelRefs.current[modelType]?.update() + } + + const handleClose = (value) => { + setOpen(false) + if (value) { + handleTabChange(value) + } + } + + const updateModels = () => { + setLoading(true) + fetchWrapper + .post('/v1/models/update_type', { model_type: modelType }) + .then(() => { + handleTabChange(`/launch_model/${modelType}`) + updateList(modelType) + }) + .catch((error) => { + console.error('Error:', error) + if (error.response.status !== 403 && error.response.status !== 401) { + setErrorMsg(error.message) + } + }) + .finally(() => { + setLoading(false) + }) + } + return ( <ErrorMessageSnackBar /> <SuccessMessageSnackBar /> <TabContext value={value}> - <Box sx={{ borderBottom: 1, borderColor: 'divider' }}> - <TabList value={value} onChange={handleTabChange} aria-label="tabs"> + <Box + sx={{ + borderBottom: 1, + borderColor: 'divider', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }} + > + <TabList + value={value} + onChange={(_, value) => handleTabChange(value)} + aria-label="tabs" + > <Tab label={t('model.languageModels')} value="/launch_model/llm" /> <Tab label={t('model.embeddingModels')} @@ -81,6 +129,53 @@ const LaunchModel = () => { value="/launch_model/custom/llm" /> </TabList> + <Box + sx={{ + display: 'flex', + alignItems: 'center', + gap: '10px', + }} + > + <Box sx={{ display: 'flex', gap: 0 }}> + <Select + value={modelType} + onChange={(e) => setModelType(e.target.value)} + size="small" + sx={{ + borderTopRightRadius: 0, + borderBottomRightRadius: 0, + minWidth: 100, + }} + > + <MenuItem value="llm">LLM</MenuItem> + <MenuItem value="embedding">Embedding</MenuItem> + <MenuItem value="rerank">Rerank</MenuItem> + <MenuItem value="image">Image</MenuItem> + <MenuItem value="audio">Audio</MenuItem> + <MenuItem value="video">Video</MenuItem> + </Select> + + <LoadingButton + variant="contained" + onClick={updateModels} + loading={loading} + sx={{ + borderTopLeftRadius: 0, + borderBottomLeftRadius: 0, + whiteSpace: 'nowrap', + }} + > + {t('launchModel.update')} + </LoadingButton> + </Box> + <Button + variant="outlined" + startIcon={<Add />} + onClick={() => setOpen(true)} + > + {t('launchModel.addModel')} + </Button> + </Box> </Box> <TabPanel value="/launch_model/llm" sx={{ padding: 0 }}> <LaunchModelComponent @@ -89,6 +184,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'llm').feature_models } + ref={(ref) => (LaunchModelRefs.current.llm = ref)} /> </TabPanel> <TabPanel value="/launch_model/embedding" sx={{ padding: 0 }}> @@ -99,6 +195,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'embedding') .feature_models } + ref={(ref) => (LaunchModelRefs.current.embedding = ref)} /> </TabPanel> <TabPanel value="/launch_model/rerank" sx={{ padding: 0 }}> @@ -109,6 +206,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'rerank') .feature_models } + ref={(ref) => (LaunchModelRefs.current.rerank = ref)} /> </TabPanel> <TabPanel value="/launch_model/image" sx={{ padding: 0 }}> @@ -118,6 +216,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'image').feature_models } + ref={(ref) => (LaunchModelRefs.current.image = ref)} /> </TabPanel> <TabPanel value="/launch_model/audio" sx={{ padding: 0 }}> @@ -127,6 +226,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'audio').feature_models } + ref={(ref) => (LaunchModelRefs.current.audio = ref)} /> </TabPanel> <TabPanel value="/launch_model/video" sx={{ padding: 0 }}> @@ -136,12 +236,20 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'video').feature_models } + ref={(ref) => (LaunchModelRefs.current.video = ref)} /> </TabPanel> <TabPanel value="/launch_model/custom/llm" sx={{ padding: 0 }}> <LaunchCustom gpuAvailable={gpuAvailable} /> </TabPanel> </TabContext> + {open && ( + <AddModelDialog + onUpdateList={updateList} + open={open} + onClose={handleClose} + /> + )} </Box> ) }