|
| 1 | +package llm |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + |
| 6 | + "github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm" |
| 7 | + "go.bytecodealliance.org/cm" |
| 8 | +) |
| 9 | + |
| 10 | +// The model use for inferencing |
| 11 | +const ( |
| 12 | + Llama2Chat InferencingModel = "llama2-chat" |
| 13 | + CodellamaInstruct InferencingModel = "codellama-instruct" |
| 14 | +) |
| 15 | + |
| 16 | +type InferencingParams llm.InferencingParams |
| 17 | +type InferencingResult llm.InferencingResult |
| 18 | +type InferencingModel llm.InferencingModel |
| 19 | + |
| 20 | +// type EmbeddingsResult llm.EmbeddingsResult |
| 21 | + |
| 22 | +type EmbeddingsResult struct { |
| 23 | + // Embeddings are the embeddings generated by the request. |
| 24 | + Embeddings [][]float32 |
| 25 | + // Usage is usage related to an embeddings generation request. |
| 26 | + Usage *EmbeddingsUsage |
| 27 | +} |
| 28 | + |
| 29 | +type EmbeddingsUsage struct { |
| 30 | + // PromptTokenCount is number of tokens in the prompt. |
| 31 | + PromptTokenCount int |
| 32 | +} |
| 33 | + |
| 34 | +// Infer performs inferencing using the provided model and prompt with the |
| 35 | +// given optional parameters. |
| 36 | +func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) { |
| 37 | + var iparams = cm.None[llm.InferencingParams]() |
| 38 | + if params != nil { |
| 39 | + iparams = cm.Some(llm.InferencingParams(*params)) |
| 40 | + } |
| 41 | + |
| 42 | + result := llm.Infer(llm.InferencingModel(model), prompt, iparams) |
| 43 | + if result.IsErr() { |
| 44 | + return InferencingResult{}, errorVariantToError(*result.Err()) |
| 45 | + } |
| 46 | + |
| 47 | + return InferencingResult(*result.OK()), nil |
| 48 | +} |
| 49 | + |
| 50 | +// GenerateEmbeddings generates the embeddings for the supplied list of text. |
| 51 | +func GenerateEmbeddings(model InferencingModel, text []string) (*EmbeddingsResult, error) { |
| 52 | + result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text)) |
| 53 | + if result.IsErr() { |
| 54 | + return &EmbeddingsResult{}, errorVariantToError(*result.Err()) |
| 55 | + } |
| 56 | + |
| 57 | + llmEmbeddingResult := llm.EmbeddingsResult(*result.OK()) |
| 58 | + |
| 59 | + list := llmEmbeddingResult.Embeddings |
| 60 | + list2 := list.Slice() |
| 61 | + list3 := [][]float32{} |
| 62 | + for _, l := range list2 { |
| 63 | + list3 = append(list3, l.Slice()) |
| 64 | + } |
| 65 | + |
| 66 | + return &EmbeddingsResult{ |
| 67 | + Embeddings: list3, |
| 68 | + Usage: &EmbeddingsUsage{ |
| 69 | + PromptTokenCount: int(llmEmbeddingResult.Usage.PromptTokenCount), |
| 70 | + }, |
| 71 | + }, nil |
| 72 | +} |
| 73 | + |
| 74 | +func errorVariantToError(err llm.Error) error { |
| 75 | + switch { |
| 76 | + case llm.ErrorModelNotSupported() == err: |
| 77 | + return fmt.Errorf("model not supported") |
| 78 | + case err.RuntimeError() != nil: |
| 79 | + return fmt.Errorf("runtime error %s", *err.RuntimeError()) |
| 80 | + case err.InvalidInput() != nil: |
| 81 | + return fmt.Errorf("invalid input %s", *err.InvalidInput()) |
| 82 | + default: |
| 83 | + return fmt.Errorf("no error provided by host implementation") |
| 84 | + } |
| 85 | +} |
0 commit comments