Skip to content

Commit 117dac3

Browse files
committed
added migration from 1.2 to 1.3
1 parent 6ca98f0 commit 117dac3

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

pystac/extensions/mlm.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2107,7 +2107,31 @@ def _migrate_1_1_to_1_2(obj: dict[str, Any]) -> None:
21072107

21082108
@staticmethod
21092109
def _migrate_1_2_to_1_3(obj: dict[str, Any]) -> None:
2110-
pass
2110+
bands_obj = obj["properties"]["mlm:input"]
2111+
2112+
if not bands_obj:
2113+
return
2114+
2115+
if "raster:bands" not in obj["properties"]:
2116+
return
2117+
raster_bands = obj["properties"]["raster:bands"]
2118+
2119+
# make sure all raster_bands have a name prop with length>0
2120+
names_properties_valid = all(
2121+
"name" in band and len(band["name"]) > 0 for band in raster_bands
2122+
)
2123+
if not names_properties_valid:
2124+
raise STACError(
2125+
"Error migrating stac:mlm version: In mlm>=1.3, each band in "
2126+
'raster:bands is required to have a property "name"'
2127+
)
2128+
2129+
# copy the raster:bands to assets
2130+
for asset_name in obj["assets"]:
2131+
asset = obj["assets"][asset_name]
2132+
if "mlm:model" not in asset["roles"]:
2133+
continue
2134+
asset["raster:bands"] = raster_bands
21112135

21122136
@staticmethod
21132137
def _migrate_1_3_to_1_4(obj: dict[str, Any]) -> None:

tests/extensions/test_mlm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,69 @@ def test_migration_1_1_to_1_2() -> None:
11211121
MLMExtensionHooks._migrate_1_1_to_1_2(data)
11221122

11231123

1124+
@pytest.mark.parametrize(
1125+
"inp_bands, raster_bands, valid",
1126+
(
1127+
([], None, True),
1128+
(
1129+
["B02", "B03"],
1130+
[
1131+
{"name": "B02", "data_type": "float64"},
1132+
{"name": "B03", "data_type": "float64"},
1133+
],
1134+
True,
1135+
),
1136+
(
1137+
["B02", "B03"],
1138+
[
1139+
{"name": "", "data_type": "float64"},
1140+
{"name": "", "data_type": "float64"},
1141+
],
1142+
False,
1143+
),
1144+
(
1145+
["B02", "B03"],
1146+
[
1147+
{"name": "B02", "data_type": "float64"},
1148+
{"name": "", "data_type": "float64"},
1149+
],
1150+
False,
1151+
),
1152+
(
1153+
["B02", "B03"],
1154+
[{"name": "B02", "data_type": "float64"}, {"data_type": "float64"}],
1155+
False,
1156+
),
1157+
(
1158+
["B02", "B03"],
1159+
[{"name": "", "data_type": "float64"}, {"data_type": "float64"}],
1160+
False,
1161+
),
1162+
(["B02", "B03"], [{"data_type": "float64"}, {"data_type": "float64"}], False),
1163+
),
1164+
)
1165+
def test_migration_1_2_to_1_3(
1166+
inp_bands: list[str], raster_bands: list[dict[str, Any]], valid: bool
1167+
) -> None:
1168+
data: dict[str, Any] = {
1169+
"properties": {"mlm:input": {}},
1170+
"assets": {"asset1": {"roles": ["data"]}, "asset2": {"roles": ["mlm:model"]}},
1171+
}
1172+
1173+
if inp_bands:
1174+
data["properties"]["mlm:input"]["bands"] = inp_bands
1175+
data["properties"]["raster:bands"] = raster_bands
1176+
1177+
if valid:
1178+
MLMExtensionHooks._migrate_1_2_to_1_3(data)
1179+
if raster_bands:
1180+
assert "raster:bands" not in data["assets"]["asset1"]
1181+
assert "raster:bands" in data["assets"]["asset2"]
1182+
else:
1183+
with pytest.raises(STACError):
1184+
MLMExtensionHooks._migrate_1_2_to_1_3(data)
1185+
1186+
11241187
@pytest.mark.parametrize(
11251188
("norm_by_channel", "norm_type", "norm_clip", "statistics", "value_scaling"),
11261189
(

0 commit comments

Comments
 (0)