From 416e69ab05120fa071f2fe7d4411d617bf6d014e Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Fri, 20 Jun 2025 02:10:38 +0800 Subject: [PATCH 1/6] fix: input config does not keep default value --- .../pkg/util/generation/generation.go | 142 +++++++++- .../pkg/util/generation/generation_test.go | 251 ++++++++++++++++++ 2 files changed, 390 insertions(+), 3 deletions(-) diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index 564fbedca6c..67ed46331bc 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -444,14 +444,150 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { return nil, fmt.Errorf("failed to read config file: %w", err) } - config := newRayClusterConfigWithDefaults() - if err := yaml.UnmarshalStrict(data, &config); err != nil { + var overrideConfig RayClusterConfig + if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } - + config := MergeWithDefaults(&overrideConfig) return config, nil } +func MergeWithDefaults(overrideConfig *RayClusterConfig) *RayClusterConfig { + config := newRayClusterConfigWithDefaults() + + if overrideConfig.Namespace != nil { + config.Namespace = overrideConfig.Namespace + } + if overrideConfig.Name != nil { + config.Name = overrideConfig.Name + } + if overrideConfig.Labels != nil { + if config.Labels == nil { + config.Labels = make(map[string]string) + } + maps.Copy(config.Labels, overrideConfig.Labels) + } + if overrideConfig.Annotations != nil { + if config.Annotations == nil { + config.Annotations = make(map[string]string) + } + maps.Copy(config.Annotations, overrideConfig.Annotations) + } + if overrideConfig.RayVersion != nil { + config.RayVersion = overrideConfig.RayVersion + } + if overrideConfig.Image != nil { + config.Image = overrideConfig.Image + } + if overrideConfig.ServiceAccount != nil { + config.ServiceAccount = overrideConfig.ServiceAccount + } + if overrideConfig.Head != nil { + if overrideConfig.Head.CPU != nil { + config.Head.CPU = overrideConfig.Head.CPU + } + if overrideConfig.Head.GPU != nil { + config.Head.GPU = overrideConfig.Head.GPU + } + if overrideConfig.Head.Memory != nil { + config.Head.Memory = overrideConfig.Head.Memory + } + if overrideConfig.Head.EphemeralStorage != nil { + config.Head.EphemeralStorage = overrideConfig.Head.EphemeralStorage + } + if overrideConfig.Head.RayStartParams != nil { + if config.Head.RayStartParams == nil { + config.Head.RayStartParams = make(map[string]string) + } + maps.Copy(config.Head.RayStartParams, overrideConfig.Head.RayStartParams) + } + if overrideConfig.Head.NodeSelectors != nil { + if config.Head.NodeSelectors == nil { + config.Head.NodeSelectors = make(map[string]string) + } + maps.Copy(config.Head.NodeSelectors, overrideConfig.Head.NodeSelectors) + } + } + if overrideConfig.GKE != nil { + if config.GKE == nil { + config.GKE = &GKE{} + } + if overrideConfig.GKE.GCSFuse != nil { + if config.GKE.GCSFuse == nil { + config.GKE.GCSFuse = &GCSFuse{} + } + config.GKE.GCSFuse.MountOptions = overrideConfig.GKE.GCSFuse.MountOptions + config.GKE.GCSFuse.DisableMetrics = overrideConfig.GKE.GCSFuse.DisableMetrics + config.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount = overrideConfig.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount + config.GKE.GCSFuse.SkipCSIBucketAccessCheck = overrideConfig.GKE.GCSFuse.SkipCSIBucketAccessCheck + if overrideConfig.GKE.GCSFuse.Resources != nil { + if config.GKE.GCSFuse.Resources == nil { + config.GKE.GCSFuse.Resources = &GCSFuseResources{} + } + config.GKE.GCSFuse.Resources.CPU = overrideConfig.GKE.GCSFuse.Resources.CPU + config.GKE.GCSFuse.Resources.Memory = overrideConfig.GKE.GCSFuse.Resources.Memory + config.GKE.GCSFuse.Resources.EphemeralStorage = overrideConfig.GKE.GCSFuse.Resources.EphemeralStorage + } + config.GKE.GCSFuse.BucketName = overrideConfig.GKE.GCSFuse.BucketName + config.GKE.GCSFuse.MountPath = overrideConfig.GKE.GCSFuse.MountPath + } + } + if overrideConfig.Autoscaler != nil { + if config.Autoscaler == nil { + config.Autoscaler = &Autoscaler{} + } + config.Autoscaler.Version = overrideConfig.Autoscaler.Version + } + if overrideConfig.WorkerGroups != nil { + for len(config.WorkerGroups) < len(overrideConfig.WorkerGroups) { + config.WorkerGroups = append(config.WorkerGroups, WorkerGroup{ + Replicas: util.DefaultWorkerReplicas, + CPU: ptr.To(util.DefaultWorkerCPU), + Memory: ptr.To(util.DefaultWorkerMemory), + }) + } + for i, workerGroup := range overrideConfig.WorkerGroups { + if workerGroup.Name != nil && *workerGroup.Name != "" { + config.WorkerGroups[i].Name = workerGroup.Name + } + if workerGroup.CPU != nil { + config.WorkerGroups[i].CPU = workerGroup.CPU + } + if workerGroup.GPU != nil { + config.WorkerGroups[i].GPU = workerGroup.GPU + } + if workerGroup.TPU != nil { + config.WorkerGroups[i].TPU = workerGroup.TPU + } + if workerGroup.NumOfHosts != nil { + config.WorkerGroups[i].NumOfHosts = workerGroup.NumOfHosts + } + if workerGroup.Memory != nil { + config.WorkerGroups[i].Memory = workerGroup.Memory + } + if workerGroup.EphemeralStorage != nil { + config.WorkerGroups[i].EphemeralStorage = workerGroup.EphemeralStorage + } + if workerGroup.RayStartParams != nil { + if config.WorkerGroups[i].RayStartParams == nil { + config.WorkerGroups[i].RayStartParams = make(map[string]string) + } + maps.Copy(config.WorkerGroups[i].RayStartParams, workerGroup.RayStartParams) + } + if workerGroup.NodeSelectors != nil { + if config.WorkerGroups[i].NodeSelectors == nil { + config.WorkerGroups[i].NodeSelectors = make(map[string]string) + } + maps.Copy(config.WorkerGroups[i].NodeSelectors, workerGroup.NodeSelectors) + } + if workerGroup.Replicas > 0 { + config.WorkerGroups[i].Replicas = workerGroup.Replicas + } + } + } + return config +} + // ValidateConfig validates the RayClusterConfig object func ValidateConfig(config *RayClusterConfig) error { // Validate head resource quantities diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index 2c16d0d1a31..5232f3ec327 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -875,8 +875,11 @@ func TestParseConfigFile(t *testing.T) { }, WorkerGroups: []WorkerGroup{ { + Name: ptr.To("default-group"), Replicas: int32(1), + CPU: ptr.To("2"), GPU: ptr.To("1"), + Memory: ptr.To("4Gi"), }, }, }, @@ -1108,3 +1111,251 @@ func TestGetGCSFuseVolumeAttributes(t *testing.T) { result := getGCSFuseVolumeAttributes(config) assert.Equal(t, expected, result) } + +func TestMergeWithDefaults(t *testing.T) { + defaultRayVersion := util.RayVersion + defaultImage := fmt.Sprintf("rayproject/ray:%s", util.RayVersion) + + t.Run("Empty RayClusterConfig and return default RayClusterConfig", func(t *testing.T) { + result := MergeWithDefaults(&RayClusterConfig{}) + expected := newRayClusterConfigWithDefaults() + assert.Equal(t, expected, result) + }) + + t.Run("Override namespace, name, labels, annotations", func(t *testing.T) { + inputNamespace := ptr.To("test-namespace") + inputName := ptr.To("test-name") + inputLabels := map[string]string{"key1": "value1", "key2": "value2"} + inputAnnotations := map[string]string{"annotation1": "value1", "annotation2": "value2"} + + override := &RayClusterConfig{ + Namespace: inputNamespace, + Name: inputName, + Labels: inputLabels, + Annotations: inputAnnotations, + } + result := MergeWithDefaults(override) + assert.Equal(t, inputNamespace, result.Namespace) + assert.Equal(t, inputName, result.Name) + assert.Equal(t, inputLabels, result.Labels) + assert.Equal(t, inputAnnotations, result.Annotations) + }) + + t.Run("Override RayVersion, Image, ServiceAccount", func(t *testing.T) { + inputRayVersion := ptr.To("4.16.0") + inputImage := ptr.To("custom/image:tag") + inputServiceAccount := ptr.To("svcacct") + + override := &RayClusterConfig{ + RayVersion: inputRayVersion, + Image: inputImage, + ServiceAccount: inputServiceAccount, + } + result := MergeWithDefaults(override) + assert.Equal(t, inputRayVersion, result.RayVersion) + assert.Equal(t, inputImage, result.Image) + assert.Equal(t, inputServiceAccount, result.ServiceAccount) + }) + + t.Run("Override Head fields", func(t *testing.T) { + headCPU := ptr.To("4") + headGPU := ptr.To("2") + headMemory := ptr.To("8Gi") + headEphemeralStorage := ptr.To("20Gi") + headRayStartParams := map[string]string{"foo": "bar"} + headNodeSelectors := map[string]string{"disktype": "ssd"} + + override := &RayClusterConfig{ + Head: &Head{ + CPU: headCPU, + GPU: headGPU, + Memory: headMemory, + EphemeralStorage: headEphemeralStorage, + RayStartParams: headRayStartParams, + NodeSelectors: headNodeSelectors, + }, + } + result := MergeWithDefaults(override) + assert.Equal(t, headCPU, result.Head.CPU) + assert.Equal(t, headGPU, result.Head.GPU) + assert.Equal(t, headMemory, result.Head.Memory) + assert.Equal(t, headEphemeralStorage, result.Head.EphemeralStorage) + assert.Equal(t, headRayStartParams, result.Head.RayStartParams) + assert.Equal(t, headNodeSelectors, result.Head.NodeSelectors) + }) + + t.Run("Override only some fields in Head, others remain default", func(t *testing.T) { + headCPU := ptr.To("8") + + override := &RayClusterConfig{ + Head: &Head{ + CPU: headCPU, + }, + } + result := MergeWithDefaults(override) + assert.Equal(t, headCPU, result.Head.CPU) + assert.Equal(t, ptr.To("4Gi"), result.Head.Memory) + assert.Equal(t, defaultRayVersion, *result.RayVersion) + assert.Equal(t, defaultImage, *result.Image) + }) + + t.Run("Override GKE.GCSFuse fields", func(t *testing.T) { + gcsFuseMountOption := ptr.To("opt1") + gcsFuseDisableMetrics := ptr.To(true) + gcsFuseMetadataPrefetchOnMount := ptr.To(true) + gcsFuseSkipCSIBucketAccessCheck := ptr.To(true) + gcsFuseBucketName := "bucket" + gcsFuseMountPath := "/mnt/path" + gcsFuseCPU := ptr.To("1") + gcsFuseMemory := ptr.To("2Gi") + gcsFuseEphemeralStorage := ptr.To("3Gi") + gcsFuseResources := &GCSFuseResources{ + CPU: gcsFuseCPU, + Memory: gcsFuseMemory, + EphemeralStorage: gcsFuseEphemeralStorage, + } + + override := &RayClusterConfig{ + GKE: &GKE{ + GCSFuse: &GCSFuse{ + MountOptions: gcsFuseMountOption, + DisableMetrics: gcsFuseDisableMetrics, + GCSFuseMetadataPrefetchOnMount: gcsFuseMetadataPrefetchOnMount, + SkipCSIBucketAccessCheck: gcsFuseSkipCSIBucketAccessCheck, + BucketName: gcsFuseBucketName, + MountPath: gcsFuseMountPath, + Resources: &GCSFuseResources{ + CPU: gcsFuseCPU, + Memory: gcsFuseMemory, + EphemeralStorage: gcsFuseEphemeralStorage, + }, + }, + }, + } + result := MergeWithDefaults(override) + assert.NotNil(t, result.GKE) + assert.NotNil(t, result.GKE.GCSFuse) + assert.Equal(t, gcsFuseMountOption, result.GKE.GCSFuse.MountOptions) + assert.Equal(t, gcsFuseDisableMetrics, result.GKE.GCSFuse.DisableMetrics) + assert.Equal(t, gcsFuseMetadataPrefetchOnMount, result.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount) + assert.Equal(t, gcsFuseSkipCSIBucketAccessCheck, result.GKE.GCSFuse.SkipCSIBucketAccessCheck) + assert.Equal(t, gcsFuseBucketName, result.GKE.GCSFuse.BucketName) + assert.Equal(t, gcsFuseMountPath, result.GKE.GCSFuse.MountPath) + assert.Equal(t, gcsFuseResources, result.GKE.GCSFuse.Resources) + assert.Equal(t, gcsFuseCPU, result.GKE.GCSFuse.Resources.CPU) + assert.Equal(t, gcsFuseMemory, result.GKE.GCSFuse.Resources.Memory) + assert.Equal(t, gcsFuseEphemeralStorage, result.GKE.GCSFuse.Resources.EphemeralStorage) + }) + + t.Run("Override Autoscaler", func(t *testing.T) { + override := &RayClusterConfig{ + Autoscaler: &Autoscaler{Version: AutoscalerV2}, + } + result := MergeWithDefaults(override) + assert.NotNil(t, result.Autoscaler) + assert.Equal(t, AutoscalerV2, result.Autoscaler.Version) + }) + + t.Run("Override WorkerGroups fields", func(t *testing.T) { + wgName1 := ptr.To("wg1") + wgCPU := ptr.To("5") + wgGPU := ptr.To("1") + wgTPU := ptr.To("2") + wgNumOfHosts := ptr.To(int32(3)) + wgMemory := ptr.To("16Gi") + wgEphemeralStorage := ptr.To("30Gi") + wgRayStartParams := map[string]string{"param": "val"} + wgNodeSelectors := map[string]string{"zone": "us-central1-a"} + wgReplicas := int32(7) + + override := &RayClusterConfig{ + WorkerGroups: []WorkerGroup{ + { + Name: wgName1, + CPU: wgCPU, + GPU: wgGPU, + TPU: wgTPU, + NumOfHosts: wgNumOfHosts, + Memory: wgMemory, + EphemeralStorage: wgEphemeralStorage, + RayStartParams: wgRayStartParams, + NodeSelectors: wgNodeSelectors, + Replicas: wgReplicas, + }, + }, + } + result := MergeWithDefaults(override) + require.Len(t, result.WorkerGroups, 1) + wg := result.WorkerGroups[0] + assert.Equal(t, wgName1, wg.Name) + assert.Equal(t, wgCPU, wg.CPU) + assert.Equal(t, wgGPU, wg.GPU) + assert.Equal(t, wgTPU, wg.TPU) + assert.Equal(t, wgNumOfHosts, wg.NumOfHosts) + assert.Equal(t, wgMemory, wg.Memory) + assert.Equal(t, wgEphemeralStorage, wg.EphemeralStorage) + assert.Equal(t, wgRayStartParams, wg.RayStartParams) + assert.Equal(t, wgNodeSelectors, wg.NodeSelectors) + assert.Equal(t, wgReplicas, wg.Replicas) + }) + + t.Run("Override WorkerGroups with more groups than defaults", func(t *testing.T) { + wg1Name := ptr.To("wg1") + wg2Name := ptr.To("wg2") + wg1Replicas := int32(2) + wg2Replicas := int32(3) + + override := &RayClusterConfig{ + WorkerGroups: []WorkerGroup{ + {Name: wg1Name, Replicas: wg1Replicas}, + {Name: wg2Name, Replicas: wg2Replicas}, + }, + } + result := MergeWithDefaults(override) + require.Len(t, result.WorkerGroups, 2) + assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) + assert.Equal(t, wg1Replicas, result.WorkerGroups[0].Replicas) + assert.Equal(t, wg2Name, result.WorkerGroups[1].Name) + assert.Equal(t, wg2Replicas, result.WorkerGroups[1].Replicas) + }) + + t.Run("Override WorkerGroups with zero replicas keeps default", func(t *testing.T) { + wg1Name := ptr.To("wg1") + + override := &RayClusterConfig{ + WorkerGroups: []WorkerGroup{ + {Name: wg1Name, Replicas: 0}, + }, + } + result := MergeWithDefaults(override) + require.Len(t, result.WorkerGroups, 1) + assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) + assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) + }) + + t.Run("Override WorkerGroups with empty name keeps default name", func(t *testing.T) { + override := &RayClusterConfig{ + WorkerGroups: []WorkerGroup{ + {Name: nil, Replicas: 2}, + }, + } + result := MergeWithDefaults(override) + require.Len(t, result.WorkerGroups, 1) + assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) + assert.Equal(t, int32(2), result.WorkerGroups[0].Replicas) + }) + + t.Run("Override only WorkerGroups CPU", func(t *testing.T) { + override := &RayClusterConfig{ + WorkerGroups: []WorkerGroup{ + {CPU: ptr.To("1")}, + }, + } + result := MergeWithDefaults(override) + require.Len(t, result.WorkerGroups, 1) + assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) + assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) + assert.Equal(t, ptr.To("1"), result.WorkerGroups[0].CPU) + assert.Equal(t, ptr.To("4Gi"), result.WorkerGroups[0].Memory) + }) +} From 5434bc270cc914e81daa71366ed9e9bdd6bc7420 Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Sat, 21 Jun 2025 02:30:49 +0800 Subject: [PATCH 2/6] refactor: merge nested struct config --- go.mod | 1 + go.sum | 2 + .../pkg/util/generation/generation.go | 137 ++++-------------- .../pkg/util/generation/generation_test.go | 48 ++++-- 4 files changed, 68 insertions(+), 120 deletions(-) diff --git a/go.mod b/go.mod index 472e6d593df..b25eccba055 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( ) require ( + dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/MakeNowJust/heredoc v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index dddab9f7e86..07da4c64ee9 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index 67ed46331bc..637cdfa630e 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "dario.cat/mergo" "gopkg.in/yaml.v2" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -448,96 +449,48 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } - config := MergeWithDefaults(&overrideConfig) + config, err := MergeWithDefaultConfig(&overrideConfig) + if err != nil { + return nil, fmt.Errorf("failed to merge config with defaults: %w", err) + } return config, nil } -func MergeWithDefaults(overrideConfig *RayClusterConfig) *RayClusterConfig { +func MergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { config := newRayClusterConfigWithDefaults() - if overrideConfig.Namespace != nil { - config.Namespace = overrideConfig.Namespace - } - if overrideConfig.Name != nil { - config.Name = overrideConfig.Name + // The defaults are not set in the default raycluster config, + // so we directly copy the values from overrideConfig + config.Namespace = overrideConfig.Namespace + config.Name = overrideConfig.Name + config.ServiceAccount = overrideConfig.ServiceAccount + config.GKE = overrideConfig.GKE + config.Autoscaler = overrideConfig.Autoscaler + + if overrideConfig.RayVersion != nil { + config.RayVersion = overrideConfig.RayVersion } + if overrideConfig.Labels != nil { - if config.Labels == nil { - config.Labels = make(map[string]string) - } + config.Labels = make(map[string]string) maps.Copy(config.Labels, overrideConfig.Labels) } if overrideConfig.Annotations != nil { - if config.Annotations == nil { - config.Annotations = make(map[string]string) - } + config.Annotations = make(map[string]string) maps.Copy(config.Annotations, overrideConfig.Annotations) } - if overrideConfig.RayVersion != nil { - config.RayVersion = overrideConfig.RayVersion - } + if overrideConfig.Image != nil { config.Image = overrideConfig.Image } - if overrideConfig.ServiceAccount != nil { - config.ServiceAccount = overrideConfig.ServiceAccount - } + if overrideConfig.Head != nil { - if overrideConfig.Head.CPU != nil { - config.Head.CPU = overrideConfig.Head.CPU - } - if overrideConfig.Head.GPU != nil { - config.Head.GPU = overrideConfig.Head.GPU - } - if overrideConfig.Head.Memory != nil { - config.Head.Memory = overrideConfig.Head.Memory - } - if overrideConfig.Head.EphemeralStorage != nil { - config.Head.EphemeralStorage = overrideConfig.Head.EphemeralStorage - } - if overrideConfig.Head.RayStartParams != nil { - if config.Head.RayStartParams == nil { - config.Head.RayStartParams = make(map[string]string) - } - maps.Copy(config.Head.RayStartParams, overrideConfig.Head.RayStartParams) - } - if overrideConfig.Head.NodeSelectors != nil { - if config.Head.NodeSelectors == nil { - config.Head.NodeSelectors = make(map[string]string) - } - maps.Copy(config.Head.NodeSelectors, overrideConfig.Head.NodeSelectors) - } - } - if overrideConfig.GKE != nil { - if config.GKE == nil { - config.GKE = &GKE{} - } - if overrideConfig.GKE.GCSFuse != nil { - if config.GKE.GCSFuse == nil { - config.GKE.GCSFuse = &GCSFuse{} - } - config.GKE.GCSFuse.MountOptions = overrideConfig.GKE.GCSFuse.MountOptions - config.GKE.GCSFuse.DisableMetrics = overrideConfig.GKE.GCSFuse.DisableMetrics - config.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount = overrideConfig.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount - config.GKE.GCSFuse.SkipCSIBucketAccessCheck = overrideConfig.GKE.GCSFuse.SkipCSIBucketAccessCheck - if overrideConfig.GKE.GCSFuse.Resources != nil { - if config.GKE.GCSFuse.Resources == nil { - config.GKE.GCSFuse.Resources = &GCSFuseResources{} - } - config.GKE.GCSFuse.Resources.CPU = overrideConfig.GKE.GCSFuse.Resources.CPU - config.GKE.GCSFuse.Resources.Memory = overrideConfig.GKE.GCSFuse.Resources.Memory - config.GKE.GCSFuse.Resources.EphemeralStorage = overrideConfig.GKE.GCSFuse.Resources.EphemeralStorage - } - config.GKE.GCSFuse.BucketName = overrideConfig.GKE.GCSFuse.BucketName - config.GKE.GCSFuse.MountPath = overrideConfig.GKE.GCSFuse.MountPath + err := mergo.Merge(config.Head, overrideConfig.Head, mergo.WithOverride) + if err != nil { + return nil, fmt.Errorf("failed to merge head config: %w", err) } } - if overrideConfig.Autoscaler != nil { - if config.Autoscaler == nil { - config.Autoscaler = &Autoscaler{} - } - config.Autoscaler.Version = overrideConfig.Autoscaler.Version - } + if overrideConfig.WorkerGroups != nil { for len(config.WorkerGroups) < len(overrideConfig.WorkerGroups) { config.WorkerGroups = append(config.WorkerGroups, WorkerGroup{ @@ -547,45 +500,13 @@ func MergeWithDefaults(overrideConfig *RayClusterConfig) *RayClusterConfig { }) } for i, workerGroup := range overrideConfig.WorkerGroups { - if workerGroup.Name != nil && *workerGroup.Name != "" { - config.WorkerGroups[i].Name = workerGroup.Name - } - if workerGroup.CPU != nil { - config.WorkerGroups[i].CPU = workerGroup.CPU - } - if workerGroup.GPU != nil { - config.WorkerGroups[i].GPU = workerGroup.GPU - } - if workerGroup.TPU != nil { - config.WorkerGroups[i].TPU = workerGroup.TPU - } - if workerGroup.NumOfHosts != nil { - config.WorkerGroups[i].NumOfHosts = workerGroup.NumOfHosts - } - if workerGroup.Memory != nil { - config.WorkerGroups[i].Memory = workerGroup.Memory - } - if workerGroup.EphemeralStorage != nil { - config.WorkerGroups[i].EphemeralStorage = workerGroup.EphemeralStorage - } - if workerGroup.RayStartParams != nil { - if config.WorkerGroups[i].RayStartParams == nil { - config.WorkerGroups[i].RayStartParams = make(map[string]string) - } - maps.Copy(config.WorkerGroups[i].RayStartParams, workerGroup.RayStartParams) - } - if workerGroup.NodeSelectors != nil { - if config.WorkerGroups[i].NodeSelectors == nil { - config.WorkerGroups[i].NodeSelectors = make(map[string]string) - } - maps.Copy(config.WorkerGroups[i].NodeSelectors, workerGroup.NodeSelectors) - } - if workerGroup.Replicas > 0 { - config.WorkerGroups[i].Replicas = workerGroup.Replicas + err := mergo.Merge(&config.WorkerGroups[i], workerGroup, mergo.WithOverride) + if err != nil { + return nil, fmt.Errorf("failed to merge worker group %d: %w", i, err) } } } - return config + return config, nil } // ValidateConfig validates the RayClusterConfig object diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index 5232f3ec327..d3fa3921bd4 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -1117,7 +1117,9 @@ func TestMergeWithDefaults(t *testing.T) { defaultImage := fmt.Sprintf("rayproject/ray:%s", util.RayVersion) t.Run("Empty RayClusterConfig and return default RayClusterConfig", func(t *testing.T) { - result := MergeWithDefaults(&RayClusterConfig{}) + result, err := MergeWithDefaultConfig(&RayClusterConfig{}) + require.NoError(t, err) + assert.NotNil(t, result) expected := newRayClusterConfigWithDefaults() assert.Equal(t, expected, result) }) @@ -1134,7 +1136,9 @@ func TestMergeWithDefaults(t *testing.T) { Labels: inputLabels, Annotations: inputAnnotations, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.Equal(t, inputNamespace, result.Namespace) assert.Equal(t, inputName, result.Name) assert.Equal(t, inputLabels, result.Labels) @@ -1151,7 +1155,9 @@ func TestMergeWithDefaults(t *testing.T) { Image: inputImage, ServiceAccount: inputServiceAccount, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.Equal(t, inputRayVersion, result.RayVersion) assert.Equal(t, inputImage, result.Image) assert.Equal(t, inputServiceAccount, result.ServiceAccount) @@ -1175,7 +1181,9 @@ func TestMergeWithDefaults(t *testing.T) { NodeSelectors: headNodeSelectors, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.Equal(t, headCPU, result.Head.CPU) assert.Equal(t, headGPU, result.Head.GPU) assert.Equal(t, headMemory, result.Head.Memory) @@ -1192,7 +1200,9 @@ func TestMergeWithDefaults(t *testing.T) { CPU: headCPU, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.Equal(t, headCPU, result.Head.CPU) assert.Equal(t, ptr.To("4Gi"), result.Head.Memory) assert.Equal(t, defaultRayVersion, *result.RayVersion) @@ -1232,7 +1242,9 @@ func TestMergeWithDefaults(t *testing.T) { }, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.NotNil(t, result.GKE) assert.NotNil(t, result.GKE.GCSFuse) assert.Equal(t, gcsFuseMountOption, result.GKE.GCSFuse.MountOptions) @@ -1251,7 +1263,9 @@ func TestMergeWithDefaults(t *testing.T) { override := &RayClusterConfig{ Autoscaler: &Autoscaler{Version: AutoscalerV2}, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) assert.NotNil(t, result.Autoscaler) assert.Equal(t, AutoscalerV2, result.Autoscaler.Version) }) @@ -1284,7 +1298,9 @@ func TestMergeWithDefaults(t *testing.T) { }, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) wg := result.WorkerGroups[0] assert.Equal(t, wgName1, wg.Name) @@ -1311,7 +1327,9 @@ func TestMergeWithDefaults(t *testing.T) { {Name: wg2Name, Replicas: wg2Replicas}, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 2) assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) assert.Equal(t, wg1Replicas, result.WorkerGroups[0].Replicas) @@ -1327,7 +1345,9 @@ func TestMergeWithDefaults(t *testing.T) { {Name: wg1Name, Replicas: 0}, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) @@ -1339,7 +1359,9 @@ func TestMergeWithDefaults(t *testing.T) { {Name: nil, Replicas: 2}, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) assert.Equal(t, int32(2), result.WorkerGroups[0].Replicas) @@ -1351,7 +1373,9 @@ func TestMergeWithDefaults(t *testing.T) { {CPU: ptr.To("1")}, }, } - result := MergeWithDefaults(override) + result, err := MergeWithDefaultConfig(override) + require.NoError(t, err) + assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) From 0b1bfb5678842fd782ef03d04701cb43870b6afe Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Sun, 13 Jul 2025 00:04:49 +0800 Subject: [PATCH 3/6] chore: change to private func Signed-off-by: Cheyu Wu --- .../pkg/util/generation/generation.go | 4 ++-- .../pkg/util/generation/generation_test.go | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index 637cdfa630e..a86f26fb73a 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -449,14 +449,14 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } - config, err := MergeWithDefaultConfig(&overrideConfig) + config, err := mergeWithDefaultConfig(&overrideConfig) if err != nil { return nil, fmt.Errorf("failed to merge config with defaults: %w", err) } return config, nil } -func MergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { +func mergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { config := newRayClusterConfigWithDefaults() // The defaults are not set in the default raycluster config, diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index d3fa3921bd4..ce2318b28b9 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -1117,7 +1117,7 @@ func TestMergeWithDefaults(t *testing.T) { defaultImage := fmt.Sprintf("rayproject/ray:%s", util.RayVersion) t.Run("Empty RayClusterConfig and return default RayClusterConfig", func(t *testing.T) { - result, err := MergeWithDefaultConfig(&RayClusterConfig{}) + result, err := mergeWithDefaultConfig(&RayClusterConfig{}) require.NoError(t, err) assert.NotNil(t, result) expected := newRayClusterConfigWithDefaults() @@ -1136,7 +1136,7 @@ func TestMergeWithDefaults(t *testing.T) { Labels: inputLabels, Annotations: inputAnnotations, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, inputNamespace, result.Namespace) @@ -1155,7 +1155,7 @@ func TestMergeWithDefaults(t *testing.T) { Image: inputImage, ServiceAccount: inputServiceAccount, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, inputRayVersion, result.RayVersion) @@ -1181,7 +1181,7 @@ func TestMergeWithDefaults(t *testing.T) { NodeSelectors: headNodeSelectors, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, headCPU, result.Head.CPU) @@ -1200,7 +1200,7 @@ func TestMergeWithDefaults(t *testing.T) { CPU: headCPU, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, headCPU, result.Head.CPU) @@ -1242,7 +1242,7 @@ func TestMergeWithDefaults(t *testing.T) { }, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.NotNil(t, result.GKE) @@ -1263,7 +1263,7 @@ func TestMergeWithDefaults(t *testing.T) { override := &RayClusterConfig{ Autoscaler: &Autoscaler{Version: AutoscalerV2}, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) assert.NotNil(t, result.Autoscaler) @@ -1298,7 +1298,7 @@ func TestMergeWithDefaults(t *testing.T) { }, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) @@ -1327,7 +1327,7 @@ func TestMergeWithDefaults(t *testing.T) { {Name: wg2Name, Replicas: wg2Replicas}, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 2) @@ -1345,7 +1345,7 @@ func TestMergeWithDefaults(t *testing.T) { {Name: wg1Name, Replicas: 0}, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) @@ -1359,7 +1359,7 @@ func TestMergeWithDefaults(t *testing.T) { {Name: nil, Replicas: 2}, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) @@ -1373,7 +1373,7 @@ func TestMergeWithDefaults(t *testing.T) { {CPU: ptr.To("1")}, }, } - result, err := MergeWithDefaultConfig(override) + result, err := mergeWithDefaultConfig(override) require.NoError(t, err) assert.NotNil(t, result) require.Len(t, result.WorkerGroups, 1) From b38d2e3e387f65ab09dfd761bdb60bd09479cab0 Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Sun, 13 Jul 2025 20:57:02 +0800 Subject: [PATCH 4/6] fix: remove private func testing use TestParseConfigFile instead --- .../pkg/util/generation/generation_test.go | 411 ++++++------------ 1 file changed, 139 insertions(+), 272 deletions(-) diff --git a/kubectl-plugin/pkg/util/generation/generation_test.go b/kubectl-plugin/pkg/util/generation/generation_test.go index ce2318b28b9..2125c5c3faf 100644 --- a/kubectl-plugin/pkg/util/generation/generation_test.go +++ b/kubectl-plugin/pkg/util/generation/generation_test.go @@ -987,6 +987,145 @@ gke: }, }, }, + "override ray-version, image, service-account": { + config: ` +ray-version: 4.16.0 +image: custom/image:tag +service-account: svcacct +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To("4.16.0"), + Image: ptr.To("custom/image:tag"), + ServiceAccount: ptr.To("svcacct"), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override only head CPU": { + config: ` +head: + cpu: "8" +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("8"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with only CPU": { + config: ` +worker-groups: +- cpu: "1" +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("1"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with empty name": { + config: ` +worker-groups: +- name: "" + replicas: 2 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 2, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override worker group with replicas = 0": { + config: ` +worker-groups: +- name: "wg1" + replicas: 0 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("wg1"), + Replicas: 1, // fallback to default + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + }, + }, + "override autoscaler": { + config: ` +autoscaler: + version: v2 +`, + expected: &RayClusterConfig{ + RayVersion: ptr.To(util.RayVersion), + Image: ptr.To(fmt.Sprintf("rayproject/ray:%s", util.RayVersion)), + Head: &Head{ + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + WorkerGroups: []WorkerGroup{ + { + Name: ptr.To("default-group"), + Replicas: 1, + CPU: ptr.To("2"), + Memory: ptr.To("4Gi"), + }, + }, + Autoscaler: &Autoscaler{ + Version: AutoscalerV2, + }, + }, + }, } for name, test := range tests { @@ -1111,275 +1250,3 @@ func TestGetGCSFuseVolumeAttributes(t *testing.T) { result := getGCSFuseVolumeAttributes(config) assert.Equal(t, expected, result) } - -func TestMergeWithDefaults(t *testing.T) { - defaultRayVersion := util.RayVersion - defaultImage := fmt.Sprintf("rayproject/ray:%s", util.RayVersion) - - t.Run("Empty RayClusterConfig and return default RayClusterConfig", func(t *testing.T) { - result, err := mergeWithDefaultConfig(&RayClusterConfig{}) - require.NoError(t, err) - assert.NotNil(t, result) - expected := newRayClusterConfigWithDefaults() - assert.Equal(t, expected, result) - }) - - t.Run("Override namespace, name, labels, annotations", func(t *testing.T) { - inputNamespace := ptr.To("test-namespace") - inputName := ptr.To("test-name") - inputLabels := map[string]string{"key1": "value1", "key2": "value2"} - inputAnnotations := map[string]string{"annotation1": "value1", "annotation2": "value2"} - - override := &RayClusterConfig{ - Namespace: inputNamespace, - Name: inputName, - Labels: inputLabels, - Annotations: inputAnnotations, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, inputNamespace, result.Namespace) - assert.Equal(t, inputName, result.Name) - assert.Equal(t, inputLabels, result.Labels) - assert.Equal(t, inputAnnotations, result.Annotations) - }) - - t.Run("Override RayVersion, Image, ServiceAccount", func(t *testing.T) { - inputRayVersion := ptr.To("4.16.0") - inputImage := ptr.To("custom/image:tag") - inputServiceAccount := ptr.To("svcacct") - - override := &RayClusterConfig{ - RayVersion: inputRayVersion, - Image: inputImage, - ServiceAccount: inputServiceAccount, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, inputRayVersion, result.RayVersion) - assert.Equal(t, inputImage, result.Image) - assert.Equal(t, inputServiceAccount, result.ServiceAccount) - }) - - t.Run("Override Head fields", func(t *testing.T) { - headCPU := ptr.To("4") - headGPU := ptr.To("2") - headMemory := ptr.To("8Gi") - headEphemeralStorage := ptr.To("20Gi") - headRayStartParams := map[string]string{"foo": "bar"} - headNodeSelectors := map[string]string{"disktype": "ssd"} - - override := &RayClusterConfig{ - Head: &Head{ - CPU: headCPU, - GPU: headGPU, - Memory: headMemory, - EphemeralStorage: headEphemeralStorage, - RayStartParams: headRayStartParams, - NodeSelectors: headNodeSelectors, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, headCPU, result.Head.CPU) - assert.Equal(t, headGPU, result.Head.GPU) - assert.Equal(t, headMemory, result.Head.Memory) - assert.Equal(t, headEphemeralStorage, result.Head.EphemeralStorage) - assert.Equal(t, headRayStartParams, result.Head.RayStartParams) - assert.Equal(t, headNodeSelectors, result.Head.NodeSelectors) - }) - - t.Run("Override only some fields in Head, others remain default", func(t *testing.T) { - headCPU := ptr.To("8") - - override := &RayClusterConfig{ - Head: &Head{ - CPU: headCPU, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, headCPU, result.Head.CPU) - assert.Equal(t, ptr.To("4Gi"), result.Head.Memory) - assert.Equal(t, defaultRayVersion, *result.RayVersion) - assert.Equal(t, defaultImage, *result.Image) - }) - - t.Run("Override GKE.GCSFuse fields", func(t *testing.T) { - gcsFuseMountOption := ptr.To("opt1") - gcsFuseDisableMetrics := ptr.To(true) - gcsFuseMetadataPrefetchOnMount := ptr.To(true) - gcsFuseSkipCSIBucketAccessCheck := ptr.To(true) - gcsFuseBucketName := "bucket" - gcsFuseMountPath := "/mnt/path" - gcsFuseCPU := ptr.To("1") - gcsFuseMemory := ptr.To("2Gi") - gcsFuseEphemeralStorage := ptr.To("3Gi") - gcsFuseResources := &GCSFuseResources{ - CPU: gcsFuseCPU, - Memory: gcsFuseMemory, - EphemeralStorage: gcsFuseEphemeralStorage, - } - - override := &RayClusterConfig{ - GKE: &GKE{ - GCSFuse: &GCSFuse{ - MountOptions: gcsFuseMountOption, - DisableMetrics: gcsFuseDisableMetrics, - GCSFuseMetadataPrefetchOnMount: gcsFuseMetadataPrefetchOnMount, - SkipCSIBucketAccessCheck: gcsFuseSkipCSIBucketAccessCheck, - BucketName: gcsFuseBucketName, - MountPath: gcsFuseMountPath, - Resources: &GCSFuseResources{ - CPU: gcsFuseCPU, - Memory: gcsFuseMemory, - EphemeralStorage: gcsFuseEphemeralStorage, - }, - }, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.NotNil(t, result.GKE) - assert.NotNil(t, result.GKE.GCSFuse) - assert.Equal(t, gcsFuseMountOption, result.GKE.GCSFuse.MountOptions) - assert.Equal(t, gcsFuseDisableMetrics, result.GKE.GCSFuse.DisableMetrics) - assert.Equal(t, gcsFuseMetadataPrefetchOnMount, result.GKE.GCSFuse.GCSFuseMetadataPrefetchOnMount) - assert.Equal(t, gcsFuseSkipCSIBucketAccessCheck, result.GKE.GCSFuse.SkipCSIBucketAccessCheck) - assert.Equal(t, gcsFuseBucketName, result.GKE.GCSFuse.BucketName) - assert.Equal(t, gcsFuseMountPath, result.GKE.GCSFuse.MountPath) - assert.Equal(t, gcsFuseResources, result.GKE.GCSFuse.Resources) - assert.Equal(t, gcsFuseCPU, result.GKE.GCSFuse.Resources.CPU) - assert.Equal(t, gcsFuseMemory, result.GKE.GCSFuse.Resources.Memory) - assert.Equal(t, gcsFuseEphemeralStorage, result.GKE.GCSFuse.Resources.EphemeralStorage) - }) - - t.Run("Override Autoscaler", func(t *testing.T) { - override := &RayClusterConfig{ - Autoscaler: &Autoscaler{Version: AutoscalerV2}, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - assert.NotNil(t, result.Autoscaler) - assert.Equal(t, AutoscalerV2, result.Autoscaler.Version) - }) - - t.Run("Override WorkerGroups fields", func(t *testing.T) { - wgName1 := ptr.To("wg1") - wgCPU := ptr.To("5") - wgGPU := ptr.To("1") - wgTPU := ptr.To("2") - wgNumOfHosts := ptr.To(int32(3)) - wgMemory := ptr.To("16Gi") - wgEphemeralStorage := ptr.To("30Gi") - wgRayStartParams := map[string]string{"param": "val"} - wgNodeSelectors := map[string]string{"zone": "us-central1-a"} - wgReplicas := int32(7) - - override := &RayClusterConfig{ - WorkerGroups: []WorkerGroup{ - { - Name: wgName1, - CPU: wgCPU, - GPU: wgGPU, - TPU: wgTPU, - NumOfHosts: wgNumOfHosts, - Memory: wgMemory, - EphemeralStorage: wgEphemeralStorage, - RayStartParams: wgRayStartParams, - NodeSelectors: wgNodeSelectors, - Replicas: wgReplicas, - }, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.WorkerGroups, 1) - wg := result.WorkerGroups[0] - assert.Equal(t, wgName1, wg.Name) - assert.Equal(t, wgCPU, wg.CPU) - assert.Equal(t, wgGPU, wg.GPU) - assert.Equal(t, wgTPU, wg.TPU) - assert.Equal(t, wgNumOfHosts, wg.NumOfHosts) - assert.Equal(t, wgMemory, wg.Memory) - assert.Equal(t, wgEphemeralStorage, wg.EphemeralStorage) - assert.Equal(t, wgRayStartParams, wg.RayStartParams) - assert.Equal(t, wgNodeSelectors, wg.NodeSelectors) - assert.Equal(t, wgReplicas, wg.Replicas) - }) - - t.Run("Override WorkerGroups with more groups than defaults", func(t *testing.T) { - wg1Name := ptr.To("wg1") - wg2Name := ptr.To("wg2") - wg1Replicas := int32(2) - wg2Replicas := int32(3) - - override := &RayClusterConfig{ - WorkerGroups: []WorkerGroup{ - {Name: wg1Name, Replicas: wg1Replicas}, - {Name: wg2Name, Replicas: wg2Replicas}, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.WorkerGroups, 2) - assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) - assert.Equal(t, wg1Replicas, result.WorkerGroups[0].Replicas) - assert.Equal(t, wg2Name, result.WorkerGroups[1].Name) - assert.Equal(t, wg2Replicas, result.WorkerGroups[1].Replicas) - }) - - t.Run("Override WorkerGroups with zero replicas keeps default", func(t *testing.T) { - wg1Name := ptr.To("wg1") - - override := &RayClusterConfig{ - WorkerGroups: []WorkerGroup{ - {Name: wg1Name, Replicas: 0}, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.WorkerGroups, 1) - assert.Equal(t, wg1Name, result.WorkerGroups[0].Name) - assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) - }) - - t.Run("Override WorkerGroups with empty name keeps default name", func(t *testing.T) { - override := &RayClusterConfig{ - WorkerGroups: []WorkerGroup{ - {Name: nil, Replicas: 2}, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.WorkerGroups, 1) - assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) - assert.Equal(t, int32(2), result.WorkerGroups[0].Replicas) - }) - - t.Run("Override only WorkerGroups CPU", func(t *testing.T) { - override := &RayClusterConfig{ - WorkerGroups: []WorkerGroup{ - {CPU: ptr.To("1")}, - }, - } - result, err := mergeWithDefaultConfig(override) - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.WorkerGroups, 1) - assert.Equal(t, result.WorkerGroups[0].Name, ptr.To("default-group")) - assert.Equal(t, int32(1), result.WorkerGroups[0].Replicas) - assert.Equal(t, ptr.To("1"), result.WorkerGroups[0].CPU) - assert.Equal(t, ptr.To("4Gi"), result.WorkerGroups[0].Memory) - }) -} From 8312c533441aa0c2926c7adda26b1e171920c781 Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Thu, 24 Jul 2025 01:45:22 +0800 Subject: [PATCH 5/6] refactor: use detach workergroup to simplify the code --- .../pkg/util/generation/generation.go | 54 +++++-------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index a86f26fb73a..9d74574731b 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -449,57 +449,27 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } - config, err := mergeWithDefaultConfig(&overrideConfig) - if err != nil { - return nil, fmt.Errorf("failed to merge config with defaults: %w", err) - } - return config, nil -} + // detach worker groups from default config + overrideConfigWG := overrideConfig.WorkerGroups + overrideConfig.WorkerGroups = nil -func mergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { config := newRayClusterConfigWithDefaults() - - // The defaults are not set in the default raycluster config, - // so we directly copy the values from overrideConfig - config.Namespace = overrideConfig.Namespace - config.Name = overrideConfig.Name - config.ServiceAccount = overrideConfig.ServiceAccount - config.GKE = overrideConfig.GKE - config.Autoscaler = overrideConfig.Autoscaler - - if overrideConfig.RayVersion != nil { - config.RayVersion = overrideConfig.RayVersion - } - - if overrideConfig.Labels != nil { - config.Labels = make(map[string]string) - maps.Copy(config.Labels, overrideConfig.Labels) - } - if overrideConfig.Annotations != nil { - config.Annotations = make(map[string]string) - maps.Copy(config.Annotations, overrideConfig.Annotations) - } - - if overrideConfig.Image != nil { - config.Image = overrideConfig.Image - } - - if overrideConfig.Head != nil { - err := mergo.Merge(config.Head, overrideConfig.Head, mergo.WithOverride) - if err != nil { - return nil, fmt.Errorf("failed to merge head config: %w", err) - } + err = mergo.Merge(config, &overrideConfig, mergo.WithOverride) + if err != nil { + return nil, fmt.Errorf("failed to merge config with defaults: %w", err) } - - if overrideConfig.WorkerGroups != nil { - for len(config.WorkerGroups) < len(overrideConfig.WorkerGroups) { + // merge WorkerGroups and keep the default values for missing fields + // if overrideConfigWG is not nil, we will merge the worker groups from the config file + // and keep the default values for missing fields + if overrideConfigWG != nil { + for len(config.WorkerGroups) < len(overrideConfigWG) { config.WorkerGroups = append(config.WorkerGroups, WorkerGroup{ Replicas: util.DefaultWorkerReplicas, CPU: ptr.To(util.DefaultWorkerCPU), Memory: ptr.To(util.DefaultWorkerMemory), }) } - for i, workerGroup := range overrideConfig.WorkerGroups { + for i, workerGroup := range overrideConfigWG { err := mergo.Merge(&config.WorkerGroups[i], workerGroup, mergo.WithOverride) if err != nil { return nil, fmt.Errorf("failed to merge worker group %d: %w", i, err) From 1fefc4dfd7a2b7e558acfac39b27acda880183d0 Mon Sep 17 00:00:00 2001 From: Cheyu Wu Date: Thu, 24 Jul 2025 23:28:43 +0800 Subject: [PATCH 6/6] refactor: merging code for readability mv to private func --- kubectl-plugin/pkg/util/generation/generation.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/kubectl-plugin/pkg/util/generation/generation.go b/kubectl-plugin/pkg/util/generation/generation.go index 9d74574731b..9a83e93d0b4 100644 --- a/kubectl-plugin/pkg/util/generation/generation.go +++ b/kubectl-plugin/pkg/util/generation/generation.go @@ -449,12 +449,20 @@ func ParseConfigFile(filePath string) (*RayClusterConfig, error) { if err := yaml.UnmarshalStrict(data, &overrideConfig); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } + config, err := mergeWithDefaultConfig(&overrideConfig) + if err != nil { + return nil, err + } + return config, nil +} + +func mergeWithDefaultConfig(overrideConfig *RayClusterConfig) (*RayClusterConfig, error) { // detach worker groups from default config overrideConfigWG := overrideConfig.WorkerGroups overrideConfig.WorkerGroups = nil config := newRayClusterConfigWithDefaults() - err = mergo.Merge(config, &overrideConfig, mergo.WithOverride) + err := mergo.Merge(config, overrideConfig, mergo.WithOverride) if err != nil { return nil, fmt.Errorf("failed to merge config with defaults: %w", err) }