diff --git a/go.mod b/go.mod index 472e6d593df..b25eccba055 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( ) require ( + dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/MakeNowJust/heredoc v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index dddab9f7e86..07da4c64ee9 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index 564fbedca6c..9a83e93d0b4 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "dario.cat/mergo" "gopkg.in/yaml.v2" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -444,11 +445,45 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { return nil, fmt.Errorf("failed to read config file: %w", err) } - config := newRayClusterConfigWithDefaults() - if err := yaml.UnmarshalStrict(data, &config); err != nil { + var overrideConfig RayClusterConfig + if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } + config, err := mergeWithDefaultConfig(&overrideConfig) + if err != nil { + return nil, err + } + return config, nil +} + +func mergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { + // detach worker groups from default config + overrideConfigWG := overrideConfig.WorkerGroups + overrideConfig.WorkerGroups = nil + config := newRayClusterConfigWithDefaults() + err := mergo.Merge(config, overrideConfig, mergo.WithOverride) + if err != nil { + return nil, fmt.Errorf("failed to merge config with defaults: %w", err) + } + // merge WorkerGroups and keep the default values for missing fields + // if overrideConfigWG is not nil, we will merge the worker groups from the config file + // and keep the default values for missing fields + if overrideConfigWG != nil { + for len(config.WorkerGroups) < len(overrideConfigWG) { + config.WorkerGroups = append(config.WorkerGroups, WorkerGroup{ + Replicas: util.DefaultWorkerReplicas, + CPU: ptr.To(util.DefaultWorkerCPU), + Memory: ptr.To(util.DefaultWorkerMemory), + }) + } + for i, workerGroup := range overrideConfigWG { + err := mergo.Merge(&config.WorkerGroups[i], workerGroup, mergo.WithOverride) + if err != nil { + return nil, fmt.Errorf("failed to merge worker group %d: %w", i, err) + } + } + } return config, nil } diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index 2c16d0d1a31..2125c5c3faf 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -875,8 +875,11 @@ func TestParseConfigFile(t *testing.T) { }, WorkerGroups: []WorkerGroup{ { + Name: ptr.To("default-group"), Replicas: int32(1), + CPU: ptr.To("2"), GPU: ptr.To("1"), + Memory: ptr.To("4Gi"), }, }, }, @@ -984,6 +987,145 @@ gke: }, }, }, + "override ray-version, image, service-account": { + config: ` +ray-version: 4.16.0 +image: custom/image:tag +service-account: svcacct +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To("4.16.0"), + Image: ptr.To("custom/image:tag"), + ServiceAccount: ptr.To("svcacct"), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override only head CPU": { + config: ` +head: + cpu: "8" +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("8"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with only CPU": { + config: ` +worker-groups: +- cpu: "1" +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("1"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with empty name": { + config: ` +worker-groups: +- name: "" + replicas: 2 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 2, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with replicas = 0": { + config: ` +worker-groups: +- name: "wg1" + replicas: 0 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("wg1"), + Replicas: 1, // fallback to default + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override autoscaler": { + config: ` +autoscaler: + version: v2 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + Autoscaler: &Autoscaler{ + Version: AutoscalerV2, + }, + }, + }, } for name, test := range tests {