@@ -1364,156 +1364,68 @@ def to_dict(self) -> dict[str, Any]:
13641364 return self .properties
13651365
13661366
1367- class MLMExtension (
1368- Generic [T ],
1369- PropertiesExtension ,
1370- ExtensionManagementMixin [pystac .Item | pystac .Collection ],
1371- ):
1372- """An abstract class that can be used to extend to properties of an
1373- :class:`pystac.Item` or :class:`pystac.Collection` with properties from the
1374- :stac-ext:`Machine Learning Model Extension <mlm>`.
1375-
1376- This class can be used to extend :class:`pystac.Item`, :class:`pystac.Collection`
1377- and :class:`pystac.ItemAssetDefinition`. For extending :class:`pystac.Asset`, use
1378- either :class:`~AssetGeneralMLMExtension`: or :class:`AssetDetailedMLMExtension`.
1379- """
1380-
1381- name : Literal ["mlm" ] = "mlm"
1367+ class _ExcludedFromAssetProps (PropertiesExtension ):
13821368 properties : dict [str , Any ]
13831369
1384- def apply (
1385- self ,
1386- name : str ,
1387- architecture : str ,
1388- tasks : list [TaskType ],
1389- input : list [ModelInput ],
1390- output : list [ModelOutput ],
1391- framework : str | None = None ,
1392- framework_version : str | None = None ,
1393- memory_size : int | None = None ,
1394- total_parameters : int | None = None ,
1395- pretrained : bool | None = None ,
1396- pretrained_source : str | None = None ,
1397- batch_size_suggestion : int | None = None ,
1398- accelerator : AcceleratorType | None = None ,
1399- accelerator_constrained : bool | None = None ,
1400- accelerator_summary : str | None = None ,
1401- accelerator_count : int | None = None ,
1402- hyperparameters : Hyperparameters | None = None ,
1403- * args : Any ,
1404- ** kwargs : Any ,
1405- ) -> None :
1370+ @property
1371+ def mlm_name (self ) -> str :
14061372 """
1407- Sets the properties of a new MLMExtension
1408-
1409- Args:
1410- name: name for the model
1411- architecture: A generic and well established architecture name of the model
1412- tasks: Specifies the Machine Learning tasks for which the model can be
1413- used for
1414- input: Describes the transformation between the EO data and the model input
1415- output: Describes each model output and how to interpret it.
1416- framework: Framework used to train the model
1417- framework_version: The ``framework`` library version
1418- memory_size: The in-memory size of the model on the accelerator during
1419- inference (bytes)
1420- total_parameters: Total number of model parameters, including trainable and
1421- non-trainable parameters.
1422- pretrained: Indicates if the model was pretrained. If the model was
1423- pretrained, consider providing ``pretrained_source`` if it is known
1424- pretrained_source: The source of the pretraining.
1425- batch_size_suggestion: A suggested batch size for the accelerator and
1426- summarized hardware.
1427- accelerator: The intended computational hardware that runs inference
1428- accelerator_constrained: Indicates if the intended ``accelerator`` is the
1429- only accelerator that can run inference
1430- accelerator_summary: A high level description of the ``accelerator``
1431- accelerator_count: A minimum amount of ``accelerator`` instances required to
1432- run the model
1433- hyperparameters: Additional hyperparameters relevant for the model
1434- *args: Unused (no effect, only here for signature compliance with apply
1435- method in derived classes
1436- **kwargs: Unused (no effect, only here for signature compliance with apply
1437- method in derived classes
1373+ Get or set the required (mlm) name property. It is named mlm_name in this
1374+ context to not break convention and overwrite the extension name class property.
14381375 """
1439- self .mlm_name = name
1440- self .architecture = architecture
1441- self .tasks = tasks
1442- self .input = input
1443- self .output = output
1444- self .framework = framework
1445- self .framework_version = framework_version
1446- self .memory_size = memory_size
1447- self .total_parameters = total_parameters
1448- self .pretrained = pretrained
1449- self .pretrained_source = pretrained_source
1450- self .batch_size_suggestion = batch_size_suggestion
1451- self .accelerator = accelerator
1452- self .accelerator_constrained = accelerator_constrained
1453- self .accelerator_summary = accelerator_summary
1454- self .accelerator_count = accelerator_count
1455- self .hyperparameters = hyperparameters
1376+ return cast (str , get_required (self .properties .get (NAME_PROP ), self , NAME_PROP ))
14561377
1457- @classmethod
1458- def get_schema_uri (cls ) -> str :
1459- """
1460- Retrieves this extension's schema URI
1378+ @mlm_name .setter
1379+ def mlm_name (self , v : str ) -> None :
1380+ self ._set_property (NAME_PROP , v )
14611381
1462- Returns:
1463- str: the schema URI
1382+ @ property
1383+ def input ( self ) -> list [ ModelInput ]:
14641384 """
1465- return SCHEMA_URI_PATTERN .format (version = DEFAULT_VERSION )
1466-
1467- @classmethod
1468- def ext (cls , obj : T , add_if_missing : bool = False ) -> MLMExtension [T ]:
1385+ Get or set the required input property
14691386 """
1470- Extend a STAC object (``obj``) with the MLMExtension
1471-
1472- Args:
1473- obj: The STAC object to be extended.
1474- add_if_missing: Defines whether this extension's URI should be added to
1475- this object's (or its parent's) list of extensions if it is not already
1476- listed there.
1387+ return [
1388+ ModelInput (inp )
1389+ for inp in get_required (
1390+ self ._get_property (INPUT_PROP , list [dict [str , Any ]]), self , INPUT_PROP
1391+ )
1392+ ]
14771393
1478- Returns:
1479- MLMExtension[T]: The extended object
1394+ @input .setter
1395+ def input (self , v : list [ModelInput ]) -> None :
1396+ self ._set_property (INPUT_PROP , [inp .to_dict () for inp in v ])
14801397
1481- Raises:
1482- TypeError: When a :class:`pystac.Asset` object is passed as the
1483- `obj` parameter
1484- pystac.ExtensionTypeError: When any unsupported object is passed as the
1485- `obj` parameter. If you see this extension in this context, please
1486- raise an issue on github.
1398+ @property
1399+ def output (self ) -> list [ModelOutput ]:
14871400 """
1488- if isinstance (obj , pystac .Item ):
1489- cls .ensure_has_extension (obj , add_if_missing )
1490- return cast (MLMExtension [T ], ItemMLMExtension (obj ))
1491- elif isinstance (obj , pystac .Collection ):
1492- cls .ensure_has_extension (obj , add_if_missing )
1493- return cast (MLMExtension [T ], CollectionMLMExtension (obj ))
1494- elif isinstance (obj , pystac .ItemAssetDefinition ):
1495- cls .ensure_owner_has_extension (obj , add_if_missing )
1496- return cast (MLMExtension [T ], ItemAssetMLMExtension (obj ))
1497- elif isinstance (obj , pystac .Asset ):
1498- raise TypeError (
1499- "This class cannot be used to extend STAC objects of type Assets. "
1500- "To extend Asset objects, use either AssetGeneralMLMExtension or "
1501- "AssetDetailedMLMExtension"
1401+ Get or set the required output property
1402+ """
1403+ return [
1404+ ModelOutput (outp )
1405+ for outp in get_required (
1406+ self ._get_property (OUTPUT_PROP , list [dict [str , Any ]]), self , OUTPUT_PROP
15021407 )
1503- else :
1504- raise pystac .ExtensionTypeError (cls ._ext_error_message (obj ))
1408+ ]
1409+
1410+ @output .setter
1411+ def output (self , v : list [ModelOutput ]) -> None :
1412+ self ._set_property (OUTPUT_PROP , [outp .to_dict () for outp in v ])
15051413
15061414 @property
1507- def mlm_name (self ) -> str :
1415+ def hyperparameters (self ) -> Hyperparameters | None :
15081416 """
1509- Get or set the required (mlm) name property. It is named mlm_name in this
1510- context to not break convention and overwrite the extension name class property.
1417+ Get or set the hyperparameters property
15111418 """
1512- return cast (str , get_required (self .properties .get (NAME_PROP ), self , NAME_PROP ))
1419+ prop = self ._get_property (HYPERPARAMETERS_PROP , dict [str , Any ])
1420+ return Hyperparameters (prop ) if prop is not None else None
15131421
1514- @mlm_name .setter
1515- def mlm_name (self , v : str ) -> None :
1516- self ._set_property (NAME_PROP , v )
1422+ @hyperparameters .setter
1423+ def hyperparameters (self , v : Hyperparameters | None ) -> None :
1424+ self ._set_property (HYPERPARAMETERS_PROP , v .to_dict () if v is not None else None )
1425+
1426+
1427+ class _IncludedInAssetProps (PropertiesExtension ):
1428+ properties : dict [str , Any ]
15171429
15181430 @property
15191431 def architecture (self ) -> str :
@@ -1666,49 +1578,146 @@ def accelerator_count(self) -> int | None:
16661578 def accelerator_count (self , v : int | None ) -> None :
16671579 self ._set_property (ACCELERATOR_COUNT_PROP , v )
16681580
1669- @property
1670- def input (self ) -> list [ModelInput ]:
1671- """
1672- Get or set the required input property
1673- """
1674- return [
1675- ModelInput (inp )
1676- for inp in get_required (
1677- self ._get_property (INPUT_PROP , list [dict [str , Any ]]), self , INPUT_PROP
1678- )
1679- ]
16801581
1681- @input .setter
1682- def input (self , v : list [ModelInput ]) -> None :
1683- self ._set_property (INPUT_PROP , [inp .to_dict () for inp in v ])
1582+ class MLMExtension (
1583+ Generic [T ],
1584+ _ExcludedFromAssetProps ,
1585+ _IncludedInAssetProps ,
1586+ ExtensionManagementMixin [pystac .Item | pystac .Collection ],
1587+ ):
1588+ """An abstract class that can be used to extend to properties of an
1589+ :class:`pystac.Item` or :class:`pystac.Collection` with properties from the
1590+ :stac-ext:`Machine Learning Model Extension <mlm>`.
16841591
1685- @property
1686- def output (self ) -> list [ModelOutput ]:
1592+ This class can be used to extend :class:`pystac.Item`, :class:`pystac.Collection`
1593+ and :class:`pystac.ItemAssetDefinition`. For extending :class:`pystac.Asset`, use
1594+ either :class:`~AssetGeneralMLMExtension`: or :class:`AssetDetailedMLMExtension`.
1595+ """
1596+
1597+ name : Literal ["mlm" ] = "mlm"
1598+ properties : dict [str , Any ]
1599+
1600+ def apply (
1601+ self ,
1602+ name : str ,
1603+ architecture : str ,
1604+ tasks : list [TaskType ],
1605+ input : list [ModelInput ],
1606+ output : list [ModelOutput ],
1607+ framework : str | None = None ,
1608+ framework_version : str | None = None ,
1609+ memory_size : int | None = None ,
1610+ total_parameters : int | None = None ,
1611+ pretrained : bool | None = None ,
1612+ pretrained_source : str | None = None ,
1613+ batch_size_suggestion : int | None = None ,
1614+ accelerator : AcceleratorType | None = None ,
1615+ accelerator_constrained : bool | None = None ,
1616+ accelerator_summary : str | None = None ,
1617+ accelerator_count : int | None = None ,
1618+ hyperparameters : Hyperparameters | None = None ,
1619+ * args : Any ,
1620+ ** kwargs : Any ,
1621+ ) -> None :
16871622 """
1688- Get or set the required output property
1623+ Sets the properties of a new MLMExtension
1624+
1625+ Args:
1626+ name: name for the model
1627+ architecture: A generic and well established architecture name of the model
1628+ tasks: Specifies the Machine Learning tasks for which the model can be
1629+ used for
1630+ input: Describes the transformation between the EO data and the model input
1631+ output: Describes each model output and how to interpret it.
1632+ framework: Framework used to train the model
1633+ framework_version: The ``framework`` library version
1634+ memory_size: The in-memory size of the model on the accelerator during
1635+ inference (bytes)
1636+ total_parameters: Total number of model parameters, including trainable and
1637+ non-trainable parameters.
1638+ pretrained: Indicates if the model was pretrained. If the model was
1639+ pretrained, consider providing ``pretrained_source`` if it is known
1640+ pretrained_source: The source of the pretraining.
1641+ batch_size_suggestion: A suggested batch size for the accelerator and
1642+ summarized hardware.
1643+ accelerator: The intended computational hardware that runs inference
1644+ accelerator_constrained: Indicates if the intended ``accelerator`` is the
1645+ only accelerator that can run inference
1646+ accelerator_summary: A high level description of the ``accelerator``
1647+ accelerator_count: A minimum amount of ``accelerator`` instances required to
1648+ run the model
1649+ hyperparameters: Additional hyperparameters relevant for the model
1650+ *args: Unused (no effect, only here for signature compliance with apply
1651+ method in derived classes
1652+ **kwargs: Unused (no effect, only here for signature compliance with apply
1653+ method in derived classes
16891654 """
1690- return [
1691- ModelOutput (outp )
1692- for outp in get_required (
1693- self ._get_property (OUTPUT_PROP , list [dict [str , Any ]]), self , OUTPUT_PROP
1694- )
1695- ]
1655+ self .mlm_name = name
1656+ self .architecture = architecture
1657+ self .tasks = tasks
1658+ self .input = input
1659+ self .output = output
1660+ self .framework = framework
1661+ self .framework_version = framework_version
1662+ self .memory_size = memory_size
1663+ self .total_parameters = total_parameters
1664+ self .pretrained = pretrained
1665+ self .pretrained_source = pretrained_source
1666+ self .batch_size_suggestion = batch_size_suggestion
1667+ self .accelerator = accelerator
1668+ self .accelerator_constrained = accelerator_constrained
1669+ self .accelerator_summary = accelerator_summary
1670+ self .accelerator_count = accelerator_count
1671+ self .hyperparameters = hyperparameters
16961672
1697- @output .setter
1698- def output (self , v : list [ModelOutput ]) -> None :
1699- self ._set_property (OUTPUT_PROP , [outp .to_dict () for outp in v ])
1673+ @classmethod
1674+ def get_schema_uri (cls ) -> str :
1675+ """
1676+ Retrieves this extension's schema URI
17001677
1701- @ property
1702- def hyperparameters ( self ) -> Hyperparameters | None :
1678+ Returns:
1679+ str: the schema URI
17031680 """
1704- Get or set the hyperparameters property
1681+ return SCHEMA_URI_PATTERN .format (version = DEFAULT_VERSION )
1682+
1683+ @classmethod
1684+ def ext (cls , obj : T , add_if_missing : bool = False ) -> MLMExtension [T ]:
17051685 """
1706- prop = self ._get_property (HYPERPARAMETERS_PROP , dict [str , Any ])
1707- return Hyperparameters (prop ) if prop is not None else None
1686+ Extend a STAC object (``obj``) with the MLMExtension
17081687
1709- @hyperparameters .setter
1710- def hyperparameters (self , v : Hyperparameters | None ) -> None :
1711- self ._set_property (HYPERPARAMETERS_PROP , v .to_dict () if v is not None else None )
1688+ Args:
1689+ obj: The STAC object to be extended.
1690+ add_if_missing: Defines whether this extension's URI should be added to
1691+ this object's (or its parent's) list of extensions if it is not already
1692+ listed there.
1693+
1694+ Returns:
1695+ MLMExtension[T]: The extended object
1696+
1697+ Raises:
1698+ TypeError: When a :class:`pystac.Asset` object is passed as the
1699+ `obj` parameter
1700+ pystac.ExtensionTypeError: When any unsupported object is passed as the
1701+ `obj` parameter. If you see this extension in this context, please
1702+ raise an issue on github.
1703+ """
1704+ if isinstance (obj , pystac .Item ):
1705+ cls .ensure_has_extension (obj , add_if_missing )
1706+ return cast (MLMExtension [T ], ItemMLMExtension (obj ))
1707+ elif isinstance (obj , pystac .Collection ):
1708+ cls .ensure_has_extension (obj , add_if_missing )
1709+ return cast (MLMExtension [T ], CollectionMLMExtension (obj ))
1710+ elif isinstance (obj , pystac .ItemAssetDefinition ):
1711+ cls .ensure_owner_has_extension (obj , add_if_missing )
1712+ return cast (MLMExtension [T ], ItemAssetMLMExtension (obj ))
1713+ elif isinstance (obj , pystac .Asset ):
1714+ raise TypeError (
1715+ "This class cannot be used to extend STAC objects of type Assets. "
1716+ "To extend Asset objects, use either AssetGeneralMLMExtension or "
1717+ "AssetDetailedMLMExtension"
1718+ )
1719+ else :
1720+ raise pystac .ExtensionTypeError (cls ._ext_error_message (obj ))
17121721
17131722 def to_dict (self ) -> dict [str , Any ]:
17141723 """
0 commit comments