Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions v2/examples/llm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
main.wasm
.spin/
12 changes: 12 additions & 0 deletions v2/examples/llm/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module github.com/spinframework/spin-go-sdk/v2/examples/llm

go 1.24.1

require github.com/spinframework/spin-go-sdk/v2 v2.0.0

require (
github.com/julienschmidt/httprouter v1.3.0 // indirect
go.bytecodealliance.org/cm v0.2.2 // indirect
)

replace github.com/spinframework/spin-go-sdk/v2 => ../../
4 changes: 4 additions & 0 deletions v2/examples/llm/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
go.bytecodealliance.org/cm v0.2.2 h1:M9iHS6qs884mbQbIjtLX1OifgyPG9DuMs2iwz8G4WQA=
go.bytecodealliance.org/cm v0.2.2/go.mod h1:JD5vtVNZv7sBoQQkvBvAAVKJPhR/bqBH7yYXTItMfZI=
34 changes: 34 additions & 0 deletions v2/examples/llm/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"fmt"
"net/http"

spinhttp "github.com/spinframework/spin-go-sdk/v2/http"
"github.com/spinframework/spin-go-sdk/v2/llm"
)

func init() {
spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) {
result, err := llm.Infer("llama2-chat", "Tell me a joke", nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fmt.Printf("Prompt tokens: %d\n", result.Usage.PromptTokenCount)
fmt.Printf("Generated tokens: %d\n", result.Usage.GeneratedTokenCount)
fmt.Fprint(w, result.Text)
fmt.Fprintf(w, "\n\n")

embeddings, err := llm.GenerateEmbeddings("all-minilm-l6-v2", []string{"Hello world"})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fmt.Printf("%d\n", len(embeddings.Embeddings[0]))
fmt.Printf("Prompt Tokens: %d\n", embeddings.Usage.PromptTokenCount)

})
}

func main() {}
19 changes: 19 additions & 0 deletions v2/examples/llm/spin.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
spin_manifest_version = 2

[application]
authors = ["Fermyon Engineering <[email protected]>"]
description = "Simple example using the llm sdk."
name = "llm-example"
version = "0.1.0"

[[trigger.http]]
route = "/..."
component = "llm"

[component.llm]
source = "main.wasm"
allowed_outbound_hosts = []
ai_models = ["llama2-chat", "all-minilm-l6-v2"]
[component.llm.build]
command = "tinygo build -target=wasip2 --wit-package $(go list -mod=readonly -m -f '{{.Dir}}' github.com/spinframework/spin-go-sdk/v2)/wit --wit-world http-trigger -gc=leaking -no-debug -o main.wasm main.go"
watch = ["**/*.go", "go.mod"]
135 changes: 135 additions & 0 deletions v2/llm/llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package llm

import (
"fmt"

"github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm"
"go.bytecodealliance.org/cm"
)

// The model use for inferencing
const (
Llama2Chat InferencingModel = "llama2-chat"
CodellamaInstruct InferencingModel = "codellama-instruct"
)

type InferencingParams struct {
// The maximum tokens that should be inferred.
//
// Note: the backing implementation may return less tokens.
MaxTokens uint32 `json:"max-tokens"`

// The amount the model should avoid repeating tokens.
RepeatPenalty float32 `json:"repeat-penalty"`

// The number of tokens the model should apply the repeat penalty to.
RepeatPenaltyLastNTokenCount uint32 `json:"repeat-penalty-last-n-token-count"`

// The randomness with which the next token is selected.
Temperature float32 `json:"temperature"`

// The number of possible next tokens the model will choose from.
TopK uint32 `json:"top-k"`

// The probability total of next tokens the model will choose from.
TopP float32 `json:"top-p"`
}

type InferencingResult struct {
// The text generated by the model
// TODO: this should be a stream
Text string `json:"text"`

// Usage information about the inferencing request
Usage InferencingUsage `json:"usage"`
}

// Usage information related to the inferencing result
type InferencingUsage struct {
_ cm.HostLayout `json:"-"`
// Number of tokens in the prompt
PromptTokenCount uint32 `json:"prompt-token-count"`

// Number of tokens generated by the inferencing operation
GeneratedTokenCount uint32 `json:"generated-token-count"`
}

// A Large Language Model.
type InferencingModel string

// The model used for generating embeddings
type EmbeddingModel string

type EmbeddingsResult struct {
// Embeddings are the embeddings generated by the request.
Embeddings [][]float32
// Usage is usage related to an embeddings generation request.
Usage *EmbeddingsUsage
}

type EmbeddingsUsage struct {
// PromptTokenCount is number of tokens in the prompt.
PromptTokenCount int
}

// Infer performs inferencing using the provided model and prompt with the
// given optional parameters.
func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) {
iparams := cm.None[llm.InferencingParams]()
if params != nil {
iparams = cm.Some(llm.InferencingParams{
MaxTokens: params.MaxTokens,
RepeatPenalty: params.RepeatPenalty,
RepeatPenaltyLastNTokenCount: params.RepeatPenaltyLastNTokenCount,
Temperature: params.Temperature,
TopK: params.TopK,
TopP: params.TopP,
})
}

result := llm.Infer(llm.InferencingModel(model), prompt, iparams)
if result.IsErr() {
return InferencingResult{}, errorVariantToError(*result.Err())
}

return InferencingResult{
Text: result.OK().Text,
Usage: InferencingUsage{
PromptTokenCount: result.OK().Usage.PromptTokenCount,
GeneratedTokenCount: result.OK().Usage.GeneratedTokenCount,
},
}, nil
}

// GenerateEmbeddings generates the embeddings for the supplied list of text.
func GenerateEmbeddings(model EmbeddingModel, text []string) (*EmbeddingsResult, error) {
result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text))
if result.IsErr() {
return &EmbeddingsResult{}, errorVariantToError(*result.Err())
}

embeddings := [][]float32{}
for _, l := range result.OK().Embeddings.Slice() {
embeddings = append(embeddings, l.Slice())
}

return &EmbeddingsResult{
Embeddings: embeddings,
Usage: &EmbeddingsUsage{
PromptTokenCount: int(result.OK().Usage.PromptTokenCount),
},
}, nil
}

func errorVariantToError(err llm.Error) error {
switch {
case llm.ErrorModelNotSupported() == err:
return fmt.Errorf("model not supported")
case err.RuntimeError() != nil:
return fmt.Errorf("runtime error %s", *err.RuntimeError())
case err.InvalidInput() != nil:
return fmt.Errorf("invalid input %s", *err.InvalidInput())
default:
return fmt.Errorf("no error provided by host implementation")
}
}