Skip to content

Commit 1ba24e3

Browse files
committed
Link with llm v2
Signed-off-by: Ryan Levick <[email protected]>
1 parent 561b3e1 commit 1ba24e3

File tree

6 files changed

+99
-19
lines changed

6 files changed

+99
-19
lines changed

crates/llm-local/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use llm::{
1111
use rand::SeedableRng;
1212
use spin_core::async_trait;
1313
use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2};
14-
use spin_world::v1::llm::{self as wasi_llm};
14+
use spin_world::v2::llm::{self as wasi_llm};
1515
use std::{
1616
collections::hash_map::Entry,
1717
collections::HashMap,

crates/llm-remote-http/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
77
use serde_json::json;
88
use spin_core::async_trait;
99
use spin_llm::LlmEngine;
10-
use spin_world::v1::llm::{self as wasi_llm};
10+
use spin_world::v2::llm::{self as wasi_llm};
1111

1212
#[derive(Clone)]
1313
pub struct RemoteHttpLlmEngine {

crates/llm/src/host_component.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ impl HostComponent for LlmComponent {
2525
linker: &mut spin_core::Linker<T>,
2626
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
2727
) -> anyhow::Result<()> {
28-
spin_world::v1::llm::add_to_linker(linker, get)
28+
spin_world::v1::llm::add_to_linker(linker, get)?;
29+
spin_world::v2::llm::add_to_linker(linker, get)
2930
}
3031

3132
fn build_data(&self) -> Self::Data {

crates/llm/src/lib.rs

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ pub mod host_component;
22

33
use spin_app::MetadataKey;
44
use spin_core::async_trait;
5-
use spin_world::v1::llm::{self as wasi_llm};
5+
use spin_world::v1::llm::{self as v1};
6+
use spin_world::v2::llm::{self as v2};
67
use std::collections::HashSet;
78

89
pub use crate::host_component::LlmComponent;
@@ -14,16 +15,16 @@ pub const AI_MODELS_KEY: MetadataKey<HashSet<String>> = MetadataKey::new("ai_mod
1415
pub trait LlmEngine: Send + Sync {
1516
async fn infer(
1617
&mut self,
17-
model: wasi_llm::InferencingModel,
18+
model: v1::InferencingModel,
1819
prompt: String,
19-
params: wasi_llm::InferencingParams,
20-
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error>;
20+
params: v2::InferencingParams,
21+
) -> Result<v2::InferencingResult, v2::Error>;
2122

2223
async fn generate_embeddings(
2324
&mut self,
24-
model: wasi_llm::EmbeddingModel,
25+
model: v2::EmbeddingModel,
2526
data: Vec<String>,
26-
) -> Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>;
27+
) -> Result<v2::EmbeddingsResult, v2::Error>;
2728
}
2829

2930
pub struct LlmDispatch {
@@ -32,13 +33,13 @@ pub struct LlmDispatch {
3233
}
3334

3435
#[async_trait]
35-
impl wasi_llm::Host for LlmDispatch {
36+
impl v2::Host for LlmDispatch {
3637
async fn infer(
3738
&mut self,
38-
model: wasi_llm::InferencingModel,
39+
model: v2::InferencingModel,
3940
prompt: String,
40-
params: Option<wasi_llm::InferencingParams>,
41-
) -> anyhow::Result<Result<wasi_llm::InferencingResult, wasi_llm::Error>> {
41+
params: Option<v2::InferencingParams>,
42+
) -> anyhow::Result<Result<v2::InferencingResult, v2::Error>> {
4243
if !self.allowed_models.contains(&model) {
4344
return Ok(Err(access_denied_error(&model)));
4445
}
@@ -47,7 +48,7 @@ impl wasi_llm::Host for LlmDispatch {
4748
.infer(
4849
model,
4950
prompt,
50-
params.unwrap_or(wasi_llm::InferencingParams {
51+
params.unwrap_or(v2::InferencingParams {
5152
max_tokens: 100,
5253
repeat_penalty: 1.1,
5354
repeat_penalty_last_n_token_count: 64,
@@ -61,18 +62,46 @@ impl wasi_llm::Host for LlmDispatch {
6162

6263
async fn generate_embeddings(
6364
&mut self,
64-
m: wasi_llm::EmbeddingModel,
65+
m: v1::EmbeddingModel,
6566
data: Vec<String>,
66-
) -> anyhow::Result<Result<wasi_llm::EmbeddingsResult, wasi_llm::Error>> {
67+
) -> anyhow::Result<Result<v2::EmbeddingsResult, v2::Error>> {
6768
if !self.allowed_models.contains(&m) {
6869
return Ok(Err(access_denied_error(&m)));
6970
}
7071
Ok(self.engine.generate_embeddings(m, data).await)
7172
}
7273
}
7374

74-
fn access_denied_error(model: &str) -> wasi_llm::Error {
75-
wasi_llm::Error::InvalidInput(format!(
75+
#[async_trait]
76+
impl v1::Host for LlmDispatch {
77+
async fn infer(
78+
&mut self,
79+
model: v1::InferencingModel,
80+
prompt: String,
81+
params: Option<v1::InferencingParams>,
82+
) -> anyhow::Result<Result<v1::InferencingResult, v1::Error>> {
83+
Ok(
84+
<Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
85+
.await?
86+
.map(Into::into)
87+
.map_err(Into::into),
88+
)
89+
}
90+
91+
async fn generate_embeddings(
92+
&mut self,
93+
model: v1::EmbeddingModel,
94+
data: Vec<String>,
95+
) -> anyhow::Result<Result<v1::EmbeddingsResult, v1::Error>> {
96+
Ok(<Self as v2::Host>::generate_embeddings(self, model, data)
97+
.await?
98+
.map(Into::into)
99+
.map_err(Into::into))
100+
}
101+
}
102+
103+
fn access_denied_error(model: &str) -> v2::Error {
104+
v2::Error::InvalidInput(format!(
76105
"The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
77106
))
78107
}

crates/trigger/src/runtime_config/llm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use async_trait::async_trait;
22
use spin_llm::LlmEngine;
33
use spin_llm_remote_http::RemoteHttpLlmEngine;
4-
use spin_world::v1::llm as wasi_llm;
4+
use spin_world::v2::llm as wasi_llm;
55
use url::Url;
66

77
#[derive(Default)]

crates/world/src/conversions.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,53 @@ mod redis {
168168
}
169169
}
170170
}
171+
172+
mod llm {
173+
use super::*;
174+
175+
impl From<v1::llm::InferencingParams> for v2::llm::InferencingParams {
176+
fn from(value: v1::llm::InferencingParams) -> Self {
177+
Self {
178+
max_tokens: value.max_tokens,
179+
repeat_penalty: value.repeat_penalty,
180+
repeat_penalty_last_n_token_count: value.repeat_penalty_last_n_token_count,
181+
temperature: value.temperature,
182+
top_k: value.top_k,
183+
top_p: value.top_p,
184+
}
185+
}
186+
}
187+
188+
impl From<v2::llm::InferencingResult> for v1::llm::InferencingResult {
189+
fn from(value: v2::llm::InferencingResult) -> Self {
190+
Self {
191+
text: value.text,
192+
usage: v1::llm::InferencingUsage {
193+
prompt_token_count: value.usage.prompt_token_count,
194+
generated_token_count: value.usage.prompt_token_count,
195+
},
196+
}
197+
}
198+
}
199+
200+
impl From<v2::llm::EmbeddingsResult> for v1::llm::EmbeddingsResult {
201+
fn from(value: v2::llm::EmbeddingsResult) -> Self {
202+
Self {
203+
embeddings: value.embeddings,
204+
usage: v1::llm::EmbeddingsUsage {
205+
prompt_token_count: value.usage.prompt_token_count,
206+
},
207+
}
208+
}
209+
}
210+
211+
impl From<v2::llm::Error> for v1::llm::Error {
212+
fn from(value: v2::llm::Error) -> Self {
213+
match value {
214+
v2::llm::Error::ModelNotSupported => Self::ModelNotSupported,
215+
v2::llm::Error::RuntimeError(s) => Self::RuntimeError(s),
216+
v2::llm::Error::InvalidInput(s) => Self::InvalidInput(s),
217+
}
218+
}
219+
}
220+
}

0 commit comments

Comments
 (0)