Skip to content

Commit cedb9b0

Browse files
authored
Merge pull request #2755 from karthik2804/llama3-llm-factor
llm-factor: migrate to `candle`
2 parents 4206ff0 + 4e40481 commit cedb9b0

File tree

9 files changed

+901
-764
lines changed

9 files changed

+901
-764
lines changed

Cargo.lock

Lines changed: 594 additions & 519 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-llm/src/spin.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ mod local {
4444
/// The default engine creator for the LLM factor when used in the Spin CLI.
4545
pub fn default_engine_creator(
4646
state_dir: Option<PathBuf>,
47-
use_gpu: bool,
4847
) -> anyhow::Result<impl LlmEngineCreator + 'static> {
4948
#[cfg(feature = "llm")]
5049
let engine = {
@@ -53,11 +52,11 @@ pub fn default_engine_creator(
5352
Some(ref dir) => dir.clone(),
5453
None => std::env::current_dir().context("failed to get current working directory")?,
5554
};
56-
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"), use_gpu)
55+
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"))
5756
};
5857
#[cfg(not(feature = "llm"))]
5958
let engine = {
60-
let _ = (state_dir, use_gpu);
59+
let _ = (state_dir);
6160
noop::NoopLlmEngine
6261
};
6362
let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
@@ -91,15 +90,14 @@ impl LlmEngine for RemoteHttpLlmEngine {
9190
pub fn runtime_config_from_toml(
9291
table: &impl GetTomlValue,
9392
state_dir: Option<PathBuf>,
94-
use_gpu: bool,
9593
) -> anyhow::Result<Option<RuntimeConfig>> {
9694
let Some(value) = table.get("llm_compute") else {
9795
return Ok(None);
9896
};
9997
let config: LlmCompute = value.clone().try_into()?;
10098

10199
Ok(Some(RuntimeConfig {
102-
engine: config.into_engine(state_dir, use_gpu)?,
100+
engine: config.into_engine(state_dir)?,
103101
}))
104102
}
105103

@@ -111,19 +109,15 @@ pub enum LlmCompute {
111109
}
112110

113111
impl LlmCompute {
114-
fn into_engine(
115-
self,
116-
state_dir: Option<PathBuf>,
117-
use_gpu: bool,
118-
) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
112+
fn into_engine(self, state_dir: Option<PathBuf>) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
119113
let engine: Arc<Mutex<dyn LlmEngine>> = match self {
120114
#[cfg(not(feature = "llm"))]
121115
LlmCompute::Spin => {
122-
let _ = (state_dir, use_gpu);
116+
let _ = (state_dir);
123117
Arc::new(Mutex::new(noop::NoopLlmEngine))
124118
}
125119
#[cfg(feature = "llm")]
126-
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu)?.create(),
120+
LlmCompute::Spin => default_engine_creator(state_dir)?.create(),
127121
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
128122
config.url,
129123
config.auth_token,

crates/llm-local/Cargo.toml

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,29 @@ authors = { workspace = true }
55
edition = { workspace = true }
66

77
[dependencies]
8-
anyhow = { workspace = true }
9-
candle = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5", package = "candle-core" }
10-
candle-nn = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5" }
11-
chrono = "0.4"
12-
llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [
13-
"tokenizers-remote",
14-
"llama",
15-
], default-features = false }
16-
lru = "0.12"
8+
anyhow = "1.0"
9+
candle = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483", package = "candle-core" }
10+
candle-nn = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" }
11+
candle-transformers = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" }
12+
chrono = "0.4.26"
13+
lru = "0.9.0"
1714
num_cpus = "1"
1815
rand = { workspace = true }
1916
safetensors = "0.3.3"
2017
serde = { workspace = true }
18+
serde_json = "1.0.125"
2119
spin-common = { path = "../common" }
2220
spin-core = { path = "../core" }
2321
spin-world = { path = "../world" }
2422
terminal = { path = "../terminal" }
25-
tokenizers = "0.13.4"
26-
tokio = { version = "1", features = ["macros", "sync"] }
23+
tokenizers = "0.19.1"
24+
tokio = { version = "1.32.0", features = ["macros", "sync"] }
2725
tracing = { workspace = true }
2826

2927
[features]
3028
default = []
31-
metal = ["llm/metal"]
32-
cublas = ["llm/cublas"]
29+
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
30+
cublas = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
3331

3432
[lints]
3533
workspace = true

crates/llm-local/src/bert.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
///
55
/// TODO: Remove this file when a new release of Candle makes it obsolete.
66
use anyhow::{bail, Result};
7-
use candle::{DType, Tensor};
7+
use candle::{DType, Module, Tensor};
88
use candle_nn::{Embedding, VarBuilder};
99
use serde::Deserialize;
1010

0 commit comments

Comments
 (0)