Skip to content

Commit 25bc3ea

Browse files
committed
migrate now forbidden properties in assets
1 parent 117dac3 commit 25bc3ea

File tree

2 files changed

+96
-41
lines changed

2 files changed

+96
-41
lines changed

pystac/extensions/mlm.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,46 +2135,70 @@ def _migrate_1_2_to_1_3(obj: dict[str, Any]) -> None:
21352135

21362136
@staticmethod
21372137
def _migrate_1_3_to_1_4(obj: dict[str, Any]) -> None:
2138-
for input_obj in obj["properties"]["mlm:input"]:
2139-
if "norm_type" in input_obj and input_obj["norm_type"] is not None:
2140-
norm_type = input_obj["norm_type"]
2141-
value_scaling_list = []
2142-
if norm_type == "min-max":
2143-
for band_statistic in input_obj["statistics"]:
2144-
value_scaling_obj = {
2145-
"type": "min-max",
2146-
"minimum": band_statistic["minimum"],
2147-
"maximum": band_statistic["maximum"],
2148-
}
2149-
value_scaling_list.append(value_scaling_obj)
2150-
elif norm_type == "z-score":
2151-
for band_statistic in input_obj["statistics"]:
2152-
value_scaling_obj = {
2153-
"type": "z-score",
2154-
"mean": band_statistic["mean"],
2155-
"stddev": band_statistic["stddev"],
2156-
}
2157-
value_scaling_list.append(value_scaling_obj)
2158-
elif norm_type == "clip":
2159-
for clip_value in input_obj["norm_clip"]:
2160-
value_scaling_obj = {
2161-
"type": "processing",
2162-
"format": "gdal-calc",
2163-
"expression": f"numpy.clip(A / {clip_value}, 0, 1)",
2164-
}
2165-
value_scaling_list.append(value_scaling_obj)
2166-
else:
2167-
raise NotImplementedError(
2168-
f"Normalization type {norm_type} is not supported in stac:mlm"
2169-
f" >= 1.3. Therefore an automatic migration is not possible. "
2170-
f"Please migrate this normalization manually using "
2171-
f'type="processing".'
2172-
)
2173-
input_obj["value_scaling"] = value_scaling_list
2174-
input_obj.pop("norm_by_channel", None)
2175-
input_obj.pop("norm_type", None)
2176-
input_obj.pop("norm_clip", None)
2177-
input_obj.pop("statistics", None)
2138+
# Migrate to value_scaling
2139+
if "mlm:input" in obj["properties"]:
2140+
for input_obj in obj["properties"]["mlm:input"]:
2141+
if "norm_type" in input_obj and input_obj["norm_type"] is not None:
2142+
norm_type = input_obj["norm_type"]
2143+
value_scaling_list = []
2144+
if norm_type == "min-max":
2145+
for band_statistic in input_obj["statistics"]:
2146+
value_scaling_obj = {
2147+
"type": "min-max",
2148+
"minimum": band_statistic["minimum"],
2149+
"maximum": band_statistic["maximum"],
2150+
}
2151+
value_scaling_list.append(value_scaling_obj)
2152+
elif norm_type == "z-score":
2153+
for band_statistic in input_obj["statistics"]:
2154+
value_scaling_obj = {
2155+
"type": "z-score",
2156+
"mean": band_statistic["mean"],
2157+
"stddev": band_statistic["stddev"],
2158+
}
2159+
value_scaling_list.append(value_scaling_obj)
2160+
elif norm_type == "clip":
2161+
for clip_value in input_obj["norm_clip"]:
2162+
value_scaling_obj = {
2163+
"type": "processing",
2164+
"format": "gdal-calc",
2165+
"expression": f"numpy.clip(A / {clip_value}, 0, 1)",
2166+
}
2167+
value_scaling_list.append(value_scaling_obj)
2168+
else:
2169+
raise NotImplementedError(
2170+
f"Normalization type {norm_type} is not supported in "
2171+
f"stac:mlm >= 1.3. Therefore an automatic migration is not "
2172+
f"possible. Please migrate this normalization manually "
2173+
f'using type="processing".'
2174+
)
2175+
input_obj["value_scaling"] = value_scaling_list
2176+
input_obj.pop("norm_by_channel", None)
2177+
input_obj.pop("norm_type", None)
2178+
input_obj.pop("norm_clip", None)
2179+
input_obj.pop("statistics", None)
2180+
2181+
if "assets" in obj:
2182+
for asset in obj["assets"]:
2183+
# move forbidden fields from asset to properties
2184+
if "mlm:name" in obj["assets"][asset]:
2185+
obj["properties"]["mlm:name"] = obj["assets"][asset]["mlm:name"]
2186+
obj["assets"][asset].pop("mlm:name")
2187+
if "mlm:input" in obj["assets"][asset]:
2188+
obj["properties"]["mlm:input"] = obj["assets"][asset]["mlm:input"]
2189+
obj["assets"][asset].pop("mlm:input")
2190+
if "mlm:output" in obj["assets"][asset]:
2191+
obj["properties"]["mlm:output"] = obj["assets"][asset]["mlm:output"]
2192+
obj["assets"][asset].pop("mlm:output")
2193+
if "mlm:hyperparameters" in obj["assets"][asset]:
2194+
obj["properties"]["mlm:hyperparameters"] = obj["assets"][asset][
2195+
"mlm:hyperparameters"
2196+
]
2197+
obj["assets"][asset].pop("mlm:hyperparameters")
2198+
2199+
# add new REQUIRED proretie mlm:artifact_type to asset
2200+
if "mlm:model" in obj["assets"][asset]["roles"]:
2201+
obj["assets"][asset]["mlm:artifact_type"] = ""
21782202

21792203
def migrate(
21802204
self, obj: dict[str, Any], version: STACVersionID, info: STACJSONDescription

tests/extensions/test_mlm.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def test_migration_1_2_to_1_3(
12431243
),
12441244
),
12451245
)
1246-
def test_migration_1_3_to_1_4(
1246+
def test_migration_1_3_to_1_4_value_scaling(
12471247
norm_by_channel: bool | None,
12481248
norm_type: str | None,
12491249
norm_clip: list[int] | None,
@@ -1289,3 +1289,34 @@ def test_migration_1_3_to_1_4_failure(norm_type: str) -> None:
12891289

12901290
with pytest.raises(NotImplementedError):
12911291
MLMExtensionHooks._migrate_1_3_to_1_4(data)
1292+
1293+
1294+
def test_migration_1_3_to_1_4_assets() -> None:
1295+
data: dict[str, Any] = {
1296+
"properties": {},
1297+
"assets": {
1298+
"asset1": {
1299+
"mlm:name": "asdf",
1300+
"mlm:input": {},
1301+
"mlm:output": {},
1302+
"mlm:hyperparameters": {},
1303+
"roles": ["mlm:model"],
1304+
}
1305+
},
1306+
}
1307+
1308+
MLMExtensionHooks._migrate_1_3_to_1_4(data)
1309+
1310+
assert "mlm:name" not in data["assets"]["asset1"]
1311+
assert "mlm:name" in data["properties"]
1312+
1313+
assert "mlm:input" not in data["assets"]["asset1"]
1314+
assert "mlm:input" in data["properties"]
1315+
1316+
assert "mlm:output" not in data["assets"]["asset1"]
1317+
assert "mlm:output" in data["properties"]
1318+
1319+
assert "mlm:hyperparameters" not in data["assets"]["asset1"]
1320+
assert "mlm:hyperparameters" in data["properties"]
1321+
1322+
assert "mlm:artifact_type" in data["assets"]["asset1"]

0 commit comments

Comments
 (0)