Skip to content

Commit 14006e6

Browse files
committed
removed forbidden properties from detailled asset tests
1 parent 04f6b32 commit 14006e6

File tree

1 file changed

+1
-69
lines changed

1 file changed

+1
-69
lines changed

tests/extensions/test_mlm.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pystac.extensions.classification import Classification
1313
from pystac.extensions.mlm import (
1414
ARCHITECTURE_PROP,
15-
NAME_PROP,
1615
TASKS_PROP,
1716
AcceleratorType,
1817
AssetDetailedMLMExtension,
@@ -770,20 +769,16 @@ def test_add_to_asset(plain_item: Item) -> None:
770769
MLMExtension.ext(plain_item, add_if_missing=True)
771770
asset = plain_item.assets["analytic"]
772771

773-
assert NAME_PROP not in asset.extra_fields.keys()
774772
assert ARCHITECTURE_PROP not in asset.extra_fields.keys()
775773
assert TASKS_PROP not in asset.extra_fields.keys()
776774

777775
asset_ext = AssetDetailedMLMExtension.ext(asset)
778-
asset_ext.mlm_name = "asdf"
779776
asset_ext.architecture = "ResNet"
780777
asset_ext.tasks = [TaskType.CLASSIFICATION]
781778

782-
assert NAME_PROP in asset.extra_fields.keys()
783779
assert ARCHITECTURE_PROP in asset.extra_fields.keys()
784780
assert TASKS_PROP in asset.extra_fields.keys()
785781

786-
assert asset.extra_fields[NAME_PROP] == "asdf"
787782
assert asset.extra_fields[ARCHITECTURE_PROP] == "ResNet"
788783
assert asset.extra_fields[TASKS_PROP] == [TaskType.CLASSIFICATION]
789784

@@ -866,33 +861,15 @@ def test_to_dict_asset_generic() -> None:
866861

867862

868863
def test_add_to_detailled_asset() -> None:
869-
model_input = ModelInput.create(
870-
name="model",
871-
bands=["B02"],
872-
input=InputStructure.create(
873-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
874-
),
875-
)
876-
model_output = ModelOutput.create(
877-
name="output",
878-
tasks=[TaskType.CLASSIFICATION],
879-
result=ResultStructure.create(
880-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
881-
),
882-
)
883-
884864
asset = pystac.Asset(
885865
href="http://example.com/test.tiff",
886866
title="image",
887867
description="asdf",
888868
media_type="application/tiff",
889869
roles=["mlm:model"],
890870
extra_fields={
891-
"mlm:name": "asdf",
892871
"mlm:architecture": "ResNet",
893872
"mlm:tasks": [TaskType.CLASSIFICATION],
894-
"mlm:input": [model_input.to_dict()],
895-
"mlm:output": [model_output.to_dict()],
896873
"mlm:artifact_type": "foo",
897874
"mlm:compile_method": "bar",
898875
"mlm:entrypoint": "baz",
@@ -901,11 +878,8 @@ def test_add_to_detailled_asset() -> None:
901878

902879
asset_ext = AssetDetailedMLMExtension.ext(asset, add_if_missing=False)
903880

904-
assert asset_ext.mlm_name == "asdf"
905881
assert asset_ext.architecture == "ResNet"
906882
assert asset_ext.tasks == [TaskType.CLASSIFICATION]
907-
assert asset_ext.input == [model_input]
908-
assert asset_ext.output == [model_output]
909883
assert asset_ext.artifact_type == "foo"
910884
assert asset_ext.compile_method == "bar"
911885
assert asset_ext.entrypoint == "baz"
@@ -930,7 +904,7 @@ def test_correct_asset_extension_is_used() -> None:
930904
asset = Asset("https://example.com")
931905
assert isinstance(asset.ext.mlm, AssetGeneralMLMExtension)
932906

933-
asset.extra_fields["mlm:name"] = "asdf"
907+
asset.extra_fields["mlm:architecture"] = "ResNet"
934908
assert isinstance(asset.ext.mlm, AssetDetailedMLMExtension)
935909

936910

@@ -951,37 +925,16 @@ def test_apply_detailled_asset() -> None:
951925
)
952926
asset_ext = AssetDetailedMLMExtension.ext(asset, add_if_missing=False)
953927

954-
model_input = ModelInput.create(
955-
name="model",
956-
bands=["B02"],
957-
input=InputStructure.create(
958-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
959-
),
960-
)
961-
model_output = ModelOutput.create(
962-
name="output",
963-
tasks=[TaskType.CLASSIFICATION],
964-
result=ResultStructure.create(
965-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
966-
),
967-
)
968-
969928
asset_ext.apply(
970-
"asdf",
971929
"ResNet",
972930
[TaskType.CLASSIFICATION],
973-
[model_input],
974-
[model_output],
975931
artifact_type="foo",
976932
compile_method="bar",
977933
entrypoint="baz",
978934
)
979935

980-
assert asset_ext.mlm_name == "asdf"
981936
assert asset_ext.architecture == "ResNet"
982937
assert asset_ext.tasks == [TaskType.CLASSIFICATION]
983-
assert asset_ext.input == [model_input]
984-
assert asset_ext.output == [model_output]
985938
assert asset_ext.artifact_type == "foo"
986939
assert asset_ext.compile_method == "bar"
987940
assert asset_ext.entrypoint == "baz"
@@ -997,38 +950,17 @@ def test_to_dict_detailed_asset() -> None:
997950
)
998951
asset_ext = AssetDetailedMLMExtension.ext(asset, add_if_missing=False)
999952

1000-
model_input = ModelInput.create(
1001-
name="model",
1002-
bands=["B02"],
1003-
input=InputStructure.create(
1004-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
1005-
),
1006-
)
1007-
model_output = ModelOutput.create(
1008-
name="output",
1009-
tasks=[TaskType.CLASSIFICATION],
1010-
result=ResultStructure.create(
1011-
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
1012-
),
1013-
)
1014-
1015953
asset_ext.apply(
1016-
"asdf",
1017954
"ResNet",
1018955
[TaskType.CLASSIFICATION],
1019-
[model_input],
1020-
[model_output],
1021956
artifact_type="foo",
1022957
compile_method="bar",
1023958
entrypoint="baz",
1024959
)
1025960

1026961
d = {
1027-
"mlm:name": "asdf",
1028962
"mlm:architecture": "ResNet",
1029963
"mlm:tasks": [TaskType.CLASSIFICATION],
1030-
"mlm:input": [model_input.to_dict()],
1031-
"mlm:output": [model_output.to_dict()],
1032964
"mlm:artifact_type": "foo",
1033965
"mlm:compile_method": "bar",
1034966
"mlm:entrypoint": "baz",

0 commit comments

Comments
 (0)