Skip to content

Commit 35800e2

Browse files
committed
added more version migrations
1 parent 84f40e8 commit 35800e2

File tree

1 file changed

+85
-40
lines changed

1 file changed

+85
-40
lines changed

pystac/extensions/mlm.py

Lines changed: 85 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,52 +2103,76 @@ def migrate(props_obj: dict[str, Any]) -> None:
21032103

21042104
@staticmethod
21052105
def _migrate_1_1_to_1_2(obj: dict[str, Any]) -> None:
2106-
if "assets" not in obj:
2107-
return
2108-
assets = obj["assets"]
2109-
model_in_assets = any(
2110-
["mlm:model" in assets[asset]["roles"] for asset in assets]
2111-
)
2112-
if not model_in_assets:
2113-
raise pystac.errors.STACError(
2114-
'Error migrating stac:mlm version: An asset with role "mlm:model" is '
2115-
"required in mlm>=1.2. Include at least one asset with role "
2116-
'"mlm:model" '
2106+
def migrate(obj_assets: dict[str, Any]) -> None:
2107+
model_in_assets = any(
2108+
["mlm:model" in obj_assets[asset]["roles"] for asset in obj_assets]
21172109
)
2110+
if not model_in_assets:
2111+
raise pystac.errors.STACError(
2112+
'Error migrating stac:mlm version: An asset with role "mlm:model" '
2113+
"is required in mlm>=1.2. Include at least one asset with role "
2114+
'"mlm:model" '
2115+
)
2116+
2117+
if "assets" in obj:
2118+
migrate(obj["assets"])
2119+
if "item_assets" in obj:
2120+
migrate(obj["item_assets"])
21182121

21192122
@staticmethod
21202123
def _migrate_1_2_to_1_3(obj: dict[str, Any]) -> None:
2121-
bands_obj = obj["properties"]["mlm:input"]
2124+
def migrate(props_obj: dict[str, Any]) -> None:
2125+
if "mlm:input" not in props_obj:
2126+
return
21222127

2123-
if not bands_obj:
2124-
return
2128+
bands_objs_present = any("bands" in inp for inp in props_obj["mlm:input"])
21252129

2126-
if "raster:bands" not in obj["properties"]:
2127-
return
2128-
raster_bands = obj["properties"]["raster:bands"]
2130+
if not bands_objs_present:
2131+
return
21292132

2130-
# make sure all raster_bands have a name prop with length>0
2131-
names_properties_valid = all(
2132-
"name" in band and len(band["name"]) > 0 for band in raster_bands
2133-
)
2134-
if not names_properties_valid:
2135-
raise STACError(
2136-
"Error migrating stac:mlm version: In mlm>=1.3, each band in "
2137-
'raster:bands is required to have a property "name"'
2133+
if "raster:bands" not in props_obj:
2134+
return
2135+
raster_bands = props_obj["raster:bands"]
2136+
2137+
# make sure all raster_bands have a name prop with length>0
2138+
names_properties_valid = all(
2139+
"name" in band and len(band["name"]) > 0 for band in raster_bands
21382140
)
2141+
if not names_properties_valid:
2142+
raise STACError(
2143+
"Error migrating stac:mlm version: In mlm>=1.3, each band in "
2144+
'raster:bands is required to have a property "name"'
2145+
)
2146+
2147+
# no need to perform the actions below if props_obj is an asset
2148+
# this is checked by the presence of "roles" prop
2149+
if "roles" in props_obj:
2150+
return
21392151

2140-
# copy the raster:bands to assets
2141-
for asset_name in obj["assets"]:
2142-
asset = obj["assets"][asset_name]
2143-
if "mlm:model" not in asset["roles"]:
2144-
continue
2145-
asset["raster:bands"] = raster_bands
2152+
# copy the raster:bands to assets
2153+
for inner_asset_name in obj["assets"]:
2154+
inner_asset = obj["assets"][inner_asset_name]
2155+
if "mlm:model" not in inner_asset["roles"]:
2156+
continue
2157+
inner_asset["raster:bands"] = raster_bands
2158+
2159+
if obj["type"] == "Feature" and "mlm:input" in obj["properties"]:
2160+
migrate(obj["properties"])
2161+
if obj["type"] == "Collection":
2162+
migrate(obj)
2163+
if "assets" in obj:
2164+
for asset_name in obj["assets"]:
2165+
asset = obj["assets"][asset_name]
2166+
migrate(asset)
21462167

21472168
@staticmethod
21482169
def _migrate_1_3_to_1_4(obj: dict[str, Any]) -> None:
2149-
# Migrate to value_scaling
2150-
if "mlm:input" in obj["properties"]:
2151-
for input_obj in obj["properties"]["mlm:input"]:
2170+
def migrate(props_obj: dict[str, Any]) -> None:
2171+
if "mlm:input" not in props_obj:
2172+
return
2173+
2174+
# Migrate to value_scaling
2175+
for input_obj in props_obj["mlm:input"]:
21522176
if "norm_type" in input_obj and input_obj["norm_type"] is not None:
21532177
norm_type = input_obj["norm_type"]
21542178
value_scaling_list = []
@@ -2189,22 +2213,43 @@ def _migrate_1_3_to_1_4(obj: dict[str, Any]) -> None:
21892213
input_obj.pop("norm_clip", None)
21902214
input_obj.pop("statistics", None)
21912215

2216+
if obj["type"] == "Feature":
2217+
migrate(obj["properties"])
2218+
if obj["type"] == "Collection":
2219+
migrate(obj)
2220+
21922221
if "assets" in obj:
21932222
for asset in obj["assets"]:
2223+
migrate(obj["assets"][asset])
2224+
21942225
# move forbidden fields from asset to properties
21952226
if "mlm:name" in obj["assets"][asset]:
2196-
obj["properties"]["mlm:name"] = obj["assets"][asset]["mlm:name"]
2227+
mlm_name = obj["assets"][asset]["mlm:name"]
2228+
if obj["type"] == "Feature":
2229+
obj["properties"]["mlm:name"] = mlm_name
2230+
if obj["type"] == "Collection":
2231+
obj["mlm:name"] = mlm_name
21972232
obj["assets"][asset].pop("mlm:name")
21982233
if "mlm:input" in obj["assets"][asset]:
2199-
obj["properties"]["mlm:input"] = obj["assets"][asset]["mlm:input"]
2234+
inp = obj["assets"][asset]["mlm:input"]
2235+
if obj["type"] == "Feature":
2236+
obj["properties"]["mlm:input"] = inp
2237+
if obj["type"] == "Collection":
2238+
obj["mlm:input"] = inp
22002239
obj["assets"][asset].pop("mlm:input")
22012240
if "mlm:output" in obj["assets"][asset]:
2202-
obj["properties"]["mlm:output"] = obj["assets"][asset]["mlm:output"]
2241+
outp = obj["assets"][asset]["mlm:output"]
2242+
if obj["type"] == "Feature":
2243+
obj["properties"]["mlm:output"] = outp
2244+
if obj["type"] == "Collection":
2245+
obj["mlm:output"] = outp
22032246
obj["assets"][asset].pop("mlm:output")
22042247
if "mlm:hyperparameters" in obj["assets"][asset]:
2205-
obj["properties"]["mlm:hyperparameters"] = obj["assets"][asset][
2206-
"mlm:hyperparameters"
2207-
]
2248+
hyp = obj["assets"][asset]["mlm:hyperparameters"]
2249+
if obj["type"] == "Feature":
2250+
obj["properties"]["mlm:hyperparameters"] = hyp
2251+
if obj["type"] == "Collection":
2252+
obj["mlm:hyperparameters"] = hyp
22082253
obj["assets"][asset].pop("mlm:hyperparameters")
22092254

22102255
# add new REQUIRED proretie mlm:artifact_type to asset

0 commit comments

Comments
 (0)