diff --git a/config.go b/config.go index 4788ba62a..4b8cfb6fb 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "regexp" + "strings" ) const ( @@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { APIType: APITypeAzure, APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { - return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + // only 3.5 models have the "." stripped in their names + if strings.Contains(model, "3.5") { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + } + return strings.ReplaceAll(model, ":", "") }, HTTPClient: &http.Client{}, diff --git a/config_test.go b/config_test.go index 960230804..f44b80825 100644 --- a/config_test.go +++ b/config_test.go @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { Model: "gpt-3.5-turbo-0301", Expect: "gpt-35-turbo-0301", }, + { + Model: "gpt-4.1", + Expect: "gpt-4.1", + }, { Model: "text-embedding-ada-002", Expect: "text-embedding-ada-002",