Skip to content

Commit 3b8bb70

Browse files
authored
Merge pull request #18 from rajatjindal/spin-w2-llm
add llm support for wasip2 sdk
2 parents 9f90e21 + 8c008a4 commit 3b8bb70

File tree

6 files changed

+206
-0
lines changed

6 files changed

+206
-0
lines changed

v2/examples/llm/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
main.wasm
2+
.spin/

v2/examples/llm/go.mod

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module github.com/spinframework/spin-go-sdk/v2/examples/llm
2+
3+
go 1.24.1
4+
5+
require github.com/spinframework/spin-go-sdk/v2 v2.0.0
6+
7+
require (
8+
github.com/julienschmidt/httprouter v1.3.0 // indirect
9+
go.bytecodealliance.org/cm v0.2.2 // indirect
10+
)
11+
12+
replace github.com/spinframework/spin-go-sdk/v2 => ../../

v2/examples/llm/go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
2+
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
3+
go.bytecodealliance.org/cm v0.2.2 h1:M9iHS6qs884mbQbIjtLX1OifgyPG9DuMs2iwz8G4WQA=
4+
go.bytecodealliance.org/cm v0.2.2/go.mod h1:JD5vtVNZv7sBoQQkvBvAAVKJPhR/bqBH7yYXTItMfZI=

v2/examples/llm/main.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
spinhttp "github.com/spinframework/spin-go-sdk/v2/http"
8+
"github.com/spinframework/spin-go-sdk/v2/llm"
9+
)
10+
11+
func init() {
12+
spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) {
13+
result, err := llm.Infer("llama2-chat", "Tell me a joke", nil)
14+
if err != nil {
15+
http.Error(w, err.Error(), http.StatusInternalServerError)
16+
return
17+
}
18+
fmt.Printf("Prompt tokens: %d\n", result.Usage.PromptTokenCount)
19+
fmt.Printf("Generated tokens: %d\n", result.Usage.GeneratedTokenCount)
20+
fmt.Fprint(w, result.Text)
21+
fmt.Fprintf(w, "\n\n")
22+
23+
embeddings, err := llm.GenerateEmbeddings("all-minilm-l6-v2", []string{"Hello world"})
24+
if err != nil {
25+
http.Error(w, err.Error(), http.StatusInternalServerError)
26+
return
27+
}
28+
fmt.Printf("%d\n", len(embeddings.Embeddings[0]))
29+
fmt.Printf("Prompt Tokens: %d\n", embeddings.Usage.PromptTokenCount)
30+
31+
})
32+
}
33+
34+
func main() {}

v2/examples/llm/spin.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
spin_manifest_version = 2
2+
3+
[application]
4+
authors = ["Fermyon Engineering <[email protected]>"]
5+
description = "Simple example using the llm sdk."
6+
name = "llm-example"
7+
version = "0.1.0"
8+
9+
[[trigger.http]]
10+
route = "/..."
11+
component = "llm"
12+
13+
[component.llm]
14+
source = "main.wasm"
15+
allowed_outbound_hosts = []
16+
ai_models = ["llama2-chat", "all-minilm-l6-v2"]
17+
[component.llm.build]
18+
command = "tinygo build -target=wasip2 --wit-package $(go list -mod=readonly -m -f '{{.Dir}}' github.com/spinframework/spin-go-sdk/v2)/wit --wit-world http-trigger -gc=leaking -no-debug -o main.wasm main.go"
19+
watch = ["**/*.go", "go.mod"]

v2/llm/llm.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package llm
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm"
7+
"go.bytecodealliance.org/cm"
8+
)
9+
10+
// The model use for inferencing
11+
const (
12+
Llama2Chat InferencingModel = "llama2-chat"
13+
CodellamaInstruct InferencingModel = "codellama-instruct"
14+
)
15+
16+
type InferencingParams struct {
17+
// The maximum tokens that should be inferred.
18+
//
19+
// Note: the backing implementation may return less tokens.
20+
MaxTokens uint32 `json:"max-tokens"`
21+
22+
// The amount the model should avoid repeating tokens.
23+
RepeatPenalty float32 `json:"repeat-penalty"`
24+
25+
// The number of tokens the model should apply the repeat penalty to.
26+
RepeatPenaltyLastNTokenCount uint32 `json:"repeat-penalty-last-n-token-count"`
27+
28+
// The randomness with which the next token is selected.
29+
Temperature float32 `json:"temperature"`
30+
31+
// The number of possible next tokens the model will choose from.
32+
TopK uint32 `json:"top-k"`
33+
34+
// The probability total of next tokens the model will choose from.
35+
TopP float32 `json:"top-p"`
36+
}
37+
38+
type InferencingResult struct {
39+
// The text generated by the model
40+
// TODO: this should be a stream
41+
Text string `json:"text"`
42+
43+
// Usage information about the inferencing request
44+
Usage InferencingUsage `json:"usage"`
45+
}
46+
47+
// Usage information related to the inferencing result
48+
type InferencingUsage struct {
49+
_ cm.HostLayout `json:"-"`
50+
// Number of tokens in the prompt
51+
PromptTokenCount uint32 `json:"prompt-token-count"`
52+
53+
// Number of tokens generated by the inferencing operation
54+
GeneratedTokenCount uint32 `json:"generated-token-count"`
55+
}
56+
57+
// A Large Language Model.
58+
type InferencingModel string
59+
60+
// The model used for generating embeddings
61+
type EmbeddingModel string
62+
63+
type EmbeddingsResult struct {
64+
// Embeddings are the embeddings generated by the request.
65+
Embeddings [][]float32
66+
// Usage is usage related to an embeddings generation request.
67+
Usage *EmbeddingsUsage
68+
}
69+
70+
type EmbeddingsUsage struct {
71+
// PromptTokenCount is number of tokens in the prompt.
72+
PromptTokenCount int
73+
}
74+
75+
// Infer performs inferencing using the provided model and prompt with the
76+
// given optional parameters.
77+
func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) {
78+
iparams := cm.None[llm.InferencingParams]()
79+
if params != nil {
80+
iparams = cm.Some(llm.InferencingParams{
81+
MaxTokens: params.MaxTokens,
82+
RepeatPenalty: params.RepeatPenalty,
83+
RepeatPenaltyLastNTokenCount: params.RepeatPenaltyLastNTokenCount,
84+
Temperature: params.Temperature,
85+
TopK: params.TopK,
86+
TopP: params.TopP,
87+
})
88+
}
89+
90+
result := llm.Infer(llm.InferencingModel(model), prompt, iparams)
91+
if result.IsErr() {
92+
return InferencingResult{}, errorVariantToError(*result.Err())
93+
}
94+
95+
return InferencingResult{
96+
Text: result.OK().Text,
97+
Usage: InferencingUsage{
98+
PromptTokenCount: result.OK().Usage.PromptTokenCount,
99+
GeneratedTokenCount: result.OK().Usage.GeneratedTokenCount,
100+
},
101+
}, nil
102+
}
103+
104+
// GenerateEmbeddings generates the embeddings for the supplied list of text.
105+
func GenerateEmbeddings(model EmbeddingModel, text []string) (*EmbeddingsResult, error) {
106+
result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text))
107+
if result.IsErr() {
108+
return &EmbeddingsResult{}, errorVariantToError(*result.Err())
109+
}
110+
111+
embeddings := [][]float32{}
112+
for _, l := range result.OK().Embeddings.Slice() {
113+
embeddings = append(embeddings, l.Slice())
114+
}
115+
116+
return &EmbeddingsResult{
117+
Embeddings: embeddings,
118+
Usage: &EmbeddingsUsage{
119+
PromptTokenCount: int(result.OK().Usage.PromptTokenCount),
120+
},
121+
}, nil
122+
}
123+
124+
func errorVariantToError(err llm.Error) error {
125+
switch {
126+
case llm.ErrorModelNotSupported() == err:
127+
return fmt.Errorf("model not supported")
128+
case err.RuntimeError() != nil:
129+
return fmt.Errorf("runtime error %s", *err.RuntimeError())
130+
case err.InvalidInput() != nil:
131+
return fmt.Errorf("invalid input %s", *err.InvalidInput())
132+
default:
133+
return fmt.Errorf("no error provided by host implementation")
134+
}
135+
}

0 commit comments

Comments
 (0)