|
| 1 | +package llm |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + |
| 6 | + "github.com/spinframework/spin-go-sdk/v2/internal/fermyon/spin/v2.0.0/llm" |
| 7 | + "go.bytecodealliance.org/cm" |
| 8 | +) |
| 9 | + |
| 10 | +// The model use for inferencing |
| 11 | +const ( |
| 12 | + Llama2Chat InferencingModel = "llama2-chat" |
| 13 | + CodellamaInstruct InferencingModel = "codellama-instruct" |
| 14 | +) |
| 15 | + |
| 16 | +type InferencingParams struct { |
| 17 | + // The maximum tokens that should be inferred. |
| 18 | + // |
| 19 | + // Note: the backing implementation may return less tokens. |
| 20 | + MaxTokens uint32 `json:"max-tokens"` |
| 21 | + |
| 22 | + // The amount the model should avoid repeating tokens. |
| 23 | + RepeatPenalty float32 `json:"repeat-penalty"` |
| 24 | + |
| 25 | + // The number of tokens the model should apply the repeat penalty to. |
| 26 | + RepeatPenaltyLastNTokenCount uint32 `json:"repeat-penalty-last-n-token-count"` |
| 27 | + |
| 28 | + // The randomness with which the next token is selected. |
| 29 | + Temperature float32 `json:"temperature"` |
| 30 | + |
| 31 | + // The number of possible next tokens the model will choose from. |
| 32 | + TopK uint32 `json:"top-k"` |
| 33 | + |
| 34 | + // The probability total of next tokens the model will choose from. |
| 35 | + TopP float32 `json:"top-p"` |
| 36 | +} |
| 37 | + |
| 38 | +type InferencingResult struct { |
| 39 | + // The text generated by the model |
| 40 | + // TODO: this should be a stream |
| 41 | + Text string `json:"text"` |
| 42 | + |
| 43 | + // Usage information about the inferencing request |
| 44 | + Usage InferencingUsage `json:"usage"` |
| 45 | +} |
| 46 | + |
| 47 | +// Usage information related to the inferencing result |
| 48 | +type InferencingUsage struct { |
| 49 | + _ cm.HostLayout `json:"-"` |
| 50 | + // Number of tokens in the prompt |
| 51 | + PromptTokenCount uint32 `json:"prompt-token-count"` |
| 52 | + |
| 53 | + // Number of tokens generated by the inferencing operation |
| 54 | + GeneratedTokenCount uint32 `json:"generated-token-count"` |
| 55 | +} |
| 56 | + |
| 57 | +// A Large Language Model. |
| 58 | +type InferencingModel string |
| 59 | + |
| 60 | +// The model used for generating embeddings |
| 61 | +type EmbeddingModel string |
| 62 | + |
| 63 | +type EmbeddingsResult struct { |
| 64 | + // Embeddings are the embeddings generated by the request. |
| 65 | + Embeddings [][]float32 |
| 66 | + // Usage is usage related to an embeddings generation request. |
| 67 | + Usage *EmbeddingsUsage |
| 68 | +} |
| 69 | + |
| 70 | +type EmbeddingsUsage struct { |
| 71 | + // PromptTokenCount is number of tokens in the prompt. |
| 72 | + PromptTokenCount int |
| 73 | +} |
| 74 | + |
| 75 | +// Infer performs inferencing using the provided model and prompt with the |
| 76 | +// given optional parameters. |
| 77 | +func Infer(model string, prompt string, params *InferencingParams) (InferencingResult, error) { |
| 78 | + iparams := cm.None[llm.InferencingParams]() |
| 79 | + if params != nil { |
| 80 | + iparams = cm.Some(llm.InferencingParams{ |
| 81 | + MaxTokens: params.MaxTokens, |
| 82 | + RepeatPenalty: params.RepeatPenalty, |
| 83 | + RepeatPenaltyLastNTokenCount: params.RepeatPenaltyLastNTokenCount, |
| 84 | + Temperature: params.Temperature, |
| 85 | + TopK: params.TopK, |
| 86 | + TopP: params.TopP, |
| 87 | + }) |
| 88 | + } |
| 89 | + |
| 90 | + result := llm.Infer(llm.InferencingModel(model), prompt, iparams) |
| 91 | + if result.IsErr() { |
| 92 | + return InferencingResult{}, errorVariantToError(*result.Err()) |
| 93 | + } |
| 94 | + |
| 95 | + return InferencingResult{ |
| 96 | + Text: result.OK().Text, |
| 97 | + Usage: InferencingUsage{ |
| 98 | + PromptTokenCount: result.OK().Usage.PromptTokenCount, |
| 99 | + GeneratedTokenCount: result.OK().Usage.GeneratedTokenCount, |
| 100 | + }, |
| 101 | + }, nil |
| 102 | +} |
| 103 | + |
| 104 | +// GenerateEmbeddings generates the embeddings for the supplied list of text. |
| 105 | +func GenerateEmbeddings(model EmbeddingModel, text []string) (*EmbeddingsResult, error) { |
| 106 | + result := llm.GenerateEmbeddings(llm.EmbeddingModel(model), cm.ToList(text)) |
| 107 | + if result.IsErr() { |
| 108 | + return &EmbeddingsResult{}, errorVariantToError(*result.Err()) |
| 109 | + } |
| 110 | + |
| 111 | + embeddings := [][]float32{} |
| 112 | + for _, l := range result.OK().Embeddings.Slice() { |
| 113 | + embeddings = append(embeddings, l.Slice()) |
| 114 | + } |
| 115 | + |
| 116 | + return &EmbeddingsResult{ |
| 117 | + Embeddings: embeddings, |
| 118 | + Usage: &EmbeddingsUsage{ |
| 119 | + PromptTokenCount: int(result.OK().Usage.PromptTokenCount), |
| 120 | + }, |
| 121 | + }, nil |
| 122 | +} |
| 123 | + |
| 124 | +func errorVariantToError(err llm.Error) error { |
| 125 | + switch { |
| 126 | + case llm.ErrorModelNotSupported() == err: |
| 127 | + return fmt.Errorf("model not supported") |
| 128 | + case err.RuntimeError() != nil: |
| 129 | + return fmt.Errorf("runtime error %s", *err.RuntimeError()) |
| 130 | + case err.InvalidInput() != nil: |
| 131 | + return fmt.Errorf("invalid input %s", *err.InvalidInput()) |
| 132 | + default: |
| 133 | + return fmt.Errorf("no error provided by host implementation") |
| 134 | + } |
| 135 | +} |
0 commit comments