|
16 | 16 | key_value, outbound_http, |
17 | 17 | redis::{self, RedisParameter, RedisResult}, |
18 | 18 | sqlite, |
| 19 | + llm |
19 | 20 | }, |
20 | 21 | std::{collections::HashMap, env, ops::Deref, str, sync::Arc}, |
21 | 22 | }; |
@@ -432,12 +433,127 @@ fn spin_config_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> { |
432 | 433 | module.add_function(pyo3::wrap_pyfunction!(config_get, module)?) |
433 | 434 | } |
434 | 435 |
|
| 436 | +#[derive(Clone)] |
| 437 | +#[pyo3::pyclass] |
| 438 | +#[pyo3(name = "LLMInferencingUsage")] |
| 439 | +struct LLMInferencingUsage { |
| 440 | + #[pyo3(get)] |
| 441 | + prompt_token_count: u32, |
| 442 | + #[pyo3(get)] |
| 443 | + generated_token_count: u32 |
| 444 | +} |
| 445 | + |
| 446 | +impl From<llm::InferencingUsage> for LLMInferencingUsage { |
| 447 | + fn from(result: llm::InferencingUsage) -> Self { |
| 448 | + LLMInferencingUsage{ |
| 449 | + prompt_token_count: result.prompt_token_count, |
| 450 | + generated_token_count: result.generated_token_count |
| 451 | + } |
| 452 | + } |
| 453 | +} |
| 454 | + |
| 455 | + |
| 456 | +#[derive(Clone)] |
| 457 | +#[pyo3::pyclass] |
| 458 | +#[pyo3(name = "LLMInferencingResult")] |
| 459 | +struct LLMInferencingResult { |
| 460 | + #[pyo3(get)] |
| 461 | + text: String, |
| 462 | + #[pyo3(get)] |
| 463 | + usage: LLMInferencingUsage |
| 464 | +} |
| 465 | + |
| 466 | +impl From<llm::InferencingResult> for LLMInferencingResult { |
| 467 | + fn from(result: llm::InferencingResult) -> Self { |
| 468 | + LLMInferencingResult{ |
| 469 | + text: result.text.clone(), |
| 470 | + usage: LLMInferencingUsage::from(result.usage) |
| 471 | + } |
| 472 | + } |
| 473 | +} |
| 474 | + |
| 475 | + |
| 476 | +#[derive(Clone)] |
| 477 | +#[pyo3::pyclass] |
| 478 | +#[pyo3(name = "LLMInferencingParams")] |
| 479 | +struct LLMInferencingParams { |
| 480 | + #[pyo3(get,set)] |
| 481 | + max_tokens: u32, |
| 482 | + #[pyo3(get,set)] |
| 483 | + repeat_penalty: f32, |
| 484 | + #[pyo3(get,set)] |
| 485 | + repeat_penalty_last_n_token_count: u32, |
| 486 | + #[pyo3(get,set)] |
| 487 | + temperature: f32, |
| 488 | + #[pyo3(get,set)] |
| 489 | + top_k: u32, |
| 490 | + #[pyo3(get,set)] |
| 491 | + top_p: f32 |
| 492 | +} |
| 493 | + |
| 494 | +#[pyo3::pymethods] |
| 495 | +impl LLMInferencingParams { |
| 496 | + #[new] |
| 497 | + fn new(max_tokens: u32, repeat_penalty: f32, repeat_penalty_last_n_token_count: u32, temperature: f32, top_k: u32, top_p: f32) -> Self { |
| 498 | + Self{ |
| 499 | + max_tokens: max_tokens, |
| 500 | + repeat_penalty: repeat_penalty, |
| 501 | + repeat_penalty_last_n_token_count: repeat_penalty_last_n_token_count, |
| 502 | + temperature: temperature, |
| 503 | + top_k: top_k, |
| 504 | + top_p: top_p, |
| 505 | + } |
| 506 | + } |
| 507 | +} |
| 508 | + |
| 509 | +impl From<LLMInferencingParams> for llm::InferencingParams { |
| 510 | + fn from(p: LLMInferencingParams) -> Self { |
| 511 | + llm::InferencingParams{ |
| 512 | + max_tokens: p.max_tokens, |
| 513 | + repeat_penalty: p.repeat_penalty, |
| 514 | + repeat_penalty_last_n_token_count: p.repeat_penalty_last_n_token_count, |
| 515 | + temperature: p.temperature, |
| 516 | + top_k: p.top_k, |
| 517 | + top_p: p.top_p, |
| 518 | + } |
| 519 | + } |
| 520 | +} |
| 521 | + |
| 522 | +#[pyo3::pyfunction] |
| 523 | +fn llm_infer(model: &str, prompt: &str, options: Option<LLMInferencingParams>) -> Result<LLMInferencingResult, Anyhow> { |
| 524 | + let m = match model { |
| 525 | + "llama2-chat" => llm::InferencingModel::Llama2Chat, |
| 526 | + "codellama-instruct" => llm::InferencingModel::CodellamaInstruct, |
| 527 | + _ => llm::InferencingModel::Other(model) |
| 528 | + }; |
| 529 | + |
| 530 | + let opts = match options { |
| 531 | + Some(o) => llm::InferencingParams::from(o), |
| 532 | + _ => llm::InferencingParams::default() |
| 533 | + }; |
| 534 | + |
| 535 | + llm::infer_with_options(m, prompt, opts) |
| 536 | + .map_err(Anyhow::from) |
| 537 | + .map(LLMInferencingResult::from) |
| 538 | +} |
| 539 | + |
| 540 | +#[pyo3::pymodule] |
| 541 | +#[pyo3(name = "spin_llm")] |
| 542 | +fn spin_llm_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> { |
| 543 | + module.add_function(pyo3::wrap_pyfunction!(llm_infer, module)?)?; |
| 544 | + module.add_class::<LLMInferencingUsage>()?; |
| 545 | + module.add_class::<LLMInferencingParams>()?; |
| 546 | + module.add_class::<LLMInferencingResult>() |
| 547 | +} |
| 548 | + |
| 549 | + |
435 | 550 | fn do_init() -> Result<()> { |
436 | 551 | pyo3::append_to_inittab!(spin_http_module); |
437 | 552 | pyo3::append_to_inittab!(spin_redis_module); |
438 | 553 | pyo3::append_to_inittab!(spin_config_module); |
439 | 554 | pyo3::append_to_inittab!(spin_key_value_module); |
440 | 555 | pyo3::append_to_inittab!(spin_sqlite_module); |
| 556 | + pyo3::append_to_inittab!(spin_llm_module); |
441 | 557 |
|
442 | 558 | pyo3::prepare_freethreaded_python(); |
443 | 559 |
|
|
0 commit comments