Skip to content

Commit 420f876

Browse files
committed
refactor: move properties to parent classes
1 parent 71e5d1e commit 420f876

File tree

1 file changed

+176
-167
lines changed

1 file changed

+176
-167
lines changed

pystac/extensions/mlm.py

Lines changed: 176 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,156 +1364,68 @@ def to_dict(self) -> dict[str, Any]:
13641364
return self.properties
13651365

13661366

1367-
class MLMExtension(
1368-
Generic[T],
1369-
PropertiesExtension,
1370-
ExtensionManagementMixin[pystac.Item | pystac.Collection],
1371-
):
1372-
"""An abstract class that can be used to extend to properties of an
1373-
:class:`pystac.Item` or :class:`pystac.Collection` with properties from the
1374-
:stac-ext:`Machine Learning Model Extension <mlm>`.
1375-
1376-
This class can be used to extend :class:`pystac.Item`, :class:`pystac.Collection`
1377-
and :class:`pystac.ItemAssetDefinition`. For extending :class:`pystac.Asset`, use
1378-
either :class:`~AssetGeneralMLMExtension`: or :class:`AssetDetailedMLMExtension`.
1379-
"""
1380-
1381-
name: Literal["mlm"] = "mlm"
1367+
class _ExcludedFromAssetProps(PropertiesExtension):
13821368
properties: dict[str, Any]
13831369

1384-
def apply(
1385-
self,
1386-
name: str,
1387-
architecture: str,
1388-
tasks: list[TaskType],
1389-
input: list[ModelInput],
1390-
output: list[ModelOutput],
1391-
framework: str | None = None,
1392-
framework_version: str | None = None,
1393-
memory_size: int | None = None,
1394-
total_parameters: int | None = None,
1395-
pretrained: bool | None = None,
1396-
pretrained_source: str | None = None,
1397-
batch_size_suggestion: int | None = None,
1398-
accelerator: AcceleratorType | None = None,
1399-
accelerator_constrained: bool | None = None,
1400-
accelerator_summary: str | None = None,
1401-
accelerator_count: int | None = None,
1402-
hyperparameters: Hyperparameters | None = None,
1403-
*args: Any,
1404-
**kwargs: Any,
1405-
) -> None:
1370+
@property
1371+
def mlm_name(self) -> str:
14061372
"""
1407-
Sets the properties of a new MLMExtension
1408-
1409-
Args:
1410-
name: name for the model
1411-
architecture: A generic and well established architecture name of the model
1412-
tasks: Specifies the Machine Learning tasks for which the model can be
1413-
used for
1414-
input: Describes the transformation between the EO data and the model input
1415-
output: Describes each model output and how to interpret it.
1416-
framework: Framework used to train the model
1417-
framework_version: The ``framework`` library version
1418-
memory_size: The in-memory size of the model on the accelerator during
1419-
inference (bytes)
1420-
total_parameters: Total number of model parameters, including trainable and
1421-
non-trainable parameters.
1422-
pretrained: Indicates if the model was pretrained. If the model was
1423-
pretrained, consider providing ``pretrained_source`` if it is known
1424-
pretrained_source: The source of the pretraining.
1425-
batch_size_suggestion: A suggested batch size for the accelerator and
1426-
summarized hardware.
1427-
accelerator: The intended computational hardware that runs inference
1428-
accelerator_constrained: Indicates if the intended ``accelerator`` is the
1429-
only accelerator that can run inference
1430-
accelerator_summary: A high level description of the ``accelerator``
1431-
accelerator_count: A minimum amount of ``accelerator`` instances required to
1432-
run the model
1433-
hyperparameters: Additional hyperparameters relevant for the model
1434-
*args: Unused (no effect, only here for signature compliance with apply
1435-
method in derived classes
1436-
**kwargs: Unused (no effect, only here for signature compliance with apply
1437-
method in derived classes
1373+
Get or set the required (mlm) name property. It is named mlm_name in this
1374+
context to not break convention and overwrite the extension name class property.
14381375
"""
1439-
self.mlm_name = name
1440-
self.architecture = architecture
1441-
self.tasks = tasks
1442-
self.input = input
1443-
self.output = output
1444-
self.framework = framework
1445-
self.framework_version = framework_version
1446-
self.memory_size = memory_size
1447-
self.total_parameters = total_parameters
1448-
self.pretrained = pretrained
1449-
self.pretrained_source = pretrained_source
1450-
self.batch_size_suggestion = batch_size_suggestion
1451-
self.accelerator = accelerator
1452-
self.accelerator_constrained = accelerator_constrained
1453-
self.accelerator_summary = accelerator_summary
1454-
self.accelerator_count = accelerator_count
1455-
self.hyperparameters = hyperparameters
1376+
return cast(str, get_required(self.properties.get(NAME_PROP), self, NAME_PROP))
14561377

1457-
@classmethod
1458-
def get_schema_uri(cls) -> str:
1459-
"""
1460-
Retrieves this extension's schema URI
1378+
@mlm_name.setter
1379+
def mlm_name(self, v: str) -> None:
1380+
self._set_property(NAME_PROP, v)
14611381

1462-
Returns:
1463-
str: the schema URI
1382+
@property
1383+
def input(self) -> list[ModelInput]:
14641384
"""
1465-
return SCHEMA_URI_PATTERN.format(version=DEFAULT_VERSION)
1466-
1467-
@classmethod
1468-
def ext(cls, obj: T, add_if_missing: bool = False) -> MLMExtension[T]:
1385+
Get or set the required input property
14691386
"""
1470-
Extend a STAC object (``obj``) with the MLMExtension
1471-
1472-
Args:
1473-
obj: The STAC object to be extended.
1474-
add_if_missing: Defines whether this extension's URI should be added to
1475-
this object's (or its parent's) list of extensions if it is not already
1476-
listed there.
1387+
return [
1388+
ModelInput(inp)
1389+
for inp in get_required(
1390+
self._get_property(INPUT_PROP, list[dict[str, Any]]), self, INPUT_PROP
1391+
)
1392+
]
14771393

1478-
Returns:
1479-
MLMExtension[T]: The extended object
1394+
@input.setter
1395+
def input(self, v: list[ModelInput]) -> None:
1396+
self._set_property(INPUT_PROP, [inp.to_dict() for inp in v])
14801397

1481-
Raises:
1482-
TypeError: When a :class:`pystac.Asset` object is passed as the
1483-
`obj` parameter
1484-
pystac.ExtensionTypeError: When any unsupported object is passed as the
1485-
`obj` parameter. If you see this extension in this context, please
1486-
raise an issue on github.
1398+
@property
1399+
def output(self) -> list[ModelOutput]:
14871400
"""
1488-
if isinstance(obj, pystac.Item):
1489-
cls.ensure_has_extension(obj, add_if_missing)
1490-
return cast(MLMExtension[T], ItemMLMExtension(obj))
1491-
elif isinstance(obj, pystac.Collection):
1492-
cls.ensure_has_extension(obj, add_if_missing)
1493-
return cast(MLMExtension[T], CollectionMLMExtension(obj))
1494-
elif isinstance(obj, pystac.ItemAssetDefinition):
1495-
cls.ensure_owner_has_extension(obj, add_if_missing)
1496-
return cast(MLMExtension[T], ItemAssetMLMExtension(obj))
1497-
elif isinstance(obj, pystac.Asset):
1498-
raise TypeError(
1499-
"This class cannot be used to extend STAC objects of type Assets. "
1500-
"To extend Asset objects, use either AssetGeneralMLMExtension or "
1501-
"AssetDetailedMLMExtension"
1401+
Get or set the required output property
1402+
"""
1403+
return [
1404+
ModelOutput(outp)
1405+
for outp in get_required(
1406+
self._get_property(OUTPUT_PROP, list[dict[str, Any]]), self, OUTPUT_PROP
15021407
)
1503-
else:
1504-
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
1408+
]
1409+
1410+
@output.setter
1411+
def output(self, v: list[ModelOutput]) -> None:
1412+
self._set_property(OUTPUT_PROP, [outp.to_dict() for outp in v])
15051413

15061414
@property
1507-
def mlm_name(self) -> str:
1415+
def hyperparameters(self) -> Hyperparameters | None:
15081416
"""
1509-
Get or set the required (mlm) name property. It is named mlm_name in this
1510-
context to not break convention and overwrite the extension name class property.
1417+
Get or set the hyperparameters property
15111418
"""
1512-
return cast(str, get_required(self.properties.get(NAME_PROP), self, NAME_PROP))
1419+
prop = self._get_property(HYPERPARAMETERS_PROP, dict[str, Any])
1420+
return Hyperparameters(prop) if prop is not None else None
15131421

1514-
@mlm_name.setter
1515-
def mlm_name(self, v: str) -> None:
1516-
self._set_property(NAME_PROP, v)
1422+
@hyperparameters.setter
1423+
def hyperparameters(self, v: Hyperparameters | None) -> None:
1424+
self._set_property(HYPERPARAMETERS_PROP, v.to_dict() if v is not None else None)
1425+
1426+
1427+
class _IncludedInAssetProps(PropertiesExtension):
1428+
properties: dict[str, Any]
15171429

15181430
@property
15191431
def architecture(self) -> str:
@@ -1666,49 +1578,146 @@ def accelerator_count(self) -> int | None:
16661578
def accelerator_count(self, v: int | None) -> None:
16671579
self._set_property(ACCELERATOR_COUNT_PROP, v)
16681580

1669-
@property
1670-
def input(self) -> list[ModelInput]:
1671-
"""
1672-
Get or set the required input property
1673-
"""
1674-
return [
1675-
ModelInput(inp)
1676-
for inp in get_required(
1677-
self._get_property(INPUT_PROP, list[dict[str, Any]]), self, INPUT_PROP
1678-
)
1679-
]
16801581

1681-
@input.setter
1682-
def input(self, v: list[ModelInput]) -> None:
1683-
self._set_property(INPUT_PROP, [inp.to_dict() for inp in v])
1582+
class MLMExtension(
1583+
Generic[T],
1584+
_ExcludedFromAssetProps,
1585+
_IncludedInAssetProps,
1586+
ExtensionManagementMixin[pystac.Item | pystac.Collection],
1587+
):
1588+
"""An abstract class that can be used to extend to properties of an
1589+
:class:`pystac.Item` or :class:`pystac.Collection` with properties from the
1590+
:stac-ext:`Machine Learning Model Extension <mlm>`.
16841591
1685-
@property
1686-
def output(self) -> list[ModelOutput]:
1592+
This class can be used to extend :class:`pystac.Item`, :class:`pystac.Collection`
1593+
and :class:`pystac.ItemAssetDefinition`. For extending :class:`pystac.Asset`, use
1594+
either :class:`~AssetGeneralMLMExtension`: or :class:`AssetDetailedMLMExtension`.
1595+
"""
1596+
1597+
name: Literal["mlm"] = "mlm"
1598+
properties: dict[str, Any]
1599+
1600+
def apply(
1601+
self,
1602+
name: str,
1603+
architecture: str,
1604+
tasks: list[TaskType],
1605+
input: list[ModelInput],
1606+
output: list[ModelOutput],
1607+
framework: str | None = None,
1608+
framework_version: str | None = None,
1609+
memory_size: int | None = None,
1610+
total_parameters: int | None = None,
1611+
pretrained: bool | None = None,
1612+
pretrained_source: str | None = None,
1613+
batch_size_suggestion: int | None = None,
1614+
accelerator: AcceleratorType | None = None,
1615+
accelerator_constrained: bool | None = None,
1616+
accelerator_summary: str | None = None,
1617+
accelerator_count: int | None = None,
1618+
hyperparameters: Hyperparameters | None = None,
1619+
*args: Any,
1620+
**kwargs: Any,
1621+
) -> None:
16871622
"""
1688-
Get or set the required output property
1623+
Sets the properties of a new MLMExtension
1624+
1625+
Args:
1626+
name: name for the model
1627+
architecture: A generic and well established architecture name of the model
1628+
tasks: Specifies the Machine Learning tasks for which the model can be
1629+
used for
1630+
input: Describes the transformation between the EO data and the model input
1631+
output: Describes each model output and how to interpret it.
1632+
framework: Framework used to train the model
1633+
framework_version: The ``framework`` library version
1634+
memory_size: The in-memory size of the model on the accelerator during
1635+
inference (bytes)
1636+
total_parameters: Total number of model parameters, including trainable and
1637+
non-trainable parameters.
1638+
pretrained: Indicates if the model was pretrained. If the model was
1639+
pretrained, consider providing ``pretrained_source`` if it is known
1640+
pretrained_source: The source of the pretraining.
1641+
batch_size_suggestion: A suggested batch size for the accelerator and
1642+
summarized hardware.
1643+
accelerator: The intended computational hardware that runs inference
1644+
accelerator_constrained: Indicates if the intended ``accelerator`` is the
1645+
only accelerator that can run inference
1646+
accelerator_summary: A high level description of the ``accelerator``
1647+
accelerator_count: A minimum amount of ``accelerator`` instances required to
1648+
run the model
1649+
hyperparameters: Additional hyperparameters relevant for the model
1650+
*args: Unused (no effect, only here for signature compliance with apply
1651+
method in derived classes
1652+
**kwargs: Unused (no effect, only here for signature compliance with apply
1653+
method in derived classes
16891654
"""
1690-
return [
1691-
ModelOutput(outp)
1692-
for outp in get_required(
1693-
self._get_property(OUTPUT_PROP, list[dict[str, Any]]), self, OUTPUT_PROP
1694-
)
1695-
]
1655+
self.mlm_name = name
1656+
self.architecture = architecture
1657+
self.tasks = tasks
1658+
self.input = input
1659+
self.output = output
1660+
self.framework = framework
1661+
self.framework_version = framework_version
1662+
self.memory_size = memory_size
1663+
self.total_parameters = total_parameters
1664+
self.pretrained = pretrained
1665+
self.pretrained_source = pretrained_source
1666+
self.batch_size_suggestion = batch_size_suggestion
1667+
self.accelerator = accelerator
1668+
self.accelerator_constrained = accelerator_constrained
1669+
self.accelerator_summary = accelerator_summary
1670+
self.accelerator_count = accelerator_count
1671+
self.hyperparameters = hyperparameters
16961672

1697-
@output.setter
1698-
def output(self, v: list[ModelOutput]) -> None:
1699-
self._set_property(OUTPUT_PROP, [outp.to_dict() for outp in v])
1673+
@classmethod
1674+
def get_schema_uri(cls) -> str:
1675+
"""
1676+
Retrieves this extension's schema URI
17001677
1701-
@property
1702-
def hyperparameters(self) -> Hyperparameters | None:
1678+
Returns:
1679+
str: the schema URI
17031680
"""
1704-
Get or set the hyperparameters property
1681+
return SCHEMA_URI_PATTERN.format(version=DEFAULT_VERSION)
1682+
1683+
@classmethod
1684+
def ext(cls, obj: T, add_if_missing: bool = False) -> MLMExtension[T]:
17051685
"""
1706-
prop = self._get_property(HYPERPARAMETERS_PROP, dict[str, Any])
1707-
return Hyperparameters(prop) if prop is not None else None
1686+
Extend a STAC object (``obj``) with the MLMExtension
17081687
1709-
@hyperparameters.setter
1710-
def hyperparameters(self, v: Hyperparameters | None) -> None:
1711-
self._set_property(HYPERPARAMETERS_PROP, v.to_dict() if v is not None else None)
1688+
Args:
1689+
obj: The STAC object to be extended.
1690+
add_if_missing: Defines whether this extension's URI should be added to
1691+
this object's (or its parent's) list of extensions if it is not already
1692+
listed there.
1693+
1694+
Returns:
1695+
MLMExtension[T]: The extended object
1696+
1697+
Raises:
1698+
TypeError: When a :class:`pystac.Asset` object is passed as the
1699+
`obj` parameter
1700+
pystac.ExtensionTypeError: When any unsupported object is passed as the
1701+
`obj` parameter. If you see this extension in this context, please
1702+
raise an issue on github.
1703+
"""
1704+
if isinstance(obj, pystac.Item):
1705+
cls.ensure_has_extension(obj, add_if_missing)
1706+
return cast(MLMExtension[T], ItemMLMExtension(obj))
1707+
elif isinstance(obj, pystac.Collection):
1708+
cls.ensure_has_extension(obj, add_if_missing)
1709+
return cast(MLMExtension[T], CollectionMLMExtension(obj))
1710+
elif isinstance(obj, pystac.ItemAssetDefinition):
1711+
cls.ensure_owner_has_extension(obj, add_if_missing)
1712+
return cast(MLMExtension[T], ItemAssetMLMExtension(obj))
1713+
elif isinstance(obj, pystac.Asset):
1714+
raise TypeError(
1715+
"This class cannot be used to extend STAC objects of type Assets. "
1716+
"To extend Asset objects, use either AssetGeneralMLMExtension or "
1717+
"AssetDetailedMLMExtension"
1718+
)
1719+
else:
1720+
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
17121721

17131722
def to_dict(self) -> dict[str, Any]:
17141723
"""

0 commit comments

Comments
 (0)