diff --git a/cloudpub/ms_azure/utils.py b/cloudpub/ms_azure/utils.py index 916f926..7d47b65 100644 --- a/cloudpub/ms_azure/utils.py +++ b/cloudpub/ms_azure/utils.py @@ -280,7 +280,7 @@ def update_skus( disk_versions (list) List of existing DiskVersion in the technical config generation (str) - The main generation for publishing + The main generation for publishing when there are no old_skus plan-name (str) The destination plan name. old_skus (list, optional) @@ -306,14 +306,26 @@ def update_skus( # The alternate plan name ends with the suffix "-genX" and we can't change that once # the offer is live, otherwise it will raise "BadRequest" with the message: # "The property 'PlanId' is locked by a previous submission". - default_gen = "V2" - alt_gen = "V1" - for osku in old_skus: - if osku.security_type is not None: - security_type = osku.security_type - if osku.id.endswith("-gen2"): # alternate is gen2 hence V1 is the default. - default_gen = "V1" - alt_gen = "V2" + osku = old_skus[0] + # Get the security type for all gens + if osku.security_type is not None: + security_type = osku.security_type + + # Default Gen2 cases + if osku.image_type.endswith("Gen1") and osku.id.endswith("gen1"): + default_gen = "V2" + alt_gen = "V1" + elif osku.image_type.endswith("Gen2") and not osku.id.endswith("gen2"): + default_gen = "V2" + alt_gen = "V1" + + # Default Gen1 cases + elif osku.image_type.endswith("Gen1") and not osku.id.endswith("gen1"): + default_gen = "V1" + alt_gen = "V2" + elif osku.image_type.endswith("Gen2") and osku.id.endswith("gen2"): + default_gen = "V1" + alt_gen = "V2" return _build_skus( disk_versions, diff --git a/tests/ms_azure/test_utils.py b/tests/ms_azure/test_utils.py index ab527e4..7f6199f 100644 --- a/tests/ms_azure/test_utils.py +++ b/tests/ms_azure/test_utils.py @@ -277,6 +277,46 @@ def test_update_existing_skus_gen1_default( ] ] + @pytest.mark.parametrize("generation", ["V1", "V2"]) + def test_update_existing_skus_gen1_single( + self, generation: str, technical_config_obj: VMIPlanTechConfig + ) -> None: + skus = [VMISku.from_json({"imageType": "x64Gen1", "skuId": "plan1"})] + technical_config_obj.skus = skus + res = update_skus( + disk_versions=technical_config_obj.disk_versions, + generation=generation, + plan_name="plan1", + old_skus=technical_config_obj.skus, + ) + assert res == [ + VMISku.from_json(x) + for x in [ + {"imageType": "x64Gen1", "skuId": "plan1", "securityType": None}, + {"imageType": "x64Gen2", "skuId": "plan1-gen2", "securityType": None}, + ] + ] + + @pytest.mark.parametrize("generation", ["V1", "V2"]) + def test_update_existing_skus_gen2_single( + self, generation: str, technical_config_obj: VMIPlanTechConfig + ) -> None: + skus = [VMISku.from_json({"imageType": "x64Gen2", "skuId": "plan1"})] + technical_config_obj.skus = skus + res = update_skus( + disk_versions=technical_config_obj.disk_versions, + generation=generation, + plan_name="plan1", + old_skus=technical_config_obj.skus, + ) + assert res == [ + VMISku.from_json(x) + for x in [ + {"imageType": "x64Gen2", "skuId": "plan1", "securityType": None}, + {"imageType": "x64Gen1", "skuId": "plan1-gen1", "securityType": None}, + ] + ] + def test_create_disk_version_from_scratch( self, disk_version_obj: DiskVersion,