Skip to content

Commit 859c836

Browse files
committed
feat(sdk/go): implement llm SDK for TinyGo
Signed-off-by: Adam Reese <[email protected]>
1 parent 500001c commit 859c836

File tree

5 files changed

+530
-2
lines changed

5 files changed

+530
-2
lines changed

Makefile

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ build-examples: $(EXAMPLES_DIR)/tinygo-outbound-redis/main.wasm
2828
build-examples: $(EXAMPLES_DIR)/tinygo-redis/main.wasm
2929
build-examples: $(EXAMPLES_DIR)/tinygo-key-value/main.wasm
3030
build-examples: $(EXAMPLES_DIR)/tinygo-sqlite/main.wasm
31+
build-examples: $(EXAMPLES_DIR)/tinygo-llm/main.wasm
3132

3233
$(EXAMPLES_DIR)/%/main.wasm: $(EXAMPLES_DIR)/%/main.go
3334
tinygo build -target=wasi -gc=leaking -no-debug -o $@ $<
@@ -42,13 +43,14 @@ GENERATED_OUTBOUND_REDIS = redis/outbound-redis.c redis/outbound-redis.h
4243
GENERATED_SPIN_REDIS = redis/spin-redis.c redis/spin-redis.h
4344
GENERATED_KEY_VALUE = key_value/key-value.c key_value/key-value.h
4445
GENERATED_SQLITE = sqlite/sqlite.c sqlite/sqlite.h
46+
GENERATED_LLM = llm/llm.c llm/llm.h
4547

4648
SDK_VERSION_SOURCE_FILE = sdk_version/sdk-version-go-template.c
4749

4850
# NOTE: Please update this list if you add a new directory to the SDK:
4951
SDK_VERSION_DEST_FILES = config/sdk-version-go.c http/sdk-version-go.c \
5052
key_value/sdk-version-go.c redis/sdk-version-go.c \
51-
sqlite/sdk-version-go.c
53+
sqlite/sdk-version-go.c llm/sdk-version-go.c
5254

5355
# NOTE: To generate the C bindings you need to install a forked version of wit-bindgen.
5456
#
@@ -58,7 +60,7 @@ SDK_VERSION_DEST_FILES = config/sdk-version-go.c http/sdk-version-go.c \
5860
generate: $(GENERATED_OUTBOUND_HTTP) $(GENERATED_SPIN_HTTP)
5961
generate: $(GENERATED_OUTBOUND_REDIS) $(GENERATED_SPIN_REDIS)
6062
generate: $(GENERATED_SPIN_CONFIG) $(GENERATED_KEY_VALUE)
61-
generate: $(GENERATED_SQLITE)
63+
generate: $(GENERATED_SQLITE) $(GENERATED_LLM)
6264
generate: $(SDK_VERSION_DEST_FILES)
6365

6466
$(SDK_VERSION_DEST_FILES): $(SDK_VERSION_SOURCE_FILE)
@@ -87,6 +89,9 @@ $(GENERATED_KEY_VALUE):
8789
$(GENERATED_SQLITE):
8890
wit-bindgen c --import ../../wit/ephemeral/sqlite.wit --out-dir ./sqlite
8991

92+
$(GENERATED_LLM):
93+
wit-bindgen c --import ../../wit/ephemeral/llm.wit --out-dir ./llm
94+
9095
# ----------------------------------------------------------------------
9196
# Cleanup
9297
# ----------------------------------------------------------------------
@@ -96,6 +101,7 @@ clean:
96101
rm -f $(GENERATED_OUTBOUND_HTTP) $(GENERATED_SPIN_HTTP)
97102
rm -f $(GENERATED_OUTBOUND_REDIS) $(GENERATED_SPIN_REDIS)
98103
rm -f $(GENERATED_KEY_VALUE) $(GENERATED_SQLITE)
104+
rm -f $(GENERATED_LLM)
99105
rm -f $(GENERATED_SDK_VERSION)
100106
rm -f http/testdata/http-tinygo/main.wasm
101107
rm -f $(EXAMPLES_DIR)/http-tinygo/main.wasm
@@ -104,4 +110,5 @@ clean:
104110
rm -f $(EXAMPLES_DIR)/tinygo-redis/main.wasm
105111
rm -f $(EXAMPLES_DIR)/tinygo-key-value/main.wasm
106112
rm -f $(EXAMPLES_DIR)/tinygo-sqlite/main.wasm
113+
rm -f $(EXAMPLES_DIR)/tinygo-llm/main.wasm
107114
rm -f $(SDK_VERSION_DEST_FILES)

llm/internals.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package llm
2+
3+
// #include "llm.h"
4+
import "C"
5+
import (
6+
"errors"
7+
"fmt"
8+
"unsafe"
9+
)
10+
11+
func infer(model, prompt string, params *InferencingParams) (*InferencingResult, error) {
12+
llmModel := toLLMModel(model)
13+
llmPrompt := toLLMString(prompt)
14+
llmParams := toLLMInferencingParams(params)
15+
16+
var ret C.llm_expected_inferencing_result_error_t
17+
defer C.llm_expected_inferencing_result_error_free(&ret)
18+
19+
C.llm_infer(&llmModel, &llmPrompt, &llmParams, &ret)
20+
if ret.is_err {
21+
return nil, toErr((*C.llm_error_t)(unsafe.Pointer(&ret.val)))
22+
}
23+
24+
result := (*C.llm_inferencing_result_t)(unsafe.Pointer(&ret.val))
25+
26+
r := &InferencingResult{
27+
Text: C.GoStringN(result.text.ptr, C.int(result.text.len)),
28+
Usage: &InferencingUsage{
29+
PromptTokenCount: int(result.usage.prompt_token_count),
30+
GeneratedTokenCount: int(result.usage.generated_token_count),
31+
},
32+
}
33+
return r, nil
34+
}
35+
36+
func toErr(err *C.llm_error_t) error {
37+
switch err.tag {
38+
case 0:
39+
return errors.New("model not supported")
40+
case 1:
41+
str := (*C.llm_string_t)(unsafe.Pointer(&err.val))
42+
return fmt.Errorf("runtime error: %s", C.GoStringN(str.ptr, C.int(str.len)))
43+
case 2:
44+
str := (*C.llm_string_t)(unsafe.Pointer(&err.val))
45+
return fmt.Errorf("invalid input error: %s", C.GoStringN(str.ptr, C.int(str.len)))
46+
default:
47+
return fmt.Errorf("unrecognized error: %v", err.tag)
48+
}
49+
}
50+
51+
func toLLMModel(name string) C.llm_inferencing_model_t {
52+
llmString := toLLMString(name)
53+
return *(*C.llm_inferencing_model_t)(unsafe.Pointer(&llmString.ptr))
54+
}
55+
56+
func toLLMString(x string) C.llm_string_t {
57+
return C.llm_string_t{ptr: C.CString(x), len: C.size_t(len(x))}
58+
}
59+
60+
func toLLMInferencingParams(p *InferencingParams) C.llm_option_inferencing_params_t {
61+
if p == nil {
62+
return C.llm_option_inferencing_params_t{is_some: false}
63+
}
64+
llmParams := C.llm_inferencing_params_t{
65+
max_tokens: C.uint32_t(p.MaxTokens),
66+
repeat_penalty: C.float(p.RepeatPenalty),
67+
repeat_penalty_last_n_token_count: C.uint32_t(p.RepeatPenaltyLastNTokenCount),
68+
temperature: C.float(p.Temperature),
69+
top_k: C.uint32_t(p.TopK),
70+
top_p: C.float(p.TopP),
71+
}
72+
return C.llm_option_inferencing_params_t{is_some: true, val: llmParams}
73+
}
74+
75+
func generateEmbeddings(model string, text []string) (*EmbeddingsResult, error) {
76+
llmModel := toLLMEmbeddingModel(model)
77+
llmListString := toLLMListString(text)
78+
79+
var ret C.llm_expected_embeddings_result_error_t
80+
defer C.llm_expected_embeddings_result_error_free(&ret)
81+
82+
C.llm_generate_embeddings(&llmModel, &llmListString, &ret)
83+
if ret.is_err {
84+
return nil, toErr((*C.llm_error_t)(unsafe.Pointer(&ret.val)))
85+
}
86+
87+
result := (*C.llm_embeddings_result_t)(unsafe.Pointer(&ret.val))
88+
89+
r := &EmbeddingsResult{
90+
Embeddings: fromLLMListListFloat32(result.embeddings),
91+
Usage: &EmbeddingsUsage{
92+
PromptTokenCount: int(result.usage.prompt_token_count),
93+
},
94+
}
95+
return r, nil
96+
}
97+
98+
func toLLMEmbeddingModel(name string) C.llm_embedding_model_t {
99+
llmString := toLLMString(name)
100+
return *(*C.llm_embedding_model_t)(unsafe.Pointer(&llmString.ptr))
101+
}
102+
103+
func toLLMListString(xs []string) C.llm_list_string_t {
104+
cxs := make([]C.llm_string_t, len(xs))
105+
for i := 0; i < len(xs); i++ {
106+
cxs[i] = toLLMString(xs[i])
107+
}
108+
return C.llm_list_string_t{ptr: &cxs[0], len: C.size_t(len(cxs))}
109+
}
110+
111+
func fromLLMListListFloat32(list C.llm_list_list_float32_t) [][]float32 {
112+
listLen := int(list.len)
113+
ret := make([][]float32, listLen)
114+
slice := unsafe.Slice(list.ptr, listLen)
115+
for i := 0; i < listLen; i++ {
116+
row := *((*C.llm_list_float32_t)(unsafe.Pointer(&slice[i])))
117+
ret[i] = fromLLMListFloat32(row)
118+
}
119+
return ret
120+
}
121+
122+
func fromLLMListFloat32(list C.llm_list_float32_t) []float32 {
123+
listLen := int(list.len)
124+
ret := make([]float32, listLen)
125+
slice := unsafe.Slice(list.ptr, listLen)
126+
for i := 0; i < listLen; i++ {
127+
v := *((*C.float)(unsafe.Pointer(&slice[i])))
128+
ret[i] = float32(v)
129+
}
130+
return ret
131+
}

0 commit comments

Comments
 (0)