Skip to content

Commit fa65bae

Browse files
adding model based round robin for DP to CP flow
1 parent 3a0ba96 commit fa65bae

File tree

3 files changed

+243
-133
lines changed

3 files changed

+243
-133
lines changed

apim-apk-agent/internal/constants/constants.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,7 @@ const (
4848
// Version constants
4949
V1 = "v1"
5050
V2 = "v2"
51+
52+
// Policy Types
53+
CommonType = "common"
5154
)

apim-apk-agent/pkg/managementserver/rest_server.go

Lines changed: 191 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,104 @@ func createAPIYaml(apiCPEvent *APICPEvent) (string, string, string) {
162162
provider = config.ControlPlane.Provider
163163
}
164164
context := removeVersionSuffix(apiCPEvent.API.BasePath, apiCPEvent.API.APIVersion)
165-
operations, scopes, operationsErr := extractOperations(*apiCPEvent)
165+
166+
multiEndpoints := apiCPEvent.API.MultiEndpoints
167+
apimEndpints := []APIMEndpoint{}
168+
prodCount := 0
169+
sandCount := 0
170+
primaryProductionEndpointID := ""
171+
primarySandboxEndpointID := ""
172+
primaryProdcutionURL := ""
173+
primarySandboxURL := ""
174+
for _, endpoint := range multiEndpoints.ProdEndpoints {
175+
prodCount++
176+
var endpointName string
177+
if prodCount == 1 {
178+
endpointName = "Default Production Endpoint"
179+
} else {
180+
endpointName = fmt.Sprintf("%d Production Endpoint", prodCount)
181+
}
182+
prodEndpoint := ""
183+
if endpoint.URL != "" {
184+
prodEndpoint = fmt.Sprintf("%s://%s", multiEndpoints.Protocol, endpoint.URL)
185+
}
186+
endpointUUID := uuid.New().String() + "--PRODUCTION"
187+
if prodCount == 1 {
188+
primaryProductionEndpointID = endpointUUID
189+
primaryProdcutionURL = prodEndpoint
190+
}
191+
apimEndpints = append(apimEndpints, APIMEndpoint{
192+
DeploymentStage: "PRODUCTION",
193+
EndpointUUID: endpointUUID,
194+
EndpointName: endpointName,
195+
EndpointConfig: APIMEndpointConfig{
196+
EndpointType: multiEndpoints.Protocol,
197+
ProductionEndpoints: Endpoints{
198+
URL: prodEndpoint,
199+
},
200+
EndpointSecurity: APIMEndpointSecurity{
201+
Production: SecurityConfig{
202+
Enabled: endpoint.SecurityEnabled,
203+
Type: endpoint.SecurityType,
204+
Username: endpoint.BasicUsername,
205+
Password: endpoint.BasicPassword,
206+
APIKeyIdentifier: endpoint.APIKeyName,
207+
APIKeyValue: endpoint.APIKeyValue,
208+
APIKeyIdentifierType: endpoint.APIKeyIn,
209+
ConnectionTimeoutDuration: -1.0,
210+
SocketTimeoutDuration: -1.0,
211+
ConnectionRequestTimeoutDuration: -1.0,
212+
},
213+
},
214+
},
215+
})
216+
}
217+
for _, endpoint := range multiEndpoints.SandEndpoints {
218+
sandCount++
219+
var endpointName string
220+
if sandCount == 1 {
221+
endpointName = "Default Sandbox Endpoint"
222+
} else {
223+
endpointName = fmt.Sprintf("%d Sandbox Endpoint", sandCount)
224+
}
225+
226+
sandEndpoint := ""
227+
if endpoint.URL != "" {
228+
sandEndpoint = fmt.Sprintf("%s://%s", multiEndpoints.Protocol, endpoint.URL)
229+
}
230+
endpointUUID := uuid.New().String() + "--SANDBOX"
231+
if sandCount == 1 {
232+
primarySandboxEndpointID = endpointUUID
233+
primarySandboxURL = sandEndpoint
234+
}
235+
apimEndpints = append(apimEndpints, APIMEndpoint{
236+
DeploymentStage: "SANDBOX",
237+
EndpointUUID: endpointUUID,
238+
EndpointName: endpointName,
239+
EndpointConfig: APIMEndpointConfig{
240+
EndpointType: multiEndpoints.Protocol,
241+
SandboxEndpoints: Endpoints{
242+
URL: sandEndpoint,
243+
},
244+
EndpointSecurity: APIMEndpointSecurity{
245+
Sandbox: SecurityConfig{
246+
Enabled: endpoint.SecurityEnabled,
247+
Type: endpoint.SecurityType,
248+
Username: endpoint.BasicUsername,
249+
Password: endpoint.BasicPassword,
250+
APIKeyIdentifier: endpoint.APIKeyName,
251+
APIKeyValue: endpoint.APIKeyValue,
252+
APIKeyIdentifierType: endpoint.APIKeyIn,
253+
ConnectionTimeoutDuration: -1.0,
254+
SocketTimeoutDuration: -1.0,
255+
ConnectionRequestTimeoutDuration: -1.0,
256+
},
257+
},
258+
},
259+
})
260+
}
261+
262+
operations, scopes, operationsErr := extractOperations(*apiCPEvent, apimEndpints)
166263
if operationsErr != nil {
167264
logger.LoggerMgtServer.Errorf("Error occured while extracting operations from open API: %s, \nError: %+v", apiCPEvent.API.Definition, operationsErr)
168265
operations = []APIOperation{}
@@ -283,6 +380,7 @@ func createAPIYaml(apiCPEvent *APICPEvent) (string, string, string) {
283380
"accessControlExposeHeaders": apiCPEvent.API.CORSPolicy.AccessControlExposeHeaders,
284381
}
285382
}
383+
286384
maxTps := make(map[string]interface{})
287385

288386
// Handle Production fields
@@ -395,101 +493,6 @@ func createAPIYaml(apiCPEvent *APICPEvent) (string, string, string) {
395493
}
396494
}
397495
}
398-
multiEndpoints := apiCPEvent.API.MultiEndpoints
399-
apimEndpints := []APIMEndpoint{}
400-
prodCount := 0
401-
sandCount := 0
402-
primaryProductionEndpointID := ""
403-
primarySandboxEndpointID := ""
404-
primaryProdcutionURL := ""
405-
primarySandboxURL := ""
406-
for _, endpoint := range multiEndpoints.ProdEndpoints {
407-
prodCount++
408-
var endpointName string
409-
if prodCount == 1 {
410-
endpointName = "Default Production Endpoint"
411-
} else {
412-
endpointName = fmt.Sprintf("%d Production Endpoint", prodCount)
413-
}
414-
prodEndpoint := ""
415-
if endpoint.URL != "" {
416-
prodEndpoint = fmt.Sprintf("%s://%s", multiEndpoints.Protocol, endpoint.URL)
417-
}
418-
endpointUUID := uuid.New().String() + "--PRODUCTION"
419-
if prodCount == 1 {
420-
primaryProductionEndpointID = endpointUUID
421-
primaryProdcutionURL = prodEndpoint
422-
}
423-
apimEndpints = append(apimEndpints, APIMEndpoint{
424-
DeploymentStage: "PRODUCTION",
425-
EndpointUUID: endpointUUID,
426-
EndpointName: endpointName,
427-
EndpointConfig: APIMEndpointConfig{
428-
EndpointType: multiEndpoints.Protocol,
429-
ProductionEndpoints: Endpoints{
430-
URL: prodEndpoint,
431-
},
432-
EndpointSecurity: APIMEndpointSecurity{
433-
Production: SecurityConfig{
434-
Enabled: endpoint.SecurityEnabled,
435-
Type: endpoint.SecurityType,
436-
Username: endpoint.BasicUsername,
437-
Password: endpoint.BasicPassword,
438-
APIKeyIdentifier: endpoint.APIKeyName,
439-
APIKeyValue: endpoint.APIKeyValue,
440-
APIKeyIdentifierType: endpoint.APIKeyIn,
441-
ConnectionTimeoutDuration: -1.0,
442-
SocketTimeoutDuration: -1.0,
443-
ConnectionRequestTimeoutDuration: -1.0,
444-
},
445-
},
446-
},
447-
})
448-
}
449-
for _, endpoint := range multiEndpoints.SandEndpoints {
450-
sandCount++
451-
var endpointName string
452-
if sandCount == 1 {
453-
endpointName = "Default Sandbox Endpoint"
454-
} else {
455-
endpointName = fmt.Sprintf("%d Sandbox Endpoint", sandCount)
456-
}
457-
458-
sandEndpoint := ""
459-
if endpoint.URL != "" {
460-
sandEndpoint = fmt.Sprintf("%s://%s", multiEndpoints.Protocol, endpoint.URL)
461-
}
462-
endpointUUID := uuid.New().String() + "--SANDBOX"
463-
if sandCount == 1 {
464-
primarySandboxEndpointID = endpointUUID
465-
primarySandboxURL = sandEndpoint
466-
}
467-
apimEndpints = append(apimEndpints, APIMEndpoint{
468-
DeploymentStage: "SANDBOX",
469-
EndpointUUID: endpointUUID,
470-
EndpointName: endpointName,
471-
EndpointConfig: APIMEndpointConfig{
472-
EndpointType: multiEndpoints.Protocol,
473-
SandboxEndpoints: Endpoints{
474-
URL: sandEndpoint,
475-
},
476-
EndpointSecurity: APIMEndpointSecurity{
477-
Sandbox: SecurityConfig{
478-
Enabled: endpoint.SecurityEnabled,
479-
Type: endpoint.SecurityType,
480-
Username: endpoint.BasicUsername,
481-
Password: endpoint.BasicPassword,
482-
APIKeyIdentifier: endpoint.APIKeyName,
483-
APIKeyValue: endpoint.APIKeyValue,
484-
APIKeyIdentifierType: endpoint.APIKeyIn,
485-
ConnectionTimeoutDuration: -1.0,
486-
SocketTimeoutDuration: -1.0,
487-
ConnectionRequestTimeoutDuration: -1.0,
488-
},
489-
},
490-
},
491-
})
492-
}
493496

494497
dataArr := make([]map[string]interface{}, 0, len(apimEndpints))
495498

@@ -636,9 +639,38 @@ func createAPIYaml(apiCPEvent *APICPEvent) (string, string, string) {
636639
}
637640
}
638641

639-
logger.LoggerMgtServer.Debugf("API Yaml: %+v", data)
642+
var requestOperationPolicies []OperationPolicy
643+
if apiCPEvent.API.AIModelBasedRoundRobin != nil {
644+
aiModelBasedRoundRobin := apiCPEvent.API.AIModelBasedRoundRobin
645+
logger.LoggerMgtServer.Infof("AIModelBasedRoundRobin : %+v", aiModelBasedRoundRobin)
646+
wrr := ModelBasedRoundRobinConfig{
647+
Production: convertAIModelWeightsToModelConfigs(aiModelBasedRoundRobin.ProductionModels, apimEndpints, true),
648+
Sandbox: convertAIModelWeightsToModelConfigs(aiModelBasedRoundRobin.SandboxModels, apimEndpints, false),
649+
SuspendDuration: fmt.Sprintf("%d", aiModelBasedRoundRobin.OnQuotaExceedSuspendDuration),
650+
}
651+
jsonBytes, err := json.Marshal(wrr)
652+
if err != nil {
653+
logger.LoggerMgtServer.Errorf("Error marshaling WeightedRoundRobinConfigs to JSON: %+v", err)
654+
}
655+
jsonStr := string(jsonBytes)
656+
singleQuoted := strings.ReplaceAll(jsonStr, `"`, `'`)
657+
apiPolicy := OperationPolicy{
658+
PolicyName: constants.ModelWeightedRoundRobin,
659+
PolicyVersion: constants.V1,
660+
PolicyType: constants.CommonType,
661+
Parameters: WeightedRoundRobinConfigs{
662+
WeightedRoundRobinConfigs: singleQuoted,
663+
},
664+
}
665+
requestOperationPolicies = append(requestOperationPolicies, apiPolicy)
666+
}
667+
data["data"].(map[string]interface{})["apiPolicies"] = OperationPolicies{
668+
Request: requestOperationPolicies,
669+
}
670+
671+
logger.LoggerMgtServer.Infof("API Yaml: %+v", data)
640672
yamlBytes, _ := yaml.Marshal(data)
641-
logger.LoggerMgtServer.Debugf("Endpoint Yaml: %v", endpointsData)
673+
logger.LoggerMgtServer.Infof("Endpoint Yaml: %v", endpointsData)
642674
endpointBytes, _ := yaml.Marshal(endpointsData)
643675
return string(yamlBytes), definition, string(endpointBytes)
644676
}
@@ -691,6 +723,7 @@ type OperationPolicy struct {
691723
PolicyName string `yaml:"policyName"`
692724
PolicyVersion string `yaml:"policyVersion"`
693725
PolicyID string `yaml:"policyId,omitempty"`
726+
PolicyType string `yaml:"policyType,omitempty"`
694727
Parameters FilterParameters `yaml:"parameters"`
695728
}
696729

@@ -699,6 +732,29 @@ type FilterParameters interface {
699732
isFilterParameters()
700733
}
701734

735+
func (m WeightedRoundRobinConfigs) isFilterParameters() {}
736+
737+
// WeightedRoundRobinConfigs holds any additional parameter data for a RequestPolicy
738+
type WeightedRoundRobinConfigs struct {
739+
WeightedRoundRobinConfigs string `yaml:"weightedRoundRobinConfigs"`
740+
}
741+
742+
func (m ModelBasedRoundRobinConfig) isFilterParameters() {}
743+
744+
// ModelConfig holds the configuration details of a model
745+
type ModelConfig struct {
746+
Model string `json:"model" yaml:"model"`
747+
EndpointID string `json:"endpointId" yaml:"endpointId"`
748+
Weight int `json:"weight" yaml:"weight"`
749+
}
750+
751+
// ModelBasedRoundRobinConfig holds the configuration details of the transformer
752+
type ModelBasedRoundRobinConfig struct {
753+
Production []ModelConfig `json:"production" yaml:"production"`
754+
Sandbox []ModelConfig `json:"sandbox" yaml:"sandbox"`
755+
SuspendDuration string `json:"suspendDuration" yaml:"suspendDuration"`
756+
}
757+
702758
func (h Header) isFilterParameters() {}
703759

704760
// Header contains the request and response header modifier information
@@ -757,7 +813,30 @@ type Scope struct {
757813
Bindings []string `yaml:"bindings"`
758814
}
759815

760-
func extractOperations(event APICPEvent) ([]APIOperation, []ScopeWrapper, error) {
816+
func convertAIModelWeightsToModelConfigs(weights []AIModelWeight, apimEndpoints []APIMEndpoint, isProd bool) []ModelConfig {
817+
var configs []ModelConfig
818+
for _, weight := range weights {
819+
var endpointID string
820+
for _, endpoint := range apimEndpoints {
821+
if endpoint.EndpointConfig.ProductionEndpoints.URL == weight.Endpoint {
822+
endpointID = endpoint.EndpointUUID
823+
break
824+
}
825+
if endpoint.EndpointConfig.SandboxEndpoints.URL == weight.Endpoint {
826+
endpointID = endpoint.EndpointUUID
827+
break
828+
}
829+
}
830+
configs = append(configs, ModelConfig{
831+
Model: weight.Model,
832+
EndpointID: endpointID,
833+
Weight: weight.Weight,
834+
})
835+
}
836+
return configs
837+
}
838+
839+
func extractOperations(event APICPEvent, apimEndpoints []APIMEndpoint) ([]APIOperation, []ScopeWrapper, error) {
761840
var apiOperations []APIOperation
762841
var requestOperationPolicies []OperationPolicy
763842
var responseOperationPolicies []OperationPolicy
@@ -793,11 +872,23 @@ func extractOperations(event APICPEvent) ([]APIOperation, []ScopeWrapper, error)
793872
Name: scope,
794873
DisplayName: scope,
795874
Description: scope,
796-
Bindings: []string{},
797875
},
798876
Shared: false,
799877
}
800878
}
879+
aiModelBasedRoundRobin := operationFromDP.AIModelBasedRoundRobin
880+
if aiModelBasedRoundRobin != nil {
881+
operationPolicy := OperationPolicy{
882+
PolicyName: constants.ModelWeightedRoundRobin,
883+
PolicyVersion: constants.V1,
884+
Parameters: ModelBasedRoundRobinConfig{
885+
Production: convertAIModelWeightsToModelConfigs(aiModelBasedRoundRobin.ProductionModels, apimEndpoints, true),
886+
Sandbox: convertAIModelWeightsToModelConfigs(aiModelBasedRoundRobin.SandboxModels, apimEndpoints, false),
887+
SuspendDuration: fmt.Sprintf("%d", aiModelBasedRoundRobin.OnQuotaExceedSuspendDuration),
888+
},
889+
}
890+
requestOperationPolicies = append(requestOperationPolicies, operationPolicy)
891+
}
801892
// Process filters
802893
for _, operationLevelFilter := range operationFromDP.Filters {
803894
switch filter := operationLevelFilter.(type) {

0 commit comments

Comments
 (0)