diff --git a/typesense/api/client_gen.go b/typesense/api/client_gen.go index cf7d6a8..d29616d 100644 --- a/typesense/api/client_gen.go +++ b/typesense/api/client_gen.go @@ -10328,7 +10328,7 @@ func ParseCreateNLSearchModelResponse(rsp *http.Response) (*CreateNLSearchModelR } switch { - case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 201: + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && (rsp.StatusCode == 200 || rsp.StatusCode == 201): var dest NLSearchModelSchema if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err diff --git a/typesense/client.go b/typesense/client.go index 33feb5e..161553e 100644 --- a/typesense/client.go +++ b/typesense/client.go @@ -76,6 +76,14 @@ func (c *Client) Preset(presetName string) PresetInterface { return &preset{apiClient: c.apiClient, presetName: presetName} } +func (c *Client) NLSearchModels() NLSearchModelsInterface { + return &nlSearchModels{apiClient: c.apiClient} +} + +func (c *Client) NLSearchModel(modelID string) NLSearchModelInterface { + return &nlSearchModel{apiClient: c.apiClient, modelID: modelID} +} + func (c *Client) Stopwords() StopwordsInterface { return &stopwords{apiClient: c.apiClient} } diff --git a/typesense/nl_search_model.go b/typesense/nl_search_model.go new file mode 100644 index 0000000..b0e6ea7 --- /dev/null +++ b/typesense/nl_search_model.go @@ -0,0 +1,51 @@ +package typesense + +import ( + "context" + + "github.com/typesense/typesense-go/v3/typesense/api" +) + +type NLSearchModelInterface interface { + Retrieve(ctx context.Context) (*api.NLSearchModelSchema, error) + Update(ctx context.Context, model *api.NLSearchModelUpdateSchema) (*api.NLSearchModelSchema, error) + Delete(ctx context.Context) (*api.NLSearchModelDeleteSchema, error) +} + +type nlSearchModel struct { + apiClient APIClientInterface + modelID string +} + +func (n *nlSearchModel) Retrieve(ctx context.Context) (*api.NLSearchModelSchema, error) { + response, err := n.apiClient.RetrieveNLSearchModelWithResponse(ctx, n.modelID) + if err != nil { + return nil, err + } + if response.JSON200 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + return response.JSON200, nil +} + +func (n *nlSearchModel) Update(ctx context.Context, model *api.NLSearchModelUpdateSchema) (*api.NLSearchModelSchema, error) { + response, err := n.apiClient.UpdateNLSearchModelWithResponse(ctx, n.modelID, *model) + if err != nil { + return nil, err + } + if response.JSON200 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + return response.JSON200, nil +} + +func (n *nlSearchModel) Delete(ctx context.Context) (*api.NLSearchModelDeleteSchema, error) { + response, err := n.apiClient.DeleteNLSearchModelWithResponse(ctx, n.modelID) + if err != nil { + return nil, err + } + if response.JSON200 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + return response.JSON200, nil +} diff --git a/typesense/nl_search_models.go b/typesense/nl_search_models.go new file mode 100644 index 0000000..2e0de70 --- /dev/null +++ b/typesense/nl_search_models.go @@ -0,0 +1,45 @@ +package typesense + +import ( + "context" + + "github.com/typesense/typesense-go/v3/typesense/api" +) + +type NLSearchModelsInterface interface { + Retrieve(ctx context.Context) ([]*api.NLSearchModelSchema, error) + Create(ctx context.Context, model *api.NLSearchModelCreateSchema) (*api.NLSearchModelSchema, error) +} + +type nlSearchModels struct { + apiClient APIClientInterface +} + +func (n *nlSearchModels) Retrieve(ctx context.Context) ([]*api.NLSearchModelSchema, error) { + response, err := n.apiClient.RetrieveAllNLSearchModelsWithResponse(ctx) + if err != nil { + return nil, err + } + if response.JSON200 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + + // Convert []NLSearchModelSchema to []*NLSearchModelSchema + result := make([]*api.NLSearchModelSchema, len(*response.JSON200)) + for i, model := range *response.JSON200 { + modelCopy := model // Create a copy to get address + result[i] = &modelCopy + } + return result, nil +} + +func (n *nlSearchModels) Create(ctx context.Context, model *api.NLSearchModelCreateSchema) (*api.NLSearchModelSchema, error) { + response, err := n.apiClient.CreateNLSearchModelWithResponse(ctx, *model) + if err != nil { + return nil, err + } + if response.JSON201 == nil { + return nil, &HTTPError{Status: response.StatusCode(), Body: response.Body} + } + return response.JSON201, nil +} diff --git a/typesense/test/dbhelpers_test.go b/typesense/test/dbhelpers_test.go index 96e4777..c4cdb04 100644 --- a/typesense/test/dbhelpers_test.go +++ b/typesense/test/dbhelpers_test.go @@ -6,6 +6,7 @@ package test import ( "context" "fmt" + "os" "testing" "time" @@ -441,3 +442,70 @@ func retrieveDocuments(t *testing.T, collectionName string, docIDs ...string) [] } return results } + +func newNLSearchModelCreateSchema() *api.NLSearchModelCreateSchema { + apiKey := os.Getenv("NL_SEARCH_MODEL_API_KEY") + + return &api.NLSearchModelCreateSchema{ + ModelName: pointer.String("openai/gpt-3.5-turbo"), + ApiKey: pointer.String(apiKey), + MaxBytes: pointer.Int(1000), + Temperature: pointer.Float32(0.7), + SystemPrompt: pointer.String("You are a helpful assistant."), + TopP: pointer.Float32(0.9), + TopK: pointer.Int(40), + StopSequences: &[]string{"END", "STOP"}, + ApiVersion: pointer.String("v1"), + } +} + +func newNLSearchModelSchema(modelID string) *api.NLSearchModelSchema { + apiKey := os.Getenv("NL_SEARCH_MODEL_API_KEY") + + return &api.NLSearchModelSchema{ + Id: modelID, + ModelName: pointer.String("openai/gpt-3.5-turbo"), + ApiKey: pointer.String(apiKey), + MaxBytes: pointer.Int(1000), + Temperature: pointer.Float32(0.7), + SystemPrompt: pointer.String("You are a helpful assistant."), + TopP: pointer.Float32(0.9), + TopK: pointer.Int(40), + StopSequences: &[]string{"END", "STOP"}, + ApiVersion: pointer.String("v1"), + } +} + +func newNLSearchModelUpdateSchema() *api.NLSearchModelUpdateSchema { + apiKey := os.Getenv("NL_SEARCH_MODEL_API_KEY") + + return &api.NLSearchModelUpdateSchema{ + ModelName: pointer.String("openai/gpt-4"), + ApiKey: pointer.String(apiKey), + MaxBytes: pointer.Int(2000), + Temperature: pointer.Float32(0.5), + SystemPrompt: pointer.String("You are an expert assistant."), + TopP: pointer.Float32(0.8), + TopK: pointer.Int(50), + StopSequences: &[]string{"END", "STOP", "QUIT"}, + ApiVersion: pointer.String("v1"), + } +} + +func shouldSkipNLSearchModelTests(t *testing.T) { + if os.Getenv("NL_SEARCH_MODEL_API_KEY") == "" { + t.Skip("Skipping NL search model test: NL_SEARCH_MODEL_API_KEY not set") + } +} + +func createNewNLSearchModel(t *testing.T) (string, *api.NLSearchModelSchema) { + t.Helper() + modelID := newUUIDName("nl-model-test") + modelSchema := newNLSearchModelCreateSchema() + modelSchema.Id = pointer.String(modelID) + + result, err := typesenseClient.NLSearchModels().Create(context.Background(), modelSchema) + + require.NoError(t, err) + return modelID, result +} diff --git a/typesense/test/nl_search_model_test.go b/typesense/test/nl_search_model_test.go new file mode 100644 index 0000000..e7fba4a --- /dev/null +++ b/typesense/test/nl_search_model_test.go @@ -0,0 +1,59 @@ +//go:build integration +// +build integration + +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/typesense/typesense-go/v3/typesense/api/pointer" +) + +func nlSearchModelsCleanUp() { + result, _ := typesenseClient.NLSearchModels().Retrieve(context.Background()) + for _, model := range result { + typesenseClient.NLSearchModel(model.Id).Delete(context.Background()) + } +} + +func TestNLSearchModel(t *testing.T) { + shouldSkipNLSearchModelTests(t) + t.Cleanup(nlSearchModelsCleanUp) + + t.Run("Retrieve", func(t *testing.T) { + modelID, expectedResult := createNewNLSearchModel(t) + + result, err := typesenseClient.NLSearchModel(modelID).Retrieve(context.Background()) + + require.NoError(t, err) + require.Equal(t, expectedResult, result) + }) + + t.Run("Update", func(t *testing.T) { + modelID, originalModel := createNewNLSearchModel(t) + + updateSchema := newNLSearchModelUpdateSchema() + updateSchema.Temperature = pointer.Float32(0.8) + + result, err := typesenseClient.NLSearchModel(modelID).Update(context.Background(), updateSchema) + + require.NoError(t, err) + require.Equal(t, "openai/gpt-4", *result.ModelName) + require.Equal(t, float32(0.8), *result.Temperature) + require.Equal(t, originalModel.Id, result.Id) + }) + + t.Run("Delete", func(t *testing.T) { + modelID, expectedResult := createNewNLSearchModel(t) + + result, err := typesenseClient.NLSearchModel(modelID).Delete(context.Background()) + + require.NoError(t, err) + require.Equal(t, expectedResult.Id, result.Id) + + _, err = typesenseClient.NLSearchModel(modelID).Retrieve(context.Background()) + require.Error(t, err) + }) +} \ No newline at end of file diff --git a/typesense/test/search_test.go b/typesense/test/search_test.go index 5296844..779847f 100644 --- a/typesense/test/search_test.go +++ b/typesense/test/search_test.go @@ -156,8 +156,6 @@ func TestCollectionSearchWithPreset(t *testing.T) { newDocument("123", withCompanyName("Company 1"), withNumEmployees(50)), newDocument("125", withCompanyName("Company 2"), withNumEmployees(150)), newDocument("127", withCompanyName("Company 3"), withNumEmployees(250)), - newDocument("129", withCompanyName("Stark Industries 4"), withNumEmployees(500)), - newDocument("131", withCompanyName("Stark Industries 5"), withNumEmployees(1000)), } params := &api.ImportDocumentsParams{Action: pointer.Any(api.Create)}