@@ -2,7 +2,8 @@ pub mod host_component;
22
33use spin_app:: MetadataKey ;
44use spin_core:: async_trait;
5- use spin_world:: v1:: llm:: { self as wasi_llm} ;
5+ use spin_world:: v1:: llm:: { self as v1} ;
6+ use spin_world:: v2:: llm:: { self as v2} ;
67use std:: collections:: HashSet ;
78
89pub use crate :: host_component:: LlmComponent ;
@@ -14,16 +15,16 @@ pub const AI_MODELS_KEY: MetadataKey<HashSet<String>> = MetadataKey::new("ai_mod
1415pub trait LlmEngine : Send + Sync {
1516 async fn infer (
1617 & mut self ,
17- model : wasi_llm :: InferencingModel ,
18+ model : v1 :: InferencingModel ,
1819 prompt : String ,
19- params : wasi_llm :: InferencingParams ,
20- ) -> Result < wasi_llm :: InferencingResult , wasi_llm :: Error > ;
20+ params : v2 :: InferencingParams ,
21+ ) -> Result < v2 :: InferencingResult , v2 :: Error > ;
2122
2223 async fn generate_embeddings (
2324 & mut self ,
24- model : wasi_llm :: EmbeddingModel ,
25+ model : v2 :: EmbeddingModel ,
2526 data : Vec < String > ,
26- ) -> Result < wasi_llm :: EmbeddingsResult , wasi_llm :: Error > ;
27+ ) -> Result < v2 :: EmbeddingsResult , v2 :: Error > ;
2728}
2829
2930pub struct LlmDispatch {
@@ -32,13 +33,13 @@ pub struct LlmDispatch {
3233}
3334
3435#[ async_trait]
35- impl wasi_llm :: Host for LlmDispatch {
36+ impl v2 :: Host for LlmDispatch {
3637 async fn infer (
3738 & mut self ,
38- model : wasi_llm :: InferencingModel ,
39+ model : v2 :: InferencingModel ,
3940 prompt : String ,
40- params : Option < wasi_llm :: InferencingParams > ,
41- ) -> anyhow:: Result < Result < wasi_llm :: InferencingResult , wasi_llm :: Error > > {
41+ params : Option < v2 :: InferencingParams > ,
42+ ) -> anyhow:: Result < Result < v2 :: InferencingResult , v2 :: Error > > {
4243 if !self . allowed_models . contains ( & model) {
4344 return Ok ( Err ( access_denied_error ( & model) ) ) ;
4445 }
@@ -47,7 +48,7 @@ impl wasi_llm::Host for LlmDispatch {
4748 . infer (
4849 model,
4950 prompt,
50- params. unwrap_or ( wasi_llm :: InferencingParams {
51+ params. unwrap_or ( v2 :: InferencingParams {
5152 max_tokens : 100 ,
5253 repeat_penalty : 1.1 ,
5354 repeat_penalty_last_n_token_count : 64 ,
@@ -61,18 +62,46 @@ impl wasi_llm::Host for LlmDispatch {
6162
6263 async fn generate_embeddings (
6364 & mut self ,
64- m : wasi_llm :: EmbeddingModel ,
65+ m : v1 :: EmbeddingModel ,
6566 data : Vec < String > ,
66- ) -> anyhow:: Result < Result < wasi_llm :: EmbeddingsResult , wasi_llm :: Error > > {
67+ ) -> anyhow:: Result < Result < v2 :: EmbeddingsResult , v2 :: Error > > {
6768 if !self . allowed_models . contains ( & m) {
6869 return Ok ( Err ( access_denied_error ( & m) ) ) ;
6970 }
7071 Ok ( self . engine . generate_embeddings ( m, data) . await )
7172 }
7273}
7374
74- fn access_denied_error ( model : & str ) -> wasi_llm:: Error {
75- wasi_llm:: Error :: InvalidInput ( format ! (
75+ #[ async_trait]
76+ impl v1:: Host for LlmDispatch {
77+ async fn infer (
78+ & mut self ,
79+ model : v1:: InferencingModel ,
80+ prompt : String ,
81+ params : Option < v1:: InferencingParams > ,
82+ ) -> anyhow:: Result < Result < v1:: InferencingResult , v1:: Error > > {
83+ Ok (
84+ <Self as v2:: Host >:: infer ( self , model, prompt, params. map ( Into :: into) )
85+ . await ?
86+ . map ( Into :: into)
87+ . map_err ( Into :: into) ,
88+ )
89+ }
90+
91+ async fn generate_embeddings (
92+ & mut self ,
93+ model : v1:: EmbeddingModel ,
94+ data : Vec < String > ,
95+ ) -> anyhow:: Result < Result < v1:: EmbeddingsResult , v1:: Error > > {
96+ Ok ( <Self as v2:: Host >:: generate_embeddings ( self , model, data)
97+ . await ?
98+ . map ( Into :: into)
99+ . map_err ( Into :: into) )
100+ }
101+ }
102+
103+ fn access_denied_error ( model : & str ) -> v2:: Error {
104+ v2:: Error :: InvalidInput ( format ! (
76105 "The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
77106 ) )
78107}
0 commit comments