Skip to content

Commit be24451

Browse files
authored
Merge pull request #2742 from fermyon/integrate-llm-factor
[Factors] Integrate llm factor
2 parents 132faa7 + a7c9163 commit be24451

File tree

15 files changed

+243
-72
lines changed

15 files changed

+243
-72
lines changed

Cargo.lock

Lines changed: 21 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-llm/Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[package]
2-
name = "factor-llm"
2+
name = "spin-factor-llm"
33
version.workspace = true
44
authors.workspace = true
55
edition.workspace = true
@@ -11,10 +11,16 @@ rust-version.workspace = true
1111
[dependencies]
1212
anyhow = "1.0"
1313
async-trait = "0.1"
14+
serde = "1.0"
1415
spin-factors = { path = "../factors" }
16+
spin-llm-local = { path = "../llm-local" }
17+
spin-llm-remote-http = { path = "../llm-remote-http" }
1518
spin-locked-app = { path = "../locked-app" }
1619
spin-world = { path = "../world" }
1720
tracing = { workspace = true }
21+
tokio = { version = "1", features = ["sync"] }
22+
toml = "0.8"
23+
url = "2"
1824

1925
[dev-dependencies]
2026
spin-factors-test = { path = "../factors-test" }

crates/factor-llm/src/host.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ impl v2::Host for InstanceState {
1616
return Err(access_denied_error(&model));
1717
}
1818
self.engine
19+
.lock()
20+
.await
1921
.infer(
2022
model,
2123
prompt,
@@ -39,7 +41,7 @@ impl v2::Host for InstanceState {
3941
if !self.allowed_models.contains(&m) {
4042
return Err(access_denied_error(&m));
4143
}
42-
self.engine.generate_embeddings(m, data).await
44+
self.engine.lock().await.generate_embeddings(m, data).await
4345
}
4446

4547
fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {

crates/factor-llm/src/lib.rs

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod host;
2+
pub mod spin;
23

34
use std::collections::{HashMap, HashSet};
45
use std::sync::Arc;
@@ -11,26 +12,28 @@ use spin_factors::{
1112
use spin_locked_app::MetadataKey;
1213
use spin_world::v1::llm::{self as v1};
1314
use spin_world::v2::llm::{self as v2};
15+
use tokio::sync::Mutex;
1416

1517
pub const ALLOWED_MODELS_KEY: MetadataKey<Vec<String>> = MetadataKey::new("ai_models");
1618

19+
/// The factor for LLMs.
1720
pub struct LlmFactor {
18-
create_engine: Box<dyn Fn() -> Box<dyn LlmEngine> + Send + Sync>,
21+
default_engine_creator: Box<dyn LlmEngineCreator>,
1922
}
2023

2124
impl LlmFactor {
22-
pub fn new<F>(create_engine: F) -> Self
23-
where
24-
F: Fn() -> Box<dyn LlmEngine> + Send + Sync + 'static,
25-
{
25+
/// Creates a new LLM factor with the given default engine creator.
26+
///
27+
/// The default engine creator is used to create the engine if no runtime configuration is provided.
28+
pub fn new<F: LlmEngineCreator + 'static>(default_engine_creator: F) -> Self {
2629
Self {
27-
create_engine: Box::new(create_engine),
30+
default_engine_creator: Box::new(default_engine_creator),
2831
}
2932
}
3033
}
3134

3235
impl Factor for LlmFactor {
33-
type RuntimeConfig = ();
36+
type RuntimeConfig = RuntimeConfig;
3437
type AppState = AppState;
3538
type InstanceBuilder = InstanceState;
3639

@@ -45,7 +48,7 @@ impl Factor for LlmFactor {
4548

4649
fn configure_app<T: RuntimeFactors>(
4750
&self,
48-
ctx: ConfigureAppContext<T, Self>,
51+
mut ctx: ConfigureAppContext<T, Self>,
4952
) -> anyhow::Result<Self::AppState> {
5053
let component_allowed_models = ctx
5154
.app()
@@ -62,7 +65,12 @@ impl Factor for LlmFactor {
6265
))
6366
})
6467
.collect::<anyhow::Result<_>>()?;
68+
let engine = ctx
69+
.take_runtime_config()
70+
.map(|c| c.engine)
71+
.unwrap_or_else(|| self.default_engine_creator.create());
6572
Ok(AppState {
73+
engine,
6674
component_allowed_models,
6775
})
6876
}
@@ -78,25 +86,35 @@ impl Factor for LlmFactor {
7886
.get(ctx.app_component().id())
7987
.cloned()
8088
.unwrap_or_default();
89+
let engine = ctx.app_state().engine.clone();
8190

8291
Ok(InstanceState {
83-
engine: (self.create_engine)(),
92+
engine,
8493
allowed_models,
8594
})
8695
}
8796
}
8897

98+
/// The application state for the LLM factor.
8999
pub struct AppState {
100+
engine: Arc<Mutex<dyn LlmEngine>>,
90101
component_allowed_models: HashMap<String, Arc<HashSet<String>>>,
91102
}
92103

104+
/// The instance state for the LLM factor.
93105
pub struct InstanceState {
94-
engine: Box<dyn LlmEngine>,
106+
engine: Arc<Mutex<dyn LlmEngine>>,
95107
pub allowed_models: Arc<HashSet<String>>,
96108
}
97109

110+
/// The runtime configuration for the LLM factor.
111+
pub struct RuntimeConfig {
112+
engine: Arc<Mutex<dyn LlmEngine>>,
113+
}
114+
98115
impl SelfInstanceBuilder for InstanceState {}
99116

117+
/// The interface for a language model engine.
100118
#[async_trait]
101119
pub trait LlmEngine: Send + Sync {
102120
async fn infer(
@@ -112,3 +130,17 @@ pub trait LlmEngine: Send + Sync {
112130
data: Vec<String>,
113131
) -> Result<v2::EmbeddingsResult, v2::Error>;
114132
}
133+
134+
/// A creator for an LLM engine.
135+
pub trait LlmEngineCreator: Send + Sync {
136+
fn create(&self) -> Arc<Mutex<dyn LlmEngine>>;
137+
}
138+
139+
impl<F> LlmEngineCreator for F
140+
where
141+
F: Fn() -> Arc<Mutex<dyn LlmEngine>> + Send + Sync,
142+
{
143+
fn create(&self) -> Arc<Mutex<dyn LlmEngine>> {
144+
self()
145+
}
146+
}

crates/factor-llm/src/spin.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use std::path::PathBuf;
2+
use std::sync::Arc;
3+
4+
pub use spin_llm_local::LocalLlmEngine;
5+
6+
use spin_llm_remote_http::RemoteHttpLlmEngine;
7+
use spin_world::async_trait;
8+
use spin_world::v1::llm::{self as v1};
9+
use spin_world::v2::llm::{self as v2};
10+
use tokio::sync::Mutex;
11+
use url::Url;
12+
13+
use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig};
14+
15+
#[async_trait]
16+
impl LlmEngine for LocalLlmEngine {
17+
async fn infer(
18+
&mut self,
19+
model: v1::InferencingModel,
20+
prompt: String,
21+
params: v2::InferencingParams,
22+
) -> Result<v2::InferencingResult, v2::Error> {
23+
self.infer(model, prompt, params).await
24+
}
25+
26+
async fn generate_embeddings(
27+
&mut self,
28+
model: v2::EmbeddingModel,
29+
data: Vec<String>,
30+
) -> Result<v2::EmbeddingsResult, v2::Error> {
31+
self.generate_embeddings(model, data).await
32+
}
33+
}
34+
35+
#[async_trait]
36+
impl LlmEngine for RemoteHttpLlmEngine {
37+
async fn infer(
38+
&mut self,
39+
model: v1::InferencingModel,
40+
prompt: String,
41+
params: v2::InferencingParams,
42+
) -> Result<v2::InferencingResult, v2::Error> {
43+
self.infer(model, prompt, params).await
44+
}
45+
46+
async fn generate_embeddings(
47+
&mut self,
48+
model: v2::EmbeddingModel,
49+
data: Vec<String>,
50+
) -> Result<v2::EmbeddingsResult, v2::Error> {
51+
self.generate_embeddings(model, data).await
52+
}
53+
}
54+
55+
pub fn runtime_config_from_toml(
56+
table: &toml::Table,
57+
state_dir: PathBuf,
58+
use_gpu: bool,
59+
) -> anyhow::Result<Option<RuntimeConfig>> {
60+
let Some(value) = table.get("llm_compute") else {
61+
return Ok(None);
62+
};
63+
let config: LlmCompute = value.clone().try_into()?;
64+
65+
Ok(Some(RuntimeConfig {
66+
engine: config.into_engine(state_dir, use_gpu),
67+
}))
68+
}
69+
70+
#[derive(Debug, serde::Deserialize)]
71+
#[serde(rename_all = "snake_case", tag = "type")]
72+
pub enum LlmCompute {
73+
Spin,
74+
RemoteHttp(RemoteHttpCompute),
75+
}
76+
77+
impl LlmCompute {
78+
fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> {
79+
match self {
80+
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu).create(),
81+
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
82+
config.url,
83+
config.auth_token,
84+
))),
85+
}
86+
}
87+
}
88+
89+
#[derive(Debug, serde::Deserialize)]
90+
pub struct RemoteHttpCompute {
91+
url: Url,
92+
auth_token: String,
93+
}
94+
95+
/// The default engine creator for the LLM factor when used in the Spin CLI.
96+
pub fn default_engine_creator(
97+
state_dir: PathBuf,
98+
use_gpu: bool,
99+
) -> impl LlmEngineCreator + 'static {
100+
move || {
101+
Arc::new(Mutex::new(LocalLlmEngine::new(
102+
state_dir.join("ai-models"),
103+
use_gpu,
104+
))) as _
105+
}
106+
}

crates/factor-llm/tests/factor_test.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use std::collections::HashSet;
2+
use std::sync::Arc;
23

3-
use factor_llm::{LlmEngine, LlmFactor};
4+
use spin_factor_llm::{LlmEngine, LlmFactor};
45
use spin_factors::{anyhow, RuntimeFactors};
56
use spin_factors_test::{toml, TestEnvironment};
67
use spin_world::v1::llm::{self as v1};
78
use spin_world::v2::llm::{self as v2, Host};
9+
use tokio::sync::Mutex;
810

911
#[derive(RuntimeFactors)]
1012
struct TestFactors {
@@ -37,9 +39,9 @@ async fn llm_works() -> anyhow::Result<()> {
3739
});
3840
let factors = TestFactors {
3941
llm: LlmFactor::new(move || {
40-
Box::new(FakeLLm {
42+
Arc::new(Mutex::new(FakeLLm {
4143
handle: handle.clone(),
42-
}) as _
44+
})) as _
4345
}),
4446
};
4547
let env = TestEnvironment::new(factors).extend_manifest(toml! {

crates/llm-local/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ safetensors = "0.3.3"
2020
serde = { version = "1.0.150", features = ["derive"] }
2121
spin-common = { path = "../common" }
2222
spin-core = { path = "../core" }
23-
spin-llm = { path = "../llm" }
2423
spin-world = { path = "../world" }
2524
terminal = { path = "../terminal" }
2625
tokenizers = "0.13.4"

0 commit comments

Comments
 (0)