@@ -71,6 +71,7 @@ void ModelRegistry::register_causallm_factory(const std::string& name,
7171 << " already registered." );
7272 } else {
7373 instance->model_registry_ [name].causal_lm_factory = factory;
74+ instance->model_backend_ [name] = " llm" ;
7475 }
7576}
7677
@@ -83,6 +84,7 @@ void ModelRegistry::register_causalvlm_factory(const std::string& name,
8384 << " already registered." );
8485 } else {
8586 instance->model_registry_ [name].causal_vlm_factory = factory;
87+ instance->model_backend_ [name] = " vlm" ;
8688 }
8789}
8890
@@ -95,6 +97,7 @@ void ModelRegistry::register_embeddinglm_factory(const std::string& name,
9597 << " already registered." );
9698 } else {
9799 instance->model_registry_ [name].embedding_lm_factory = factory;
100+ instance->model_backend_ [name] = " llm" ;
98101 }
99102}
100103
@@ -107,6 +110,7 @@ void ModelRegistry::register_dit_model_factory(const std::string& name,
107110 << " already registered." );
108111 } else {
109112 instance->model_registry_ [name].dit_model_factory = factory;
113+ instance->model_backend_ [name] = " dit" ;
110114 }
111115}
112116
@@ -229,6 +233,11 @@ TokenizerArgsLoader ModelRegistry::get_tokenizer_args_loader(
229233 return instance->model_registry_ [name].tokenizer_args_loader ;
230234}
231235
236+ std::string ModelRegistry::get_model_backend (const std::string& name) {
237+ ModelRegistry* instance = get_instance ();
238+ return instance->model_backend_ [name];
239+ }
240+
232241std::unique_ptr<CausalLM> create_llm_model (const ModelContext& context) {
233242 // get the factory function for the model type from model registry
234243 auto factory = ModelRegistry::get_causallm_factory (
0 commit comments