diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d152fc9c..4f52df8fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ - `postgresflex`: [v1.0.1](services/postgresflex/CHANGELOG.md#v101-2025-03-12) - **Bugfix:** `DeleteUserWaitHandler` is now also using the region as parameter. +- `modelserving`: [v0.2.0](services/modelserving/CHANGELOG.md#v020-2025-03-14) + - **New**: STACKIT Model Serving module wait handler added. ## Release (2025-03-05) diff --git a/services/modelserving/CHANGELOG.md b/services/modelserving/CHANGELOG.md index 867b8760c..a34eca7a5 100644 --- a/services/modelserving/CHANGELOG.md +++ b/services/modelserving/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.2.0 (2025-03-14) + +- **New**: STACKIT Model Serving module wait handler added. + ## v0.1.0 (2025-02-25) - **New**: STACKIT Model Serving module can be used to manage the STACKIT Model Serving. diff --git a/services/modelserving/go.mod b/services/modelserving/go.mod index 8f8494db6..56fd69ca3 100644 --- a/services/modelserving/go.mod +++ b/services/modelserving/go.mod @@ -2,10 +2,12 @@ module github.com/stackitcloud/stackit-sdk-go/services/modelserving go 1.21 -require github.com/stackitcloud/stackit-sdk-go/core v0.16.0 +require ( + github.com/google/go-cmp v0.7.0 + github.com/stackitcloud/stackit-sdk-go/core v0.16.0 +) require ( github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect ) diff --git a/services/modelserving/wait/wait.go b/services/modelserving/wait/wait.go new file mode 100644 index 000000000..7c7341394 --- /dev/null +++ b/services/modelserving/wait/wait.go @@ -0,0 +1,75 @@ +package wait + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" + "github.com/stackitcloud/stackit-sdk-go/core/wait" + "github.com/stackitcloud/stackit-sdk-go/services/modelserving" +) + +const ( + activeState = "active" +) + +type APIClientInterface interface { + GetTokenExecute(ctx context.Context, region, projectId, tokenId string) (*modelserving.GetTokenResponse, error) +} + +func CreateModelServingWaitHandler(ctx context.Context, a APIClientInterface, region, projectId, tokenId string) *wait.AsyncActionHandler[modelserving.GetTokenResponse] { + handler := wait.New(func() (waitFinished bool, response *modelserving.GetTokenResponse, err error) { + getTokenResp, err := a.GetTokenExecute(ctx, region, projectId, tokenId) + if err != nil { + return false, nil, err + } + if getTokenResp.Token.State == nil { + return false, nil, fmt.Errorf( + "token state is missing for token with id %s", + tokenId, + ) + } + if *getTokenResp.Token.State == activeState { + return true, getTokenResp, nil + } + + return false, nil, nil + }) + + handler.SetTimeout(10 * time.Minute) + + return handler +} + +// UpdateModelServingWaitHandler will wait for the model serving auth token to be updated. +// Eventually it will have a different implementation, but for now it's the same as the create handler. +func UpdateModelServingWaitHandler(ctx context.Context, a APIClientInterface, region, projectId, tokenId string) *wait.AsyncActionHandler[modelserving.GetTokenResponse] { + return CreateModelServingWaitHandler(ctx, a, region, projectId, tokenId) +} + +func DeleteModelServingWaitHandler(ctx context.Context, a APIClientInterface, region, projectId, tokenId string) *wait.AsyncActionHandler[modelserving.GetTokenResponse] { + handler := wait.New( + func() (waitFinished bool, response *modelserving.GetTokenResponse, err error) { + _, err = a.GetTokenExecute(ctx, region, projectId, tokenId) + if err != nil { + var oapiErr *oapierror.GenericOpenAPIError + if errors.As(err, &oapiErr) { + if oapiErr.StatusCode == http.StatusNotFound { + return true, nil, nil + } + } + + return false, nil, err + } + + return false, nil, nil + }, + ) + + handler.SetTimeout(10 * time.Minute) + + return handler +} diff --git a/services/modelserving/wait/wait_test.go b/services/modelserving/wait/wait_test.go new file mode 100644 index 000000000..31fd6f7b1 --- /dev/null +++ b/services/modelserving/wait/wait_test.go @@ -0,0 +1,231 @@ +package wait + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" + "github.com/stackitcloud/stackit-sdk-go/core/utils" + "github.com/stackitcloud/stackit-sdk-go/services/modelserving" +) + +type apiClientMocked struct { + getFails bool + resourceState string + statusCode int +} + +func (a *apiClientMocked) GetTokenExecute(_ context.Context, _, _, _ string) (*modelserving.GetTokenResponse, error) { + if a.getFails { + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: a.statusCode, + } + } + + return &modelserving.GetTokenResponse{ + Token: &modelserving.Token{ + State: utils.Ptr(a.resourceState), + Id: utils.Ptr("tid"), + }, + }, nil +} + +func TestCreateModelServingWaitHandler(t *testing.T) { + tests := []struct { + desc string + getFails bool + statusCode int + resourceState string + wantErr bool + wantResp bool + }{ + { + desc: "create_succeeded", + getFails: false, + statusCode: 200, + resourceState: activeState, + wantErr: false, + wantResp: true, + }, + { + desc: "get_fails", + getFails: true, + statusCode: 500, + resourceState: "", + wantErr: true, + wantResp: false, + }, + { + desc: "timeout", + getFails: false, + statusCode: 200, + resourceState: "ANOTHER_STATE", + wantErr: true, + wantResp: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + apiClient := &apiClientMocked{ + getFails: tt.getFails, + statusCode: tt.statusCode, + resourceState: tt.resourceState, + } + + var wantRes *modelserving.GetTokenResponse + if tt.wantResp { + wantRes = &modelserving.GetTokenResponse{ + Token: &modelserving.Token{ + State: utils.Ptr(tt.resourceState), + Id: utils.Ptr("tid"), + }, + } + } + + handler := CreateModelServingWaitHandler(context.Background(), apiClient, "region", "pid", "tid") + + gotRes, err := handler.SetTimeout(10 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + if !cmp.Equal(gotRes, wantRes) { + t.Fatalf("handler gotRes = %v, want %v", gotRes, wantRes) + } + }) + } +} + +func TestUpdateModelServingWaitHandler(t *testing.T) { + tests := []struct { + desc string + getFails bool + statusCode int + resourceState string + wantErr bool + wantResp bool + }{ + { + desc: "update_succeeded", + getFails: false, + statusCode: 200, + resourceState: activeState, + wantErr: false, + wantResp: true, + }, + { + desc: "get_fails", + getFails: true, + statusCode: 500, + resourceState: "", + wantErr: true, + wantResp: false, + }, + { + desc: "timeout", + getFails: false, + statusCode: 200, + resourceState: "ANOTHER_STATE", + wantErr: true, + wantResp: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + apiClient := &apiClientMocked{ + getFails: tt.getFails, + statusCode: tt.statusCode, + resourceState: tt.resourceState, + } + + var wantRes *modelserving.GetTokenResponse + if tt.wantResp { + wantRes = &modelserving.GetTokenResponse{ + Token: &modelserving.Token{ + State: utils.Ptr(tt.resourceState), + Id: utils.Ptr("tid"), + }, + } + } + + handler := UpdateModelServingWaitHandler(context.Background(), apiClient, "region", "pid", "tid") + + gotRes, err := handler.SetTimeout(10 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + if !cmp.Equal(gotRes, wantRes) { + t.Fatalf("handler gotRes = %v, want %v", gotRes, wantRes) + } + }) + } +} + +func TestDeleteModelServingWaitHandler(t *testing.T) { + tests := []struct { + desc string + getFails bool + statusCode int + resourceState string + wantErr bool + wantResp bool + }{ + { + desc: "delete_succeeded", + getFails: true, + statusCode: 404, + resourceState: "", + wantErr: false, + wantResp: false, + }, + { + desc: "delete_in_progress", + getFails: false, + statusCode: 200, + resourceState: "DELETING", + wantErr: true, // Should timeout since delete is not complete + wantResp: false, + }, + { + desc: "get_fails_with_other_error", + getFails: true, + statusCode: 500, + resourceState: "", + wantErr: true, + wantResp: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + apiClient := &apiClientMocked{ + getFails: tt.getFails, + statusCode: tt.statusCode, + resourceState: tt.resourceState, + } + + var wantRes *modelserving.GetTokenResponse + if tt.wantResp { + wantRes = &modelserving.GetTokenResponse{ + Token: &modelserving.Token{ + State: utils.Ptr(tt.resourceState), + Id: utils.Ptr("tid"), + }, + } + } + + handler := DeleteModelServingWaitHandler(context.Background(), apiClient, "region", "pid", "tid") + + gotRes, err := handler.SetTimeout(10 * time.Millisecond).WaitWithContext(context.Background()) + + if (err != nil) != tt.wantErr { + t.Fatalf("handler error = %v, wantErr %v", err, tt.wantErr) + } + if !cmp.Equal(gotRes, wantRes) { + t.Fatalf("handler gotRes = %v, want %v", gotRes, wantRes) + } + }) + } +}