Skip to content

Commit 67300cb

Browse files
committed
feat: add tests for filter options
Signed-off-by: Sidney Glinton <[email protected]>
1 parent a6df1ad commit 67300cb

File tree

4 files changed

+297
-7
lines changed

4 files changed

+297
-7
lines changed

catalog/internal/catalog/db_catalog_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,168 @@ func TestDBCatalog(t *testing.T) {
794794
assert.Contains(t, err.Error(), "invalid model name")
795795
})
796796
})
797+
798+
t.Run("TestGetFilterOptions", func(t *testing.T) {
799+
// Create models with various properties for filter options testing
800+
model1 := &models.CatalogModelImpl{
801+
TypeID: apiutils.Of(int32(catalogModelTypeID)),
802+
Attributes: &models.CatalogModelAttributes{
803+
Name: apiutils.Of("filter-options-model-1"),
804+
ExternalID: apiutils.Of("filter-opt-1"),
805+
},
806+
Properties: &[]mr_models.Properties{
807+
{Name: "source_id", StringValue: apiutils.Of("filter-test-source")},
808+
{Name: "license", StringValue: apiutils.Of("MIT")},
809+
{Name: "provider", StringValue: apiutils.Of("HuggingFace")},
810+
{Name: "maturity", StringValue: apiutils.Of("stable")},
811+
{Name: "library_name", StringValue: apiutils.Of("transformers")},
812+
{Name: "language", StringValue: apiutils.Of(`["python", "rust"]`)},
813+
{Name: "tasks", StringValue: apiutils.Of(`["text-classification", "token-classification"]`)},
814+
},
815+
}
816+
817+
model2 := &models.CatalogModelImpl{
818+
TypeID: apiutils.Of(int32(catalogModelTypeID)),
819+
Attributes: &models.CatalogModelAttributes{
820+
Name: apiutils.Of("filter-options-model-2"),
821+
ExternalID: apiutils.Of("filter-opt-2"),
822+
},
823+
Properties: &[]mr_models.Properties{
824+
{Name: "source_id", StringValue: apiutils.Of("filter-test-source")},
825+
{Name: "license", StringValue: apiutils.Of("Apache-2.0")},
826+
{Name: "provider", StringValue: apiutils.Of("OpenAI")},
827+
{Name: "maturity", StringValue: apiutils.Of("experimental")},
828+
{Name: "library_name", StringValue: apiutils.Of("openai")},
829+
{Name: "language", StringValue: apiutils.Of(`["python", "javascript"]`)},
830+
{Name: "tasks", StringValue: apiutils.Of(`["text-generation", "conversational"]`)},
831+
{Name: "readme", StringValue: apiutils.Of("This is a very long readme that exceeds 100 characters and should be excluded from filter options because it's too verbose for filtering purposes.")},
832+
},
833+
}
834+
835+
model3 := &models.CatalogModelImpl{
836+
TypeID: apiutils.Of(int32(catalogModelTypeID)),
837+
Attributes: &models.CatalogModelAttributes{
838+
Name: apiutils.Of("filter-options-model-3"),
839+
ExternalID: apiutils.Of("filter-opt-3"),
840+
},
841+
Properties: &[]mr_models.Properties{
842+
{Name: "source_id", StringValue: apiutils.Of("filter-test-source")},
843+
{Name: "license", StringValue: apiutils.Of("MIT")},
844+
{Name: "provider", StringValue: apiutils.Of("PyTorch")},
845+
{Name: "maturity", StringValue: apiutils.Of("stable")},
846+
{Name: "language", StringValue: apiutils.Of(`["python"]`)},
847+
{Name: "tasks", StringValue: apiutils.Of(`["image-classification"]`)},
848+
{Name: "logo", StringValue: apiutils.Of("https://example.com/logo.png")},
849+
{Name: "license_link", StringValue: apiutils.Of("https://example.com/license")},
850+
},
851+
}
852+
853+
_, err := catalogModelRepo.Save(model1)
854+
require.NoError(t, err)
855+
_, err = catalogModelRepo.Save(model2)
856+
require.NoError(t, err)
857+
_, err = catalogModelRepo.Save(model3)
858+
require.NoError(t, err)
859+
860+
// Test GetFilterOptions
861+
filterOptions, err := dbCatalog.GetFilterOptions(ctx)
862+
require.NoError(t, err)
863+
require.NotNil(t, filterOptions)
864+
require.NotNil(t, filterOptions.Filters)
865+
866+
filters := *filterOptions.Filters
867+
868+
// Should include short properties
869+
assert.Contains(t, filters, "license")
870+
assert.Contains(t, filters, "provider")
871+
assert.Contains(t, filters, "maturity")
872+
assert.Contains(t, filters, "library_name")
873+
assert.Contains(t, filters, "language")
874+
assert.Contains(t, filters, "tasks")
875+
876+
// Should exclude internal/verbose fields
877+
assert.NotContains(t, filters, "source_id", "source_id should be excluded")
878+
assert.NotContains(t, filters, "logo", "logo should be excluded")
879+
assert.NotContains(t, filters, "license_link", "license_link should be excluded")
880+
assert.NotContains(t, filters, "readme", "readme should be excluded (too long)")
881+
882+
licenseFilter := filters["license"]
883+
assert.Equal(t, "string", licenseFilter.Type)
884+
assert.NotNil(t, licenseFilter.Values)
885+
assert.GreaterOrEqual(t, len(licenseFilter.Values), 2, "Should have at least MIT and Apache-2.0")
886+
887+
// Convert to string slice for easier checking
888+
licenseValues := make([]string, 0)
889+
for _, v := range licenseFilter.Values {
890+
if strVal, ok := v.(string); ok {
891+
licenseValues = append(licenseValues, strVal)
892+
}
893+
}
894+
assert.Contains(t, licenseValues, "MIT")
895+
assert.Contains(t, licenseValues, "Apache-2.0")
896+
897+
// Verify provider filter options
898+
providerFilter := filters["provider"]
899+
assert.Equal(t, "string", providerFilter.Type)
900+
providerValues := make([]string, 0)
901+
for _, v := range providerFilter.Values {
902+
if strVal, ok := v.(string); ok {
903+
providerValues = append(providerValues, strVal)
904+
}
905+
}
906+
assert.Contains(t, providerValues, "HuggingFace")
907+
assert.Contains(t, providerValues, "OpenAI")
908+
assert.Contains(t, providerValues, "PyTorch")
909+
910+
// Verify JSON array fields are properly parsed and expanded
911+
languageFilter := filters["language"]
912+
assert.Equal(t, "string", languageFilter.Type)
913+
languageValues := make([]string, 0)
914+
for _, v := range languageFilter.Values {
915+
if strVal, ok := v.(string); ok {
916+
languageValues = append(languageValues, strVal)
917+
}
918+
}
919+
// Should contain individual values from JSON arrays
920+
assert.Contains(t, languageValues, "python")
921+
assert.Contains(t, languageValues, "rust")
922+
assert.Contains(t, languageValues, "javascript")
923+
924+
// Verify tasks are properly expanded
925+
tasksFilter := filters["tasks"]
926+
assert.Equal(t, "string", tasksFilter.Type)
927+
tasksValues := make([]string, 0)
928+
for _, v := range tasksFilter.Values {
929+
if strVal, ok := v.(string); ok {
930+
tasksValues = append(tasksValues, strVal)
931+
}
932+
}
933+
assert.Contains(t, tasksValues, "text-classification")
934+
assert.Contains(t, tasksValues, "token-classification")
935+
assert.Contains(t, tasksValues, "text-generation")
936+
assert.Contains(t, tasksValues, "conversational")
937+
assert.Contains(t, tasksValues, "image-classification")
938+
939+
// Verify no duplicates
940+
pythonCount := 0
941+
for _, v := range languageValues {
942+
if v == "python" {
943+
pythonCount++
944+
}
945+
}
946+
assert.Equal(t, 1, pythonCount, "python should appear only once (deduplicated)")
947+
948+
// Verify maturity options
949+
maturityFilter := filters["maturity"]
950+
maturityValues := make([]string, 0)
951+
for _, v := range maturityFilter.Values {
952+
if strVal, ok := v.(string); ok {
953+
maturityValues = append(maturityValues, strVal)
954+
}
955+
}
956+
assert.Contains(t, maturityValues, "stable")
957+
assert.Contains(t, maturityValues, "experimental")
958+
})
797959
}
798960

799961
// Helper functions to get type IDs from database

catalog/internal/db/service/catalog_model.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,20 +237,23 @@ func mapDataLayerToCatalogModel(modelCtx schema.Context, propertiesCtx []schema.
237237
func (r *CatalogModelRepositoryImpl) GetFilterableProperties(maxLength int) (map[string][]string, error) {
238238
config := r.GetConfig()
239239

240+
// Get table names using GORM utilities for database compatibility
241+
contextTable := utils.GetTableName(config.DB, &schema.Context{})
242+
propertyTable := utils.GetTableName(config.DB, &schema.ContextProperty{})
243+
240244
// Simplified query: get distinct property name/value pairs
241-
query := `
245+
query := fmt.Sprintf(`
242246
SELECT DISTINCT cp.name, cp.string_value
243-
FROM "ContextProperty" cp
247+
FROM %s cp
244248
WHERE cp.context_id IN (
245-
SELECT id FROM "Context" WHERE type_id = ?
249+
SELECT id FROM %s WHERE type_id = ?
246250
)
247251
AND cp.name IN (
248-
-- Only include property names where max length is within threshold
249252
SELECT name FROM (
250253
SELECT name, MAX(CHAR_LENGTH(string_value)) as max_len
251-
FROM "ContextProperty"
254+
FROM %s
252255
WHERE context_id IN (
253-
SELECT id FROM "Context" WHERE type_id = ?
256+
SELECT id FROM %s WHERE type_id = ?
254257
)
255258
AND string_value IS NOT NULL
256259
AND string_value != ''
@@ -261,7 +264,7 @@ func (r *CatalogModelRepositoryImpl) GetFilterableProperties(maxLength int) (map
261264
AND cp.string_value IS NOT NULL
262265
AND cp.string_value != ''
263266
ORDER BY cp.name, cp.string_value
264-
`
267+
`, propertyTable, contextTable, propertyTable, contextTable)
265268

266269
type propertyRow struct {
267270
Name string

catalog/internal/db/service/catalog_model_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,85 @@ func TestCatalogModelRepository(t *testing.T) {
337337
require.NotNil(t, retrieved.GetCustomProperties())
338338
assert.Len(t, *retrieved.GetCustomProperties(), 2)
339339
})
340+
341+
t.Run("TestGetFilterableProperties", func(t *testing.T) {
342+
// Create models with various property lengths
343+
shortValueModel := &models.CatalogModelImpl{
344+
Attributes: &models.CatalogModelAttributes{
345+
Name: apiutils.Of("short-value-model"),
346+
ExternalID: apiutils.Of("short-ext"),
347+
},
348+
Properties: &[]dbmodels.Properties{
349+
{Name: "license", StringValue: apiutils.Of("MIT")},
350+
{Name: "provider", StringValue: apiutils.Of("HuggingFace")},
351+
{Name: "maturity", StringValue: apiutils.Of("stable")},
352+
},
353+
}
354+
355+
longValueModel := &models.CatalogModelImpl{
356+
Attributes: &models.CatalogModelAttributes{
357+
Name: apiutils.Of("long-value-model"),
358+
ExternalID: apiutils.Of("long-ext"),
359+
},
360+
Properties: &[]dbmodels.Properties{
361+
{Name: "license", StringValue: apiutils.Of("Apache-2.0")},
362+
{Name: "readme", StringValue: apiutils.Of("This is a very long readme that should be excluded from filterable properties because it exceeds the maximum length threshold of 100 characters. It contains detailed information about the model.")},
363+
{Name: "description", StringValue: apiutils.Of("This is also a very long description that should be excluded from filterable properties because it exceeds 100 chars")},
364+
},
365+
}
366+
367+
jsonArrayModel := &models.CatalogModelImpl{
368+
Attributes: &models.CatalogModelAttributes{
369+
Name: apiutils.Of("json-array-model"),
370+
ExternalID: apiutils.Of("json-ext"),
371+
},
372+
Properties: &[]dbmodels.Properties{
373+
{Name: "language", StringValue: apiutils.Of(`["python", "go"]`)},
374+
{Name: "tasks", StringValue: apiutils.Of(`["text-classification", "question-answering"]`)},
375+
},
376+
}
377+
378+
_, err := repo.Save(shortValueModel)
379+
require.NoError(t, err)
380+
_, err = repo.Save(longValueModel)
381+
require.NoError(t, err)
382+
_, err = repo.Save(jsonArrayModel)
383+
require.NoError(t, err)
384+
385+
// Test with max length of 100
386+
result, err := repo.GetFilterableProperties(100)
387+
require.NoError(t, err)
388+
require.NotNil(t, result)
389+
390+
// Should include short properties
391+
assert.Contains(t, result, "license")
392+
assert.Contains(t, result, "provider")
393+
assert.Contains(t, result, "maturity")
394+
assert.Contains(t, result, "language")
395+
assert.Contains(t, result, "tasks")
396+
397+
// Should exclude long properties
398+
assert.NotContains(t, result, "readme")
399+
assert.NotContains(t, result, "description")
400+
401+
// Verify license has both values
402+
licenseValues := result["license"]
403+
assert.GreaterOrEqual(t, len(licenseValues), 2)
404+
assert.Contains(t, licenseValues, "MIT")
405+
assert.Contains(t, licenseValues, "Apache-2.0")
406+
407+
// Test with smaller max length
408+
result, err = repo.GetFilterableProperties(10)
409+
require.NoError(t, err)
410+
require.NotNil(t, result)
411+
412+
// Should include only very short properties
413+
assert.Contains(t, result, "license")
414+
// Should exclude longer properties
415+
assert.NotContains(t, result, "provider") // "HuggingFace" is > 10 chars
416+
assert.NotContains(t, result, "language")
417+
assert.NotContains(t, result, "tasks")
418+
})
340419
}
341420

342421
// Helper function to get or create CatalogModel type ID

catalog/internal/server/openapi/api_model_catalog_service_service_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,11 @@ func (m *mockModelProvider) GetArtifacts(ctx context.Context, name string, sourc
776776
}, nil
777777
}
778778

779+
func (m *mockModelProvider) GetFilterOptions(ctx context.Context) (*model.FilterOptionsList, error) {
780+
emptyFilters := make(map[string]model.FilterOption)
781+
return &model.FilterOptionsList{Filters: &emptyFilters}, nil
782+
}
783+
779784
func TestGetModel(t *testing.T) {
780785
testCases := []struct {
781786
name string
@@ -1003,3 +1008,44 @@ func TestGetAllModelArtifacts(t *testing.T) {
10031008
})
10041009
}
10051010
}
1011+
1012+
func TestFindModelsFilterOptions(t *testing.T) {
1013+
testCases := []struct {
1014+
name string
1015+
provider catalog.APIProvider
1016+
expectedStatus int
1017+
expectedError bool
1018+
}{
1019+
{
1020+
name: "Successfully retrieve filter options",
1021+
provider: &mockModelProvider{
1022+
models: map[string]*model.CatalogModel{},
1023+
},
1024+
expectedStatus: http.StatusOK,
1025+
expectedError: false,
1026+
},
1027+
}
1028+
1029+
for _, tc := range testCases {
1030+
t.Run(tc.name, func(t *testing.T) {
1031+
sources := catalog.NewSourceCollection()
1032+
service := NewModelCatalogServiceAPIService(tc.provider, sources)
1033+
1034+
resp, err := service.FindModelsFilterOptions(context.Background())
1035+
1036+
assert.Equal(t, tc.expectedStatus, resp.Code)
1037+
1038+
if tc.expectedError {
1039+
assert.Error(t, err)
1040+
return
1041+
}
1042+
require.NotNil(t, resp.Body)
1043+
1044+
// Type assertion to access the FilterOptionsList
1045+
filterOptions, ok := resp.Body.(*model.FilterOptionsList)
1046+
require.True(t, ok, "Response body should be a FilterOptionsList")
1047+
1048+
require.NotNil(t, filterOptions.Filters)
1049+
})
1050+
}
1051+
}

0 commit comments

Comments
 (0)