Skip to content

Commit c25c927

Browse files
committed
feat(wasip2/llm): add support for llm
Signed-off-by: Rajat Jindal <[email protected]>
1 parent 9f90e21 commit c25c927

File tree

6 files changed

+156
-0
lines changed

6 files changed

+156
-0
lines changed

v2/examples/llm/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
main.wasm
2+
.spin/

v2/examples/llm/go.mod

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module github.com/spinframework/spin-go-sdk/v2/examples/llm
2+
3+
go 1.24.1
4+
5+
require github.com/spinframework/spin-go-sdk/v2 v2.0.0
6+
7+
require (
8+
github.com/julienschmidt/httprouter v1.3.0 // indirect
9+
go.bytecodealliance.org/cm v0.2.2 // indirect
10+
)
11+
12+
replace github.com/spinframework/spin-go-sdk/v2 => ../../

v2/examples/llm/go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
2+
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
3+
go.bytecodealliance.org/cm v0.2.2 h1:M9iHS6qs884mbQbIjtLX1OifgyPG9DuMs2iwz8G4WQA=
4+
go.bytecodealliance.org/cm v0.2.2/go.mod h1:JD5vtVNZv7sBoQQkvBvAAVKJPhR/bqBH7yYXTItMfZI=

v2/examples/llm/main.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
spinhttp "github.com/spinframework/spin-go-sdk/v2/http"
8+
"github.com/spinframework/spin-go-sdk/v2/llm"
9+
)
10+
11+
func init() {
12+
spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) {
13+
result, err := llm.Infer("llama2-chat", "Tell me a joke", nil)
14+
if err != nil {
15+
http.Error(w, err.Error(), http.StatusInternalServerError)
16+
return
17+
}
18+
fmt.Printf("Prompt tokens: %d\n", result.Usage.PromptTokenCount)
19+
fmt.Printf("Generated tokens: %d\n", result.Usage.GeneratedTokenCount)
20+
fmt.Fprint(w, result.Text)
21+
fmt.Fprintf(w, "\n\n")
22+
23+
embeddings, err := llm.GenerateEmbeddings("all-minilm-l6-v2", []string{"Hello world"})
24+
if err != nil {
25+
http.Error(w, err.Error(), http.StatusInternalServerError)
26+
return
27+
}
28+
fmt.Printf("%d\n", len(embeddings.Embeddings[0]))
29+
fmt.Printf("Prompt Tokens: %d\n", embeddings.Usage.PromptTokenCount)
30+
31+
})
32+
}
33+
34+
func main() {}

v2/examples/llm/spin.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
spin_manifest_version = 2
2+
3+
[application]
4+
authors = ["Fermyon Engineering <[email protected]>"]
5+
description = "Simple example using the llm sdk."
6+
name = "llm-example"
7+
version = "0.1.0"
8+
9+
[[trigger.http]]
10+
route = "/..."
11+
component = "llm"
12+
13+
[component.llm]
14+
source = "main.wasm"
15+
allowed_outbound_hosts = []
16+
ai_models = ["llama2-chat", "all-minilm-l6-v2"]
17+
[component.llm.build]
18+
command = "tinygo build -target=wasip2 --wit-package $(go list -mod=readonly -m -f '{{.Dir}}' github.com/spinframework/spin-go-sdk/v2)/wit --wit-world http-trigger -gc=leaking -no-debug -o main.wasm main.go"
19+
watch = ["**/*.go", "go.mod"]

v2/llm/llm.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package llm
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm"
7+
"go.bytecodealliance.org/cm"
8+
)
9+
10+
// The model use for inferencing
11+
const (
12+
Llama2Chat InferencingModel = "llama2-chat"
13+
CodellamaInstruct InferencingModel = "codellama-instruct"
14+
)
15+
16+
type InferencingParams llm.InferencingParams
17+
type InferencingResult llm.InferencingResult
18+
type InferencingModel llm.InferencingModel
19+
20+
// type EmbeddingsResult llm.EmbeddingsResult
21+
22+
type EmbeddingsResult struct {
23+
// Embeddings are the embeddings generated by the request.
24+
Embeddings [][]float32
25+
// Usage is usage related to an embeddings generation request.
26+
Usage *EmbeddingsUsage
27+
}
28+
29+
type EmbeddingsUsage struct {
30+
// PromptTokenCount is number of tokens in the prompt.
31+
PromptTokenCount int
32+
}
33+
34+
// Infer performs inferencing using the provided model and prompt with the
35+
// given optional parameters.
36+
func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) {
37+
var iparams = cm.None[llm.InferencingParams]()
38+
if params != nil {
39+
iparams = cm.Some(llm.InferencingParams(*params))
40+
}
41+
42+
result := llm.Infer(llm.InferencingModel(model), prompt, iparams)
43+
if result.IsErr() {
44+
return InferencingResult{}, errorVariantToError(*result.Err())
45+
}
46+
47+
return InferencingResult(*result.OK()), nil
48+
}
49+
50+
// GenerateEmbeddings generates the embeddings for the supplied list of text.
51+
func GenerateEmbeddings(model InferencingModel, text []string) (*EmbeddingsResult, error) {
52+
result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text))
53+
if result.IsErr() {
54+
return &EmbeddingsResult{}, errorVariantToError(*result.Err())
55+
}
56+
57+
llmEmbeddingResult := llm.EmbeddingsResult(*result.OK())
58+
59+
list := llmEmbeddingResult.Embeddings
60+
list2 := list.Slice()
61+
list3 := [][]float32{}
62+
for _, l := range list2 {
63+
list3 = append(list3, l.Slice())
64+
}
65+
66+
return &EmbeddingsResult{
67+
Embeddings: list3,
68+
Usage: &EmbeddingsUsage{
69+
PromptTokenCount: int(llmEmbeddingResult.Usage.PromptTokenCount),
70+
},
71+
}, nil
72+
}
73+
74+
func errorVariantToError(err llm.Error) error {
75+
switch {
76+
case llm.ErrorModelNotSupported() == err:
77+
return fmt.Errorf("model not supported")
78+
case err.RuntimeError() != nil:
79+
return fmt.Errorf("runtime error %s", *err.RuntimeError())
80+
case err.InvalidInput() != nil:
81+
return fmt.Errorf("invalid input %s", *err.InvalidInput())
82+
default:
83+
return fmt.Errorf("no error provided by host implementation")
84+
}
85+
}

0 commit comments

Comments
 (0)