Skip to content

Commit 6c41448

Browse files
committed
spin-python-engine: init llm inferencing
Signed-off-by: Danielle Lancashire <[email protected]>
1 parent 656e4ad commit 6c41448

File tree

1 file changed

+116
-0
lines changed
  • crates/spin-python-engine/src

1 file changed

+116
-0
lines changed

crates/spin-python-engine/src/lib.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use {
1616
key_value, outbound_http,
1717
redis::{self, RedisParameter, RedisResult},
1818
sqlite,
19+
llm
1920
},
2021
std::{collections::HashMap, env, ops::Deref, str, sync::Arc},
2122
};
@@ -432,12 +433,127 @@ fn spin_config_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> {
432433
module.add_function(pyo3::wrap_pyfunction!(config_get, module)?)
433434
}
434435

436+
#[derive(Clone)]
437+
#[pyo3::pyclass]
438+
#[pyo3(name = "LLMInferencingUsage")]
439+
struct LLMInferencingUsage {
440+
#[pyo3(get)]
441+
prompt_token_count: u32,
442+
#[pyo3(get)]
443+
generated_token_count: u32
444+
}
445+
446+
impl From<llm::InferencingUsage> for LLMInferencingUsage {
447+
fn from(result: llm::InferencingUsage) -> Self {
448+
LLMInferencingUsage{
449+
prompt_token_count: result.prompt_token_count,
450+
generated_token_count: result.generated_token_count
451+
}
452+
}
453+
}
454+
455+
456+
#[derive(Clone)]
457+
#[pyo3::pyclass]
458+
#[pyo3(name = "LLMInferencingResult")]
459+
struct LLMInferencingResult {
460+
#[pyo3(get)]
461+
text: String,
462+
#[pyo3(get)]
463+
usage: LLMInferencingUsage
464+
}
465+
466+
impl From<llm::InferencingResult> for LLMInferencingResult {
467+
fn from(result: llm::InferencingResult) -> Self {
468+
LLMInferencingResult{
469+
text: result.text.clone(),
470+
usage: LLMInferencingUsage::from(result.usage)
471+
}
472+
}
473+
}
474+
475+
476+
#[derive(Clone)]
477+
#[pyo3::pyclass]
478+
#[pyo3(name = "LLMInferencingParams")]
479+
struct LLMInferencingParams {
480+
#[pyo3(get,set)]
481+
max_tokens: u32,
482+
#[pyo3(get,set)]
483+
repeat_penalty: f32,
484+
#[pyo3(get,set)]
485+
repeat_penalty_last_n_token_count: u32,
486+
#[pyo3(get,set)]
487+
temperature: f32,
488+
#[pyo3(get,set)]
489+
top_k: u32,
490+
#[pyo3(get,set)]
491+
top_p: f32
492+
}
493+
494+
#[pyo3::pymethods]
495+
impl LLMInferencingParams {
496+
#[new]
497+
fn new(max_tokens: u32, repeat_penalty: f32, repeat_penalty_last_n_token_count: u32, temperature: f32, top_k: u32, top_p: f32) -> Self {
498+
Self{
499+
max_tokens: max_tokens,
500+
repeat_penalty: repeat_penalty,
501+
repeat_penalty_last_n_token_count: repeat_penalty_last_n_token_count,
502+
temperature: temperature,
503+
top_k: top_k,
504+
top_p: top_p,
505+
}
506+
}
507+
}
508+
509+
impl From<LLMInferencingParams> for llm::InferencingParams {
510+
fn from(p: LLMInferencingParams) -> Self {
511+
llm::InferencingParams{
512+
max_tokens: p.max_tokens,
513+
repeat_penalty: p.repeat_penalty,
514+
repeat_penalty_last_n_token_count: p.repeat_penalty_last_n_token_count,
515+
temperature: p.temperature,
516+
top_k: p.top_k,
517+
top_p: p.top_p,
518+
}
519+
}
520+
}
521+
522+
#[pyo3::pyfunction]
523+
fn llm_infer(model: &str, prompt: &str, options: Option<LLMInferencingParams>) -> Result<LLMInferencingResult, Anyhow> {
524+
let m = match model {
525+
"llama2-chat" => llm::InferencingModel::Llama2Chat,
526+
"codellama-instruct" => llm::InferencingModel::CodellamaInstruct,
527+
_ => llm::InferencingModel::Other(model)
528+
};
529+
530+
let opts = match options {
531+
Some(o) => llm::InferencingParams::from(o),
532+
_ => llm::InferencingParams::default()
533+
};
534+
535+
llm::infer_with_options(m, prompt, opts)
536+
.map_err(Anyhow::from)
537+
.map(LLMInferencingResult::from)
538+
}
539+
540+
#[pyo3::pymodule]
541+
#[pyo3(name = "spin_llm")]
542+
fn spin_llm_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> {
543+
module.add_function(pyo3::wrap_pyfunction!(llm_infer, module)?)?;
544+
module.add_class::<LLMInferencingUsage>()?;
545+
module.add_class::<LLMInferencingParams>()?;
546+
module.add_class::<LLMInferencingResult>()
547+
}
548+
549+
435550
fn do_init() -> Result<()> {
436551
pyo3::append_to_inittab!(spin_http_module);
437552
pyo3::append_to_inittab!(spin_redis_module);
438553
pyo3::append_to_inittab!(spin_config_module);
439554
pyo3::append_to_inittab!(spin_key_value_module);
440555
pyo3::append_to_inittab!(spin_sqlite_module);
556+
pyo3::append_to_inittab!(spin_llm_module);
441557

442558
pyo3::prepare_freethreaded_python();
443559

0 commit comments

Comments
 (0)