diff --git a/v2/examples/llm/.gitignore b/v2/examples/llm/.gitignore new file mode 100644 index 0000000..b565010 --- /dev/null +++ b/v2/examples/llm/.gitignore @@ -0,0 +1,2 @@ +main.wasm +.spin/ diff --git a/v2/examples/llm/go.mod b/v2/examples/llm/go.mod new file mode 100644 index 0000000..cc73aec --- /dev/null +++ b/v2/examples/llm/go.mod @@ -0,0 +1,12 @@ +module github.com/spinframework/spin-go-sdk/v2/examples/llm + +go 1.24.1 + +require github.com/spinframework/spin-go-sdk/v2 v2.0.0 + +require ( + github.com/julienschmidt/httprouter v1.3.0 // indirect + go.bytecodealliance.org/cm v0.2.2 // indirect +) + +replace github.com/spinframework/spin-go-sdk/v2 => ../../ diff --git a/v2/examples/llm/go.sum b/v2/examples/llm/go.sum new file mode 100644 index 0000000..c1ebfdf --- /dev/null +++ b/v2/examples/llm/go.sum @@ -0,0 +1,4 @@ +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +go.bytecodealliance.org/cm v0.2.2 h1:M9iHS6qs884mbQbIjtLX1OifgyPG9DuMs2iwz8G4WQA= +go.bytecodealliance.org/cm v0.2.2/go.mod h1:JD5vtVNZv7sBoQQkvBvAAVKJPhR/bqBH7yYXTItMfZI= diff --git a/v2/examples/llm/main.go b/v2/examples/llm/main.go new file mode 100644 index 0000000..55dae4e --- /dev/null +++ b/v2/examples/llm/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "net/http" + + spinhttp "github.com/spinframework/spin-go-sdk/v2/http" + "github.com/spinframework/spin-go-sdk/v2/llm" +) + +func init() { + spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) { + result, err := llm.Infer("llama2-chat", "Tell me a joke", nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Printf("Prompt tokens: %d\n", result.Usage.PromptTokenCount) + fmt.Printf("Generated tokens: %d\n", result.Usage.GeneratedTokenCount) + fmt.Fprint(w, result.Text) + fmt.Fprintf(w, "\n\n") + + embeddings, err := llm.GenerateEmbeddings("all-minilm-l6-v2", []string{"Hello world"}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Printf("%d\n", len(embeddings.Embeddings[0])) + fmt.Printf("Prompt Tokens: %d\n", embeddings.Usage.PromptTokenCount) + + }) +} + +func main() {} diff --git a/v2/examples/llm/spin.toml b/v2/examples/llm/spin.toml new file mode 100644 index 0000000..bcdda80 --- /dev/null +++ b/v2/examples/llm/spin.toml @@ -0,0 +1,19 @@ +spin_manifest_version = 2 + +[application] +authors = ["Fermyon Engineering "] +description = "Simple example using the llm sdk." +name = "llm-example" +version = "0.1.0" + +[[trigger.http]] +route = "/..." +component = "llm" + +[component.llm] +source = "main.wasm" +allowed_outbound_hosts = [] +ai_models = ["llama2-chat", "all-minilm-l6-v2"] +[component.llm.build] +command = "tinygo build -target=wasip2 --wit-package $(go list -mod=readonly -m -f '{{.Dir}}' github.com/spinframework/spin-go-sdk/v2)/wit --wit-world http-trigger -gc=leaking -no-debug -o main.wasm main.go" +watch = ["**/*.go", "go.mod"] diff --git a/v2/llm/llm.go b/v2/llm/llm.go new file mode 100644 index 0000000..ac4b94d --- /dev/null +++ b/v2/llm/llm.go @@ -0,0 +1,135 @@ +package llm + +import ( + "fmt" + + "github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm" + "go.bytecodealliance.org/cm" +) + +// The model use for inferencing +const ( + Llama2Chat InferencingModel = "llama2-chat" + CodellamaInstruct InferencingModel = "codellama-instruct" +) + +type InferencingParams struct { + // The maximum tokens that should be inferred. + // + // Note: the backing implementation may return less tokens. + MaxTokens uint32 `json:"max-tokens"` + + // The amount the model should avoid repeating tokens. + RepeatPenalty float32 `json:"repeat-penalty"` + + // The number of tokens the model should apply the repeat penalty to. + RepeatPenaltyLastNTokenCount uint32 `json:"repeat-penalty-last-n-token-count"` + + // The randomness with which the next token is selected. + Temperature float32 `json:"temperature"` + + // The number of possible next tokens the model will choose from. + TopK uint32 `json:"top-k"` + + // The probability total of next tokens the model will choose from. + TopP float32 `json:"top-p"` +} + +type InferencingResult struct { + // The text generated by the model + // TODO: this should be a stream + Text string `json:"text"` + + // Usage information about the inferencing request + Usage InferencingUsage `json:"usage"` +} + +// Usage information related to the inferencing result +type InferencingUsage struct { + _ cm.HostLayout `json:"-"` + // Number of tokens in the prompt + PromptTokenCount uint32 `json:"prompt-token-count"` + + // Number of tokens generated by the inferencing operation + GeneratedTokenCount uint32 `json:"generated-token-count"` +} + +// A Large Language Model. +type InferencingModel string + +// The model used for generating embeddings +type EmbeddingModel string + +type EmbeddingsResult struct { + // Embeddings are the embeddings generated by the request. + Embeddings [][]float32 + // Usage is usage related to an embeddings generation request. + Usage *EmbeddingsUsage +} + +type EmbeddingsUsage struct { + // PromptTokenCount is number of tokens in the prompt. + PromptTokenCount int +} + +// Infer performs inferencing using the provided model and prompt with the +// given optional parameters. +func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) { + iparams := cm.None[llm.InferencingParams]() + if params != nil { + iparams = cm.Some(llm.InferencingParams{ + MaxTokens: params.MaxTokens, + RepeatPenalty: params.RepeatPenalty, + RepeatPenaltyLastNTokenCount: params.RepeatPenaltyLastNTokenCount, + Temperature: params.Temperature, + TopK: params.TopK, + TopP: params.TopP, + }) + } + + result := llm.Infer(llm.InferencingModel(model), prompt, iparams) + if result.IsErr() { + return InferencingResult{}, errorVariantToError(*result.Err()) + } + + return InferencingResult{ + Text: result.OK().Text, + Usage: InferencingUsage{ + PromptTokenCount: result.OK().Usage.PromptTokenCount, + GeneratedTokenCount: result.OK().Usage.GeneratedTokenCount, + }, + }, nil +} + +// GenerateEmbeddings generates the embeddings for the supplied list of text. +func GenerateEmbeddings(model EmbeddingModel, text []string) (*EmbeddingsResult, error) { + result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text)) + if result.IsErr() { + return &EmbeddingsResult{}, errorVariantToError(*result.Err()) + } + + embeddings := [][]float32{} + for _, l := range result.OK().Embeddings.Slice() { + embeddings = append(embeddings, l.Slice()) + } + + return &EmbeddingsResult{ + Embeddings: embeddings, + Usage: &EmbeddingsUsage{ + PromptTokenCount: int(result.OK().Usage.PromptTokenCount), + }, + }, nil +} + +func errorVariantToError(err llm.Error) error { + switch { + case llm.ErrorModelNotSupported() == err: + return fmt.Errorf("model not supported") + case err.RuntimeError() != nil: + return fmt.Errorf("runtime error %s", *err.RuntimeError()) + case err.InvalidInput() != nil: + return fmt.Errorf("invalid input %s", *err.InvalidInput()) + default: + return fmt.Errorf("no error provided by host implementation") + } +}