@@ -6,11 +6,11 @@ use candle::DType;
66use candle_nn:: VarBuilder ;
77use llm:: {
88 InferenceFeedback , InferenceParameters , InferenceResponse , InferenceSessionConfig , Model ,
9- ModelKVMemoryType , ModelParameters ,
9+ ModelArchitecture , ModelKVMemoryType , ModelParameters ,
1010} ;
1111use rand:: SeedableRng ;
1212use spin_core:: async_trait;
13- use spin_llm:: { model_arch , model_name , LlmEngine , MODEL_ALL_MINILM_L6_V2 } ;
13+ use spin_llm:: { LlmEngine , MODEL_ALL_MINILM_L6_V2 } ;
1414use spin_world:: llm:: { self as wasi_llm} ;
1515use std:: {
1616 collections:: hash_map:: Entry ,
@@ -170,14 +170,22 @@ impl LocalLlmEngine {
170170 & mut self ,
171171 model : wasi_llm:: InferencingModel ,
172172 ) -> Result < Arc < dyn Model > , wasi_llm:: Error > {
173- let model_name = model_name ( & model) ?;
174173 let use_gpu = self . use_gpu ;
175174 let progress_fn = |_| { } ;
176- let model = match self . inferencing_models . entry ( ( model_name . into ( ) , use_gpu) ) {
175+ let model = match self . inferencing_models . entry ( ( model . clone ( ) , use_gpu) ) {
177176 Entry :: Occupied ( o) => o. get ( ) . clone ( ) ,
178177 Entry :: Vacant ( v) => v
179178 . insert ( {
180- let path = self . registry . join ( model_name) ;
179+ let ( path, arch) = if let Some ( arch) = well_known_inferencing_model_arch ( & model) {
180+ let model_binary = self . registry . join ( & model) ;
181+ if model_binary. exists ( ) {
182+ ( model_binary, arch. to_owned ( ) )
183+ } else {
184+ walk_registry_for_model ( & self . registry , model) . await ?
185+ }
186+ } else {
187+ walk_registry_for_model ( & self . registry , model) . await ?
188+ } ;
181189 if !self . registry . exists ( ) {
182190 return Err ( wasi_llm:: Error :: RuntimeError (
183191 format ! ( "The directory expected to house the inferencing model '{}' does not exist." , self . registry. display( ) )
@@ -199,7 +207,7 @@ impl LocalLlmEngine {
199207 n_gqa : None ,
200208 } ;
201209 let model = llm:: load_dynamic (
202- Some ( model_arch ( & model ) ? ) ,
210+ Some ( arch ) ,
203211 & path,
204212 llm:: TokenizerSource :: Embedded ,
205213 params,
@@ -223,6 +231,80 @@ impl LocalLlmEngine {
223231 }
224232}
225233
234+ /// Get the model binary and arch from walking the registry file structure
235+ async fn walk_registry_for_model (
236+ registry_path : & Path ,
237+ model : String ,
238+ ) -> Result < ( PathBuf , ModelArchitecture ) , wasi_llm:: Error > {
239+ let mut arch_dirs = tokio:: fs:: read_dir ( registry_path) . await . map_err ( |e| {
240+ wasi_llm:: Error :: RuntimeError ( format ! (
241+ "Could not read model registry directory '{}': {e}" ,
242+ registry_path. display( )
243+ ) )
244+ } ) ?;
245+ let mut result = None ;
246+ ' outer: while let Some ( arch_dir) = arch_dirs. next_entry ( ) . await . map_err ( |e| {
247+ wasi_llm:: Error :: RuntimeError ( format ! (
248+ "Failed to read arch directory in model registry: {e}"
249+ ) )
250+ } ) ? {
251+ if arch_dir
252+ . file_type ( )
253+ . await
254+ . map_err ( |e| {
255+ wasi_llm:: Error :: RuntimeError ( format ! (
256+ "Could not read file type of '{}' dir: {e}" ,
257+ arch_dir. path( ) . display( )
258+ ) )
259+ } ) ?
260+ . is_file ( )
261+ {
262+ continue ;
263+ }
264+ let mut model_files = tokio:: fs:: read_dir ( arch_dir. path ( ) ) . await . map_err ( |e| {
265+ wasi_llm:: Error :: RuntimeError ( format ! (
266+ "Error reading architecture directory in model registry: {e}"
267+ ) )
268+ } ) ?;
269+ while let Some ( model_file) = model_files. next_entry ( ) . await . map_err ( |e| {
270+ wasi_llm:: Error :: RuntimeError ( format ! (
271+ "Error reading model file in model registry: {e}"
272+ ) )
273+ } ) ? {
274+ if model_file
275+ . file_name ( )
276+ . to_str ( )
277+ . map ( |m| m == model)
278+ . unwrap_or_default ( )
279+ {
280+ let arch = arch_dir. file_name ( ) ;
281+ let arch = arch
282+ . to_str ( )
283+ . ok_or ( wasi_llm:: Error :: ModelNotSupported ) ?
284+ . parse ( )
285+ . map_err ( |_| wasi_llm:: Error :: ModelNotSupported ) ?;
286+ result = Some ( ( model_file. path ( ) , arch) ) ;
287+ break ' outer;
288+ }
289+ }
290+ }
291+
292+ result. ok_or_else ( || {
293+ wasi_llm:: Error :: InvalidInput ( format ! (
294+ "no model directory found in registry for model '{model}'"
295+ ) )
296+ } )
297+ }
298+
299+ fn well_known_inferencing_model_arch (
300+ model : & wasi_llm:: InferencingModel ,
301+ ) -> Option < ModelArchitecture > {
302+ match model. as_str ( ) {
303+ "llama2-chat" | "code_llama" => Some ( ModelArchitecture :: Llama ) ,
304+ _ => None ,
305+ }
306+ }
307+
226308async fn generate_embeddings (
227309 data : Vec < String > ,
228310 model : Arc < ( tokenizers:: Tokenizer , BertModel ) > ,
0 commit comments