Skip to content

Commit 2999925

Browse files
committed
Add Collection.from_items
1 parent 542b9fb commit 2999925

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

pystac/collection.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,76 @@ def from_dict(
710710

711711
return collection
712712

713+
@classmethod
714+
def from_items(
715+
cls: type[C],
716+
items: Iterable[Item] | pystac.ItemCollection,
717+
*,
718+
id: str | None = None,
719+
strategy: HrefLayoutStrategy | None = None,
720+
) -> C:
721+
"""Create a :class:`Collection` from items or an :class:`ItemCollection`.
722+
723+
Will try to pull collection attributes from :attr:`ItemCollection.extra_fields`
724+
and items when possible.
725+
726+
Args:
727+
items : Iterable of :class:`~pystac.Item` instances to include in the
728+
:class:`ItemCollection`. This can be an :class:`pystac.ItemCollection`.
729+
id : Identifier for the collection. If not set, must be available on the
730+
items and they must all match.
731+
strategy : The layout strategy to use for setting the
732+
HREFs of the catalog child objections and items.
733+
If not provided, it will default to strategy of the parent and fallback
734+
to :class:`~pystac.layout.BestPracticesLayoutStrategy`.
735+
"""
736+
737+
def extract(attr: str) -> Any:
738+
"""Extract attrs from items or item.properties as long as they all match"""
739+
value = None
740+
values = {getattr(item, attr, None) for item in items}
741+
if len(values) == 1:
742+
value = next(iter(values))
743+
if value is None:
744+
values = {item.properties.get(attr, None) for item in items}
745+
if len(values) == 1:
746+
value = next(iter(values))
747+
return value
748+
749+
if isinstance(items, pystac.ItemCollection):
750+
extra_fields = deepcopy(items.extra_fields)
751+
links = extra_fields.pop("links", {})
752+
providers = extra_fields.pop("providers", None)
753+
if providers is not None:
754+
providers = [pystac.Provider.from_dict(p) for p in providers]
755+
else:
756+
extra_fields = {}
757+
links = {}
758+
providers = []
759+
760+
id = id or extract("collection_id")
761+
if id is None:
762+
raise ValueError(
763+
"Collection id must be defined. Either by specifying collection_id "
764+
"on every item, or as a keyword argument to this function."
765+
)
766+
767+
collection = cls(
768+
id=id,
769+
description=extract("description"),
770+
extent=Extent.from_items(items),
771+
title=extract("title"),
772+
providers=providers,
773+
extra_fields=extra_fields,
774+
strategy=strategy,
775+
)
776+
collection.add_items(items)
777+
778+
for link in links:
779+
collection.add_link(Link.from_dict(link))
780+
781+
return collection
782+
713783
def get_item(self, id: str, recursive: bool = False) -> Item | None:
714784
"""Returns an item with a given ID.
715785

tests/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from pystac import Asset, Catalog, Collection, Item, Link
12+
from pystac import Asset, Catalog, Collection, Item, ItemCollection, Link
1313

1414
from .utils import ARBITRARY_BBOX, ARBITRARY_EXTENT, ARBITRARY_GEOM, TestCases
1515

@@ -76,6 +76,18 @@ def sample_item() -> Item:
7676
return Item.from_file(TestCases.get_path("data-files/item/sample-item.json"))
7777

7878

79+
@pytest.fixture
80+
def sample_item_collection() -> ItemCollection:
81+
return ItemCollection.from_file(
82+
TestCases.get_path("data-files/item-collection/sample-item-collection.json")
83+
)
84+
85+
86+
@pytest.fixture
87+
def sample_items(sample_item_collection: ItemCollection) -> list[Item]:
88+
return list(sample_item_collection)
89+
90+
7991
@pytest.fixture(scope="function")
8092
def tmp_asset(tmp_path: Path) -> Asset:
8193
"""Copy the entirety of test-case-2 to tmp and"""

tests/test_collection.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Collection,
2020
Extent,
2121
Item,
22+
ItemCollection,
2223
Provider,
2324
SpatialExtent,
2425
TemporalExtent,
@@ -711,3 +712,58 @@ def test_permissive_temporal_extent_deserialization(collection: Collection) -> N
711712
]["interval"][0]
712713
with pytest.warns(UserWarning):
713714
Collection.from_dict(collection_dict)
715+
716+
717+
@pytest.mark.parametrize("fixture_name", ("sample_item_collection", "sample_items"))
718+
def test_from_items(fixture_name: str, request: pytest.FixtureRequest) -> None:
719+
items = request.getfixturevalue(fixture_name)
720+
collection = Collection.from_items(items)
721+
722+
for item in items:
723+
assert collection.id == item.collection_id
724+
assert collection.extent.spatial.bboxes[0][0] <= item.bbox[0]
725+
assert collection.extent.spatial.bboxes[0][1] <= item.bbox[1]
726+
assert collection.extent.spatial.bboxes[0][2] >= item.bbox[2]
727+
assert collection.extent.spatial.bboxes[0][3] >= item.bbox[3]
728+
729+
start = collection.extent.temporal.intervals[0][0]
730+
end = collection.extent.temporal.intervals[0][1]
731+
assert start and start <= datetime.fromisoformat(
732+
item.properties["start_datetime"]
733+
)
734+
assert end and end >= datetime.fromisoformat(item.properties["end_datetime"])
735+
736+
if isinstance(items, ItemCollection):
737+
expected = {(link["rel"], link["href"]) for link in items.extra_fields["links"]}
738+
actual = {(link.rel, link.href) for link in collection.links}
739+
assert expected.issubset(actual)
740+
741+
742+
def test_from_items_pulls_from_properties() -> None:
743+
item1 = Item(
744+
id="test-item-1",
745+
geometry=ARBITRARY_GEOM,
746+
bbox=[-10, -20, 0, -10],
747+
datetime=datetime(2000, 2, 1, 12, 0, 0, 0, tzinfo=tz.UTC),
748+
collection="test-collection-1",
749+
properties={"title": "Test Item", "description": "Extra words describing"},
750+
)
751+
collection = Collection.from_items([item1])
752+
assert collection.id == item1.collection_id
753+
assert collection.title == item1.properties["title"]
754+
assert collection.description == item1.properties["description"]
755+
756+
757+
def test_from_items_without_collection_id() -> None:
758+
item1 = Item(
759+
id="test-item-1",
760+
geometry=ARBITRARY_GEOM,
761+
bbox=[-10, -20, 0, -10],
762+
datetime=datetime(2000, 2, 1, 12, 0, 0, 0, tzinfo=tz.UTC),
763+
properties={},
764+
)
765+
with pytest.raises(ValueError, match="Collection id must be defined."):
766+
Collection.from_items([item1])
767+
768+
collection = Collection.from_items([item1], id="test-collection")
769+
assert collection.id == "test-collection"

0 commit comments

Comments
 (0)