Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pystac/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Asset:
extra_fields : Optional, additional fields for this asset. This is used
by extensions as a way to serialize and deserialize properties on asset
object JSON.
stac_extensions : Optional, a list of schema URIs for STAC Extensions
implemented by this STAC Asset.
"""

href: str
Expand Down Expand Up @@ -64,6 +66,9 @@ class Asset:
"""Optional, additional fields for this asset. This is used by extensions as a
way to serialize and deserialize properties on asset object JSON."""

stac_extensions: List[str]
"""A list of schema URIs for STAC Extensions implemented by this STAC Asset."""

def __init__(
self,
href: str,
Expand All @@ -79,6 +84,7 @@ def __init__(
self.media_type = media_type
self.roles = roles
self.extra_fields = extra_fields or {}
self.stac_extensions = None

# The Item which owns this Asset.
self.owner = None
Expand Down
16 changes: 15 additions & 1 deletion pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,22 @@ def to_dict(
)
d["extent"] = self.extent.to_dict()
d["license"] = self.license

# we use the fact that in recent Python versions, dict keys are ordered
# by default
stac_extensions: Optional[Dict[str, None]] = None
if self.stac_extensions:
d["stac_extensions"] = self.stac_extensions
stac_extensions = dict.fromkeys(self.stac_extensions)

for asset in self.assets.values():
if stac_extensions and asset.stac_extensions:
stac_extensions.update(dict.fromkeys(asset.stac_extensions))
elif asset.stac_extensions:
stac_extensions = dict.fromkeys(asset.stac_extensions)

if stac_extensions is not None:
d["stac_extensions"] = list(stac_extensions.keys())

if self.keywords:
d["keywords"] = self.keywords
if self.providers:
Expand Down
50 changes: 36 additions & 14 deletions pystac/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Type,
TypeVar,
Union,
cast,
Protocol,
)

import pystac
Expand Down Expand Up @@ -96,7 +96,11 @@ def _set_property(
self.properties[prop_name] = v


S = TypeVar("S", bound=pystac.STACObject)
class STACExtendable(Protocol):
stac_extensions: List[str]


S = TypeVar("S", bound=STACExtendable)


class ExtensionManagementMixin(Generic[S], ABC):
Expand Down Expand Up @@ -150,6 +154,21 @@ def has_extension(cls, obj: S) -> bool:
:attr:`pystac.STACObject.stac_extensions` for this extension's schema URI."""
schema_startswith = VERSION_REGEX.split(cls.get_schema_uri())[0] + "/"

if isinstance(obj, (pystac.Item, pystac.Collection)):
for asset in obj.assets.values():
if asset.stac_extensions is not None and any(
uri.startswith(schema_startswith)
for uri in asset.stac_extensions
):
return True

elif isinstance(obj, pystac.Asset):
if obj.owner and obj.owner.stac_extensions is not None and any(
uri.startswith(schema_startswith)
for uri in obj.owner.stac_extensions
):
return True

return obj.stac_extensions is not None and any(
uri.startswith(schema_startswith) for uri in obj.stac_extensions
)
Expand All @@ -173,15 +192,13 @@ def validate_owner_has_extension(
STACError : If ``add_if_missing`` is ``True`` and ``asset.owner`` is
``None``.
"""
if asset.owner is None:
if add_if_missing:
raise pystac.STACError(
"Attempted to use add_if_missing=True for an Asset with no owner. "
"Use Asset.set_owner or set add_if_missing=False."
)
else:
return
return cls.ensure_has_extension(cast(S, asset.owner), add_if_missing)

warnings.warn(
"validate_owner_has_extension is deprecated and will be removed in v2.0. "
"Use ensure_has_extension instead",
DeprecationWarning,
)
return cls.ensure_has_extension(asset, add_if_missing)

@classmethod
def validate_has_extension(cls, obj: S, add_if_missing: bool = False) -> None:
Expand Down Expand Up @@ -222,10 +239,15 @@ def ensure_has_extension(cls, obj: S, add_if_missing: bool = False) -> None:
if add_if_missing:
cls.add_to(obj)

if isinstance(obj, pystac.Asset):
cls.ensure_has_extension(obj.owner)

if not cls.has_extension(obj):
raise pystac.ExtensionNotImplemented(
f"Could not find extension schema URI {cls.get_schema_uri()} in object."
)
if not obj.owner or not cls.has_extension(obj.owner):
raise pystac.ExtensionNotImplemented(
f"Could not find extension schema URI {cls.get_schema_uri()} "
"in object."
)

@classmethod
def _ext_error_message(cls, obj: Any) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> ClassificationExtension[T]
cls.ensure_has_extension(obj, add_if_missing)
return cast(ClassificationExtension[T], ItemClassificationExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(ClassificationExtension[T], AssetClassificationExtension(obj))
elif isinstance(obj, item_assets.AssetDefinition):
cls.ensure_has_extension(
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> DatacubeExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(DatacubeExtension[T], ItemDatacubeExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(DatacubeExtension[T], AssetDatacubeExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> EOExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(EOExtension[T], ItemEOExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(EOExtension[T], AssetEOExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> FileExtension:
This extension can be applied to instances of :class:`~pystac.Asset`.
"""
if isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cls(obj)
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> PointcloudExtension[T]:
raise pystac.ExtensionTypeError(
"Pointcloud extension does not apply to Collection Assets."
)
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(PointcloudExtension[T], AssetPointcloudExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> ProjectionExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(ProjectionExtension[T], ItemProjectionExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(ProjectionExtension[T], AssetProjectionExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> RasterExtension[T]:
pystac.ExtensionTypeError : If an invalid object type is passed.
"""
if isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(RasterExtension[T], AssetRasterExtension(obj))
elif isinstance(obj, item_assets.AssetDefinition):
cls.ensure_has_extension(
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/sar.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> SarExtension[T]:
raise pystac.ExtensionTypeError(
"SAR extension does not apply to Collection Assets."
)
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(SarExtension[T], AssetSarExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> SatExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(SatExtension[T], ItemSatExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(SatExtension[T], AssetSatExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> StorageExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(StorageExtension[T], ItemStorageExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(StorageExtension[T], AssetStorageExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> TableExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(TableExtension[T], ItemTableExtension(obj))
if isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(TableExtension[T], AssetTableExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/timestamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> TimestampsExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(TimestampsExtension[T], ItemTimestampsExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(TimestampsExtension[T], AssetTimestampsExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> ViewExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(ViewExtension[T], ItemViewExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(ViewExtension[T], AssetViewExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
2 changes: 1 addition & 1 deletion pystac/extensions/xarray_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> XarrayAssetsExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return ItemXarrayAssetsExtension(obj)
if isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return AssetXarrayAssetsExtension(obj)
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))
Expand Down
17 changes: 16 additions & 1 deletion pystac/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,22 @@ def to_dict(
if self.bbox is not None:
d["bbox"] = self.bbox

if self.stac_extensions is not None:
# we use the fact that in recent Python versions, dict keys are ordered
# by default
stac_extensions: Optional[Dict[str, None]] = None
if self.stac_extensions:
stac_extensions = dict.fromkeys(self.stac_extensions)

for asset in self.assets.values():
if stac_extensions and asset.stac_extensions:
stac_extensions.update(dict.fromkeys(asset.stac_extensions))
elif asset.stac_extensions:
stac_extensions = dict.fromkeys(asset.stac_extensions)

if stac_extensions is not None:
d["stac_extensions"] = list(stac_extensions.keys())

if stac_extensions is not None:
d["stac_extensions"] = self.stac_extensions

if self.collection_id:
Expand Down
10 changes: 9 additions & 1 deletion tests/extensions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_schema_uri(cls) -> str:
@classmethod
def ext(cls, obj: T, add_if_missing: bool = False) -> "CustomExtension[T]":
if isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_has_extension(obj, add_if_missing)
return cast(CustomExtension[T], AssetCustomExtension(obj))
if isinstance(obj, pystac.Item):
cls.ensure_has_extension(obj, add_if_missing)
Expand Down Expand Up @@ -152,6 +152,14 @@ def test_add_to_catalog(self) -> None:
catalog_as_dict = catalog.to_dict()
assert catalog_as_dict["test:prop"] == "foo"

def test_add_to_asset_no_owner(self) -> None:
asset = Asset("http://pystac.test/asset.tif")
custom = CustomExtension.ext(asset, add_if_missing=True)
assert CustomExtension.has_extension(asset)
custom.apply("bar")
asset_as_dict = asset.to_dict()
assert asset_as_dict["test:prop"] == "bar"

def test_add_to_collection(self) -> None:
collection = Collection(
"an-id",
Expand Down