@@ -2103,52 +2103,76 @@ def migrate(props_obj: dict[str, Any]) -> None:
21032103
21042104 @staticmethod
21052105 def _migrate_1_1_to_1_2 (obj : dict [str , Any ]) -> None :
2106- if "assets" not in obj :
2107- return
2108- assets = obj ["assets" ]
2109- model_in_assets = any (
2110- ["mlm:model" in assets [asset ]["roles" ] for asset in assets ]
2111- )
2112- if not model_in_assets :
2113- raise pystac .errors .STACError (
2114- 'Error migrating stac:mlm version: An asset with role "mlm:model" is '
2115- "required in mlm>=1.2. Include at least one asset with role "
2116- '"mlm:model" '
2106+ def migrate (obj_assets : dict [str , Any ]) -> None :
2107+ model_in_assets = any (
2108+ ["mlm:model" in obj_assets [asset ]["roles" ] for asset in obj_assets ]
21172109 )
2110+ if not model_in_assets :
2111+ raise pystac .errors .STACError (
2112+ 'Error migrating stac:mlm version: An asset with role "mlm:model" '
2113+ "is required in mlm>=1.2. Include at least one asset with role "
2114+ '"mlm:model" '
2115+ )
2116+
2117+ if "assets" in obj :
2118+ migrate (obj ["assets" ])
2119+ if "item_assets" in obj :
2120+ migrate (obj ["item_assets" ])
21182121
21192122 @staticmethod
21202123 def _migrate_1_2_to_1_3 (obj : dict [str , Any ]) -> None :
2121- bands_obj = obj ["properties" ]["mlm:input" ]
2124+ def migrate (props_obj : dict [str , Any ]) -> None :
2125+ if "mlm:input" not in props_obj :
2126+ return
21222127
2123- if not bands_obj :
2124- return
2128+ bands_objs_present = any ("bands" in inp for inp in props_obj ["mlm:input" ])
21252129
2126- if "raster:bands" not in obj ["properties" ]:
2127- return
2128- raster_bands = obj ["properties" ]["raster:bands" ]
2130+ if not bands_objs_present :
2131+ return
21292132
2130- # make sure all raster_bands have a name prop with length>0
2131- names_properties_valid = all (
2132- "name" in band and len (band ["name" ]) > 0 for band in raster_bands
2133- )
2134- if not names_properties_valid :
2135- raise STACError (
2136- "Error migrating stac:mlm version: In mlm>=1.3, each band in "
2137- 'raster:bands is required to have a property "name"'
2133+ if "raster:bands" not in props_obj :
2134+ return
2135+ raster_bands = props_obj ["raster:bands" ]
2136+
2137+ # make sure all raster_bands have a name prop with length>0
2138+ names_properties_valid = all (
2139+ "name" in band and len (band ["name" ]) > 0 for band in raster_bands
21382140 )
2141+ if not names_properties_valid :
2142+ raise STACError (
2143+ "Error migrating stac:mlm version: In mlm>=1.3, each band in "
2144+ 'raster:bands is required to have a property "name"'
2145+ )
2146+
2147+ # no need to perform the actions below if props_obj is an asset
2148+ # this is checked by the presence of "roles" prop
2149+ if "roles" in props_obj :
2150+ return
21392151
2140- # copy the raster:bands to assets
2141- for asset_name in obj ["assets" ]:
2142- asset = obj ["assets" ][asset_name ]
2143- if "mlm:model" not in asset ["roles" ]:
2144- continue
2145- asset ["raster:bands" ] = raster_bands
2152+ # copy the raster:bands to assets
2153+ for inner_asset_name in obj ["assets" ]:
2154+ inner_asset = obj ["assets" ][inner_asset_name ]
2155+ if "mlm:model" not in inner_asset ["roles" ]:
2156+ continue
2157+ inner_asset ["raster:bands" ] = raster_bands
2158+
2159+ if obj ["type" ] == "Feature" and "mlm:input" in obj ["properties" ]:
2160+ migrate (obj ["properties" ])
2161+ if obj ["type" ] == "Collection" :
2162+ migrate (obj )
2163+ if "assets" in obj :
2164+ for asset_name in obj ["assets" ]:
2165+ asset = obj ["assets" ][asset_name ]
2166+ migrate (asset )
21462167
21472168 @staticmethod
21482169 def _migrate_1_3_to_1_4 (obj : dict [str , Any ]) -> None :
2149- # Migrate to value_scaling
2150- if "mlm:input" in obj ["properties" ]:
2151- for input_obj in obj ["properties" ]["mlm:input" ]:
2170+ def migrate (props_obj : dict [str , Any ]) -> None :
2171+ if "mlm:input" not in props_obj :
2172+ return
2173+
2174+ # Migrate to value_scaling
2175+ for input_obj in props_obj ["mlm:input" ]:
21522176 if "norm_type" in input_obj and input_obj ["norm_type" ] is not None :
21532177 norm_type = input_obj ["norm_type" ]
21542178 value_scaling_list = []
@@ -2189,22 +2213,43 @@ def _migrate_1_3_to_1_4(obj: dict[str, Any]) -> None:
21892213 input_obj .pop ("norm_clip" , None )
21902214 input_obj .pop ("statistics" , None )
21912215
2216+ if obj ["type" ] == "Feature" :
2217+ migrate (obj ["properties" ])
2218+ if obj ["type" ] == "Collection" :
2219+ migrate (obj )
2220+
21922221 if "assets" in obj :
21932222 for asset in obj ["assets" ]:
2223+ migrate (obj ["assets" ][asset ])
2224+
21942225 # move forbidden fields from asset to properties
21952226 if "mlm:name" in obj ["assets" ][asset ]:
2196- obj ["properties" ]["mlm:name" ] = obj ["assets" ][asset ]["mlm:name" ]
2227+ mlm_name = obj ["assets" ][asset ]["mlm:name" ]
2228+ if obj ["type" ] == "Feature" :
2229+ obj ["properties" ]["mlm:name" ] = mlm_name
2230+ if obj ["type" ] == "Collection" :
2231+ obj ["mlm:name" ] = mlm_name
21972232 obj ["assets" ][asset ].pop ("mlm:name" )
21982233 if "mlm:input" in obj ["assets" ][asset ]:
2199- obj ["properties" ]["mlm:input" ] = obj ["assets" ][asset ]["mlm:input" ]
2234+ inp = obj ["assets" ][asset ]["mlm:input" ]
2235+ if obj ["type" ] == "Feature" :
2236+ obj ["properties" ]["mlm:input" ] = inp
2237+ if obj ["type" ] == "Collection" :
2238+ obj ["mlm:input" ] = inp
22002239 obj ["assets" ][asset ].pop ("mlm:input" )
22012240 if "mlm:output" in obj ["assets" ][asset ]:
2202- obj ["properties" ]["mlm:output" ] = obj ["assets" ][asset ]["mlm:output" ]
2241+ outp = obj ["assets" ][asset ]["mlm:output" ]
2242+ if obj ["type" ] == "Feature" :
2243+ obj ["properties" ]["mlm:output" ] = outp
2244+ if obj ["type" ] == "Collection" :
2245+ obj ["mlm:output" ] = outp
22032246 obj ["assets" ][asset ].pop ("mlm:output" )
22042247 if "mlm:hyperparameters" in obj ["assets" ][asset ]:
2205- obj ["properties" ]["mlm:hyperparameters" ] = obj ["assets" ][asset ][
2206- "mlm:hyperparameters"
2207- ]
2248+ hyp = obj ["assets" ][asset ]["mlm:hyperparameters" ]
2249+ if obj ["type" ] == "Feature" :
2250+ obj ["properties" ]["mlm:hyperparameters" ] = hyp
2251+ if obj ["type" ] == "Collection" :
2252+ obj ["mlm:hyperparameters" ] = hyp
22082253 obj ["assets" ][asset ].pop ("mlm:hyperparameters" )
22092254
22102255 # add new REQUIRED proretie mlm:artifact_type to asset
0 commit comments