Skip to content

Commit 503917f

Browse files
authored
Merge pull request #1831 from adamreese/feat/llm-sdk-go
feat(sdk/go): implement llm SDK for TinyGo
2 parents fd4854b + 19297e8 commit 503917f

File tree

10 files changed

+596
-2
lines changed

10 files changed

+596
-2
lines changed

examples/tinygo-llm/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
main.wasm
2+
.spin/

examples/tinygo-llm/go.mod

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module github.com/tinygo_llm
2+
3+
go 1.17
4+
5+
require github.com/fermyon/spin/sdk/go v0.0.0
6+
7+
require github.com/julienschmidt/httprouter v1.3.0 // indirect
8+
9+
replace github.com/fermyon/spin/sdk/go v0.0.0 => ../../sdk/go/

examples/tinygo-llm/go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/fermyon/spin/sdk/go v1.4.1 h1:n8KTYTnkErTJdyMBBEtPmJe8dXrvMT6R7iVWbLRjq5E=
2+
github.com/fermyon/spin/sdk/go v1.4.1/go.mod h1:yb8lGesopgj/GwPzLPATxcOeqWZT/HjrzEFfwbztAXE=
3+
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
4+
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=

examples/tinygo-llm/main.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
spinhttp "github.com/fermyon/spin/sdk/go/http"
8+
"github.com/fermyon/spin/sdk/go/llm"
9+
)
10+
11+
func init() {
12+
spinhttp.Handle(func(w http.ResponseWriter, r *http.Request) {
13+
result, err := llm.Infer("llama2-chat", "Tell me a joke", nil)
14+
if err != nil {
15+
http.Error(w, err.Error(), http.StatusInternalServerError)
16+
return
17+
}
18+
fmt.Printf("Prompt tokens: %d\n", result.Usage.PromptTokenCount)
19+
fmt.Printf("Generated tokens: %d\n", result.Usage.GeneratedTokenCount)
20+
fmt.Fprint(w, result.Text)
21+
fmt.Fprintf(w, "\n\n")
22+
23+
embeddings, err := llm.GenerateEmbeddings("all-minilm-l6-v2", []string{"Hello world"})
24+
if err != nil {
25+
http.Error(w, err.Error(), http.StatusInternalServerError)
26+
return
27+
}
28+
fmt.Printf("%d\n", len(embeddings.Embeddings[0]))
29+
fmt.Printf("Prompt Tokens: %d\n", embeddings.Usage.PromptTokenCount)
30+
31+
})
32+
}
33+
34+
func main() {}

examples/tinygo-llm/spin.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
spin_manifest_version = "1"
2+
authors = ["Fermyon Engineering <[email protected]>"]
3+
description = "Simple example using the llm sdk."
4+
name = "tinygo-llm"
5+
trigger = { type = "http", base = "/" }
6+
version = "0.1.0"
7+
8+
[[component]]
9+
id = "tinygo-llm"
10+
source = "main.wasm"
11+
allowed_http_hosts = []
12+
ai_models = ["llama2-chat", "all-minilm-l6-v2"]
13+
[component.trigger]
14+
route = "/..."
15+
[component.build]
16+
command = "tinygo build -target=wasi -gc=leaking -no-debug -o main.wasm main.go"
17+
watch = ["**/*.go", "go.mod"]

sdk/go/Makefile

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ build-examples: $(EXAMPLES_DIR)/tinygo-outbound-redis/main.wasm
2828
build-examples: $(EXAMPLES_DIR)/tinygo-redis/main.wasm
2929
build-examples: $(EXAMPLES_DIR)/tinygo-key-value/main.wasm
3030
build-examples: $(EXAMPLES_DIR)/tinygo-sqlite/main.wasm
31+
build-examples: $(EXAMPLES_DIR)/tinygo-llm/main.wasm
3132

3233
$(EXAMPLES_DIR)/%/main.wasm: $(EXAMPLES_DIR)/%/main.go
3334
tinygo build -target=wasi -gc=leaking -no-debug -o $@ $<
@@ -42,13 +43,14 @@ GENERATED_OUTBOUND_REDIS = redis/outbound-redis.c redis/outbound-redis.h
4243
GENERATED_SPIN_REDIS = redis/spin-redis.c redis/spin-redis.h
4344
GENERATED_KEY_VALUE = key_value/key-value.c key_value/key-value.h
4445
GENERATED_SQLITE = sqlite/sqlite.c sqlite/sqlite.h
46+
GENERATED_LLM = llm/llm.c llm/llm.h
4547

4648
SDK_VERSION_SOURCE_FILE = sdk_version/sdk-version-go-template.c
4749

4850
# NOTE: Please update this list if you add a new directory to the SDK:
4951
SDK_VERSION_DEST_FILES = config/sdk-version-go.c http/sdk-version-go.c \
5052
key_value/sdk-version-go.c redis/sdk-version-go.c \
51-
sqlite/sdk-version-go.c
53+
sqlite/sdk-version-go.c llm/sdk-version-go.c
5254

5355
# NOTE: To generate the C bindings you need to install a forked version of wit-bindgen.
5456
#
@@ -58,7 +60,7 @@ SDK_VERSION_DEST_FILES = config/sdk-version-go.c http/sdk-version-go.c \
5860
generate: $(GENERATED_OUTBOUND_HTTP) $(GENERATED_SPIN_HTTP)
5961
generate: $(GENERATED_OUTBOUND_REDIS) $(GENERATED_SPIN_REDIS)
6062
generate: $(GENERATED_SPIN_CONFIG) $(GENERATED_KEY_VALUE)
61-
generate: $(GENERATED_SQLITE)
63+
generate: $(GENERATED_SQLITE) $(GENERATED_LLM)
6264
generate: $(SDK_VERSION_DEST_FILES)
6365

6466
$(SDK_VERSION_DEST_FILES): $(SDK_VERSION_SOURCE_FILE)
@@ -87,6 +89,9 @@ $(GENERATED_KEY_VALUE):
8789
$(GENERATED_SQLITE):
8890
wit-bindgen c --import ../../wit/ephemeral/sqlite.wit --out-dir ./sqlite
8991

92+
$(GENERATED_LLM):
93+
wit-bindgen c --import ../../wit/ephemeral/llm.wit --out-dir ./llm
94+
9095
# ----------------------------------------------------------------------
9196
# Cleanup
9297
# ----------------------------------------------------------------------
@@ -96,6 +101,7 @@ clean:
96101
rm -f $(GENERATED_OUTBOUND_HTTP) $(GENERATED_SPIN_HTTP)
97102
rm -f $(GENERATED_OUTBOUND_REDIS) $(GENERATED_SPIN_REDIS)
98103
rm -f $(GENERATED_KEY_VALUE) $(GENERATED_SQLITE)
104+
rm -f $(GENERATED_LLM)
99105
rm -f $(GENERATED_SDK_VERSION)
100106
rm -f http/testdata/http-tinygo/main.wasm
101107
rm -f $(EXAMPLES_DIR)/http-tinygo/main.wasm
@@ -104,4 +110,5 @@ clean:
104110
rm -f $(EXAMPLES_DIR)/tinygo-redis/main.wasm
105111
rm -f $(EXAMPLES_DIR)/tinygo-key-value/main.wasm
106112
rm -f $(EXAMPLES_DIR)/tinygo-sqlite/main.wasm
113+
rm -f $(EXAMPLES_DIR)/tinygo-llm/main.wasm
107114
rm -f $(SDK_VERSION_DEST_FILES)

sdk/go/llm/internals.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package llm
2+
3+
// #include "llm.h"
4+
import "C"
5+
import (
6+
"errors"
7+
"fmt"
8+
"unsafe"
9+
)
10+
11+
func infer(model, prompt string, params *InferencingParams) (*InferencingResult, error) {
12+
llmModel := toLLMModel(model)
13+
llmPrompt := toLLMString(prompt)
14+
llmParams := toLLMInferencingParams(params)
15+
16+
var ret C.llm_expected_inferencing_result_error_t
17+
defer C.llm_expected_inferencing_result_error_free(&ret)
18+
19+
C.llm_infer(&llmModel, &llmPrompt, &llmParams, &ret)
20+
if ret.is_err {
21+
return nil, toErr((*C.llm_error_t)(unsafe.Pointer(&ret.val)))
22+
}
23+
24+
result := (*C.llm_inferencing_result_t)(unsafe.Pointer(&ret.val))
25+
26+
r := &InferencingResult{
27+
Text: C.GoStringN(result.text.ptr, C.int(result.text.len)),
28+
Usage: &InferencingUsage{
29+
PromptTokenCount: int(result.usage.prompt_token_count),
30+
GeneratedTokenCount: int(result.usage.generated_token_count),
31+
},
32+
}
33+
return r, nil
34+
}
35+
36+
func toErr(err *C.llm_error_t) error {
37+
switch err.tag {
38+
case 0:
39+
return errors.New("model not supported")
40+
case 1:
41+
str := (*C.llm_string_t)(unsafe.Pointer(&err.val))
42+
return fmt.Errorf("runtime error: %s", C.GoStringN(str.ptr, C.int(str.len)))
43+
case 2:
44+
str := (*C.llm_string_t)(unsafe.Pointer(&err.val))
45+
return fmt.Errorf("invalid input error: %s", C.GoStringN(str.ptr, C.int(str.len)))
46+
default:
47+
return fmt.Errorf("unrecognized error: %v", err.tag)
48+
}
49+
}
50+
51+
func toLLMModel(name string) C.llm_inferencing_model_t {
52+
llmString := toLLMString(name)
53+
return *(*C.llm_inferencing_model_t)(unsafe.Pointer(&llmString.ptr))
54+
}
55+
56+
func toLLMString(x string) C.llm_string_t {
57+
return C.llm_string_t{ptr: C.CString(x), len: C.size_t(len(x))}
58+
}
59+
60+
func toLLMInferencingParams(p *InferencingParams) C.llm_option_inferencing_params_t {
61+
if p == nil {
62+
return C.llm_option_inferencing_params_t{is_some: false}
63+
}
64+
llmParams := C.llm_inferencing_params_t{
65+
max_tokens: C.uint32_t(p.MaxTokens),
66+
repeat_penalty: C.float(p.RepeatPenalty),
67+
repeat_penalty_last_n_token_count: C.uint32_t(p.RepeatPenaltyLastNTokenCount),
68+
temperature: C.float(p.Temperature),
69+
top_k: C.uint32_t(p.TopK),
70+
top_p: C.float(p.TopP),
71+
}
72+
return C.llm_option_inferencing_params_t{is_some: true, val: llmParams}
73+
}
74+
75+
func generateEmbeddings(model string, text []string) (*EmbeddingsResult, error) {
76+
llmModel := toLLMEmbeddingModel(model)
77+
llmListString := toLLMListString(text)
78+
79+
var ret C.llm_expected_embeddings_result_error_t
80+
defer C.llm_expected_embeddings_result_error_free(&ret)
81+
82+
C.llm_generate_embeddings(&llmModel, &llmListString, &ret)
83+
if ret.is_err {
84+
return nil, toErr((*C.llm_error_t)(unsafe.Pointer(&ret.val)))
85+
}
86+
87+
result := (*C.llm_embeddings_result_t)(unsafe.Pointer(&ret.val))
88+
89+
r := &EmbeddingsResult{
90+
Embeddings: fromLLMListListFloat32(result.embeddings),
91+
Usage: &EmbeddingsUsage{
92+
PromptTokenCount: int(result.usage.prompt_token_count),
93+
},
94+
}
95+
return r, nil
96+
}
97+
98+
func toLLMEmbeddingModel(name string) C.llm_embedding_model_t {
99+
llmString := toLLMString(name)
100+
return *(*C.llm_embedding_model_t)(unsafe.Pointer(&llmString.ptr))
101+
}
102+
103+
func toLLMListString(xs []string) C.llm_list_string_t {
104+
cxs := make([]C.llm_string_t, len(xs))
105+
for i := 0; i < len(xs); i++ {
106+
cxs[i] = toLLMString(xs[i])
107+
}
108+
return C.llm_list_string_t{ptr: &cxs[0], len: C.size_t(len(cxs))}
109+
}
110+
111+
func fromLLMListListFloat32(list C.llm_list_list_float32_t) [][]float32 {
112+
listLen := int(list.len)
113+
ret := make([][]float32, listLen)
114+
slice := unsafe.Slice(list.ptr, listLen)
115+
for i := 0; i < listLen; i++ {
116+
row := *((*C.llm_list_float32_t)(unsafe.Pointer(&slice[i])))
117+
ret[i] = fromLLMListFloat32(row)
118+
}
119+
return ret
120+
}
121+
122+
func fromLLMListFloat32(list C.llm_list_float32_t) []float32 {
123+
listLen := int(list.len)
124+
ret := make([]float32, listLen)
125+
slice := unsafe.Slice(list.ptr, listLen)
126+
for i := 0; i < listLen; i++ {
127+
v := *((*C.float)(unsafe.Pointer(&slice[i])))
128+
ret[i] = float32(v)
129+
}
130+
return ret
131+
}

0 commit comments

Comments
 (0)