Skip to content

Commit 362c75d

Browse files
authored
Merge pull request #1795 from fermyon/allow-other-models
Allow other ai models besides while known ones
2 parents 2269651 + e74b2a8 commit 362c75d

File tree

4 files changed

+135
-24
lines changed

4 files changed

+135
-24
lines changed

Cargo.lock

Lines changed: 46 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/llm-local/src/lib.rs

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ use candle::DType;
66
use candle_nn::VarBuilder;
77
use llm::{
88
InferenceFeedback, InferenceParameters, InferenceResponse, InferenceSessionConfig, Model,
9-
ModelKVMemoryType, ModelParameters,
9+
ModelArchitecture, ModelKVMemoryType, ModelParameters,
1010
};
1111
use rand::SeedableRng;
1212
use spin_core::async_trait;
13-
use spin_llm::{model_arch, model_name, LlmEngine, MODEL_ALL_MINILM_L6_V2};
13+
use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2};
1414
use spin_world::llm::{self as wasi_llm};
1515
use std::{
1616
collections::hash_map::Entry,
@@ -170,14 +170,22 @@ impl LocalLlmEngine {
170170
&mut self,
171171
model: wasi_llm::InferencingModel,
172172
) -> Result<Arc<dyn Model>, wasi_llm::Error> {
173-
let model_name = model_name(&model)?;
174173
let use_gpu = self.use_gpu;
175174
let progress_fn = |_| {};
176-
let model = match self.inferencing_models.entry((model_name.into(), use_gpu)) {
175+
let model = match self.inferencing_models.entry((model.clone(), use_gpu)) {
177176
Entry::Occupied(o) => o.get().clone(),
178177
Entry::Vacant(v) => v
179178
.insert({
180-
let path = self.registry.join(model_name);
179+
let (path, arch) = if let Some(arch) = well_known_inferencing_model_arch(&model) {
180+
let model_binary = self.registry.join(&model);
181+
if model_binary.exists() {
182+
(model_binary, arch.to_owned())
183+
} else {
184+
walk_registry_for_model(&self.registry, model).await?
185+
}
186+
} else {
187+
walk_registry_for_model(&self.registry, model).await?
188+
};
181189
if !self.registry.exists() {
182190
return Err(wasi_llm::Error::RuntimeError(
183191
format!("The directory expected to house the inferencing model '{}' does not exist.", self.registry.display())
@@ -199,7 +207,7 @@ impl LocalLlmEngine {
199207
n_gqa: None,
200208
};
201209
let model = llm::load_dynamic(
202-
Some(model_arch(&model)?),
210+
Some(arch),
203211
&path,
204212
llm::TokenizerSource::Embedded,
205213
params,
@@ -223,6 +231,80 @@ impl LocalLlmEngine {
223231
}
224232
}
225233

234+
/// Get the model binary and arch from walking the registry file structure
235+
async fn walk_registry_for_model(
236+
registry_path: &Path,
237+
model: String,
238+
) -> Result<(PathBuf, ModelArchitecture), wasi_llm::Error> {
239+
let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
240+
wasi_llm::Error::RuntimeError(format!(
241+
"Could not read model registry directory '{}': {e}",
242+
registry_path.display()
243+
))
244+
})?;
245+
let mut result = None;
246+
'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
247+
wasi_llm::Error::RuntimeError(format!(
248+
"Failed to read arch directory in model registry: {e}"
249+
))
250+
})? {
251+
if arch_dir
252+
.file_type()
253+
.await
254+
.map_err(|e| {
255+
wasi_llm::Error::RuntimeError(format!(
256+
"Could not read file type of '{}' dir: {e}",
257+
arch_dir.path().display()
258+
))
259+
})?
260+
.is_file()
261+
{
262+
continue;
263+
}
264+
let mut model_files = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
265+
wasi_llm::Error::RuntimeError(format!(
266+
"Error reading architecture directory in model registry: {e}"
267+
))
268+
})?;
269+
while let Some(model_file) = model_files.next_entry().await.map_err(|e| {
270+
wasi_llm::Error::RuntimeError(format!(
271+
"Error reading model file in model registry: {e}"
272+
))
273+
})? {
274+
if model_file
275+
.file_name()
276+
.to_str()
277+
.map(|m| m == model)
278+
.unwrap_or_default()
279+
{
280+
let arch = arch_dir.file_name();
281+
let arch = arch
282+
.to_str()
283+
.ok_or(wasi_llm::Error::ModelNotSupported)?
284+
.parse()
285+
.map_err(|_| wasi_llm::Error::ModelNotSupported)?;
286+
result = Some((model_file.path(), arch));
287+
break 'outer;
288+
}
289+
}
290+
}
291+
292+
result.ok_or_else(|| {
293+
wasi_llm::Error::InvalidInput(format!(
294+
"no model directory found in registry for model '{model}'"
295+
))
296+
})
297+
}
298+
299+
fn well_known_inferencing_model_arch(
300+
model: &wasi_llm::InferencingModel,
301+
) -> Option<ModelArchitecture> {
302+
match model.as_str() {
303+
"llama2-chat" | "code_llama" => Some(ModelArchitecture::Llama),
304+
_ => None,
305+
}
306+
}
307+
226308
async fn generate_embeddings(
227309
data: Vec<String>,
228310
model: Arc<(tokenizers::Tokenizer, BertModel)>,

crates/llm/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ anyhow = "1.0"
99
bytesize = "1.1"
1010
llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [
1111
"tokenizers-remote",
12-
"llama",
12+
"models",
1313
], default-features = false }
1414
spin-app = { path = "../app" }
1515
spin-core = { path = "../core" }

crates/llm/src/lib.rs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
pub mod host_component;
22

3-
use llm::ModelArchitecture;
43
use spin_app::MetadataKey;
54
use spin_core::async_trait;
65
use spin_world::llm::{self as wasi_llm};
@@ -72,22 +71,6 @@ impl wasi_llm::Host for LlmDispatch {
7271
}
7372
}
7473

75-
pub fn model_name(model: &wasi_llm::InferencingModel) -> Result<&str, wasi_llm::Error> {
76-
match model.as_str() {
77-
"llama2-chat" | "codellama-instruct" => Ok(model.as_str()),
78-
_ => Err(wasi_llm::Error::ModelNotSupported),
79-
}
80-
}
81-
82-
pub fn model_arch(
83-
model: &wasi_llm::InferencingModel,
84-
) -> Result<ModelArchitecture, wasi_llm::Error> {
85-
match model.as_str() {
86-
"llama2-chat" | "codellama-instruct" => Ok(ModelArchitecture::Llama),
87-
_ => Err(wasi_llm::Error::ModelNotSupported),
88-
}
89-
}
90-
9174
fn access_denied_error(model: &str) -> wasi_llm::Error {
9275
wasi_llm::Error::InvalidInput(format!(
9376
"The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"

0 commit comments

Comments
 (0)