Skip to content

Commit bd81c78

Browse files
committed
improved test coverage
1 parent 3cefe9f commit bd81c78

File tree

1 file changed

+169
-19
lines changed

1 file changed

+169
-19
lines changed

tests/extensions/test_mlm.py

Lines changed: 169 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def test_model_band() -> None:
7575

7676
assert c.to_dict() == d
7777

78+
with pytest.raises(NotImplementedError):
79+
_ = c == "blah"
80+
7881

7982
def test_model_props() -> None:
8083
c = ModelBand({})
@@ -100,6 +103,9 @@ def test_processing_expression() -> None:
100103

101104
assert c.to_dict() == d
102105

106+
with pytest.raises(NotImplementedError):
107+
_ = c == "blah"
108+
103109

104110
def test_processint_expression_props() -> None:
105111
c = ProcessingExpression({})
@@ -116,25 +122,43 @@ def test_processint_expression_props() -> None:
116122
assert c.expression == "B01 + B02"
117123

118124

119-
def test_valuescaling_object() -> None:
125+
@pytest.mark.parametrize(
126+
"scale_type, min_val, max_val, mean, stddev, value, format_val, expression",
127+
[
128+
(ValueScalingType.MIN_MAX, 0, 4, 3, 3, 4, "asdf", "asdf"),
129+
(ValueScalingType.MIN_MAX, 0.2, 4.3, 3.13, 3.2, 4.5, "asdf", "asdf"),
130+
(ValueScalingType.MIN_MAX, 0, 4, None, None, None, None, None),
131+
(ValueScalingType.SCALE, None, None, None, None, 2, None, None),
132+
],
133+
)
134+
def test_valuescaling_object(
135+
scale_type: ValueScalingType,
136+
min_val: int | float | None,
137+
max_val: int | float | None,
138+
mean: int | float | None,
139+
stddev: int | float | None,
140+
value: int | float | None,
141+
format_val: str | None,
142+
expression: str | None,
143+
) -> None:
120144
c = ValueScaling.create(
121-
ValueScalingType.MIN_MAX,
122-
minimum=0,
123-
maximum=4,
124-
mean=3,
125-
stddev=3.141,
126-
value=4,
127-
format="asdf",
128-
expression="asdf",
129-
)
130-
assert c.type == ValueScalingType.MIN_MAX
131-
assert c.minimum == 0
132-
assert c.maximum == 4
133-
assert c.mean == 3
134-
assert c.stddev == 3.141
135-
assert c.value == 4
136-
assert c.format == "asdf"
137-
assert c.expression == "asdf"
145+
scale_type,
146+
minimum=min_val,
147+
maximum=max_val,
148+
mean=mean,
149+
stddev=stddev,
150+
value=value,
151+
format=format_val,
152+
expression=expression,
153+
)
154+
assert c.type == scale_type
155+
assert c.minimum == min_val
156+
assert c.maximum == max_val
157+
assert c.mean == mean
158+
assert c.stddev == stddev
159+
assert c.value == value
160+
assert c.format == format_val
161+
assert c.expression == expression
138162

139163
with pytest.raises(STACError):
140164
ValueScaling.create(
@@ -144,6 +168,9 @@ def test_valuescaling_object() -> None:
144168
with pytest.raises(STACError):
145169
ValueScaling.create(ValueScalingType.Z_SCORE, mean=3) # missing param stddev
146170

171+
with pytest.raises(NotImplementedError):
172+
_ = c == "blah"
173+
147174

148175
def test_valuescaling_required_params() -> None:
149176
assert ValueScaling.get_required_props(ValueScalingType.MIN_MAX) == [
@@ -178,6 +205,9 @@ def test_input_structure() -> None:
178205
assert c.dim_order == ["batch", "channel", "width", "height"]
179206
assert c.data_type == DataType.FLOAT64
180207

208+
with pytest.raises(NotImplementedError):
209+
_ = c == "blah"
210+
181211

182212
def test_model_input_structure_props() -> None:
183213
c = InputStructure({})
@@ -269,6 +299,9 @@ def test_model_input(
269299
assert "resize_type" in d_reverse
270300
assert "pre_processing_function" in d_reverse
271301

302+
with pytest.raises(NotImplementedError):
303+
_ = c == "blah"
304+
272305

273306
def test_model_input_props() -> None:
274307
c = ModelInput({})
@@ -314,6 +347,9 @@ def test_result_structure() -> None:
314347
assert c.dim_order == ["time", "width", "height"]
315348
assert c.data_type == DataType.FLOAT64
316349

350+
with pytest.raises(NotImplementedError):
351+
_ = c == "blah"
352+
317353

318354
def test_result_structure_props() -> None:
319355
c = ResultStructure({})
@@ -361,6 +397,9 @@ def test_model_output(post_proc_func: ProcessingExpression | None) -> None:
361397
]
362398
assert c.post_processing_function == post_proc_func
363399

400+
with pytest.raises(NotImplementedError):
401+
_ = c == "blah"
402+
364403

365404
def test_model_output_props() -> None:
366405
c = ModelOutput({})
@@ -408,6 +447,9 @@ def test_hyperparameters() -> None:
408447
assert key in c.to_dict()
409448
assert c.to_dict()[key] == d[key]
410449

450+
with pytest.raises(NotImplementedError):
451+
_ = c == "blah"
452+
411453

412454
def teest_get_schema_uri(basic_mlm_item: Item) -> None:
413455
with pytest.raises(DeprecationWarning):
@@ -571,6 +613,34 @@ def test_apply(plain_item: Item) -> None:
571613
and MLMExtension.ext(plain_item).hyperparameters == hyp
572614
)
573615

616+
d = {
617+
**plain_item.properties,
618+
"mlm:name": "asdf",
619+
"mlm:architecture": "ResNet",
620+
"mlm:tasks": [TaskType.CLASSIFICATION],
621+
"mlm:framework": "PyTorch",
622+
"mlm:framework_version": "1.2.3",
623+
"mlm:memory_size": 3,
624+
"mlm:total_parameters": 123,
625+
"mlm:pretrained": True,
626+
"mlm:pretrained_source": "asdfasdfasdf",
627+
"mlm:batch_size_suggestion": 32,
628+
"mlm:accelerator": AcceleratorType.CUDA,
629+
"mlm:accelerator_constrained": False,
630+
"mlm:accelerator_summary": "This is the summary",
631+
"mlm:accelerator_count": 1,
632+
"mlm:input": [inp.to_dict() for inp in model_input],
633+
"mlm:output": [out.to_dict() for out in model_output],
634+
"mlm:hyperparameters": hyp.to_dict(),
635+
}
636+
637+
assert MLMExtension.ext(plain_item).to_dict() == d
638+
639+
640+
def test_apply_wrong_object() -> None:
641+
with pytest.raises(pystac.ExtensionTypeError):
642+
_ = MLMExtension.ext(1, False)
643+
574644

575645
def test_to_from_dict(basic_item_dict: dict[str, Any]) -> None:
576646
d1 = deepcopy(basic_item_dict)
@@ -774,6 +844,25 @@ def test_apply_generic_asset() -> None:
774844
assert asset_ext.entrypoint == "baz"
775845

776846

847+
def test_to_dict_asset_generic() -> None:
848+
asset = pystac.Asset(
849+
href="http://example.com/test.tiff",
850+
title="image",
851+
description="asdf",
852+
media_type="application/tiff",
853+
roles=["mlm:model"],
854+
)
855+
asset_ext = AssetGeneralMLMExtension.ext(asset, add_if_missing=False)
856+
asset_ext.apply(artifact_type="foo", compile_method="bar", entrypoint="baz")
857+
858+
d = {
859+
"mlm:artifact_type": "foo",
860+
"mlm:compile_method": "bar",
861+
"mlm:entrypoint": "baz",
862+
}
863+
assert asset_ext.to_dict() == d
864+
865+
777866
def test_add_to_detailled_asset() -> None:
778867
model_input = ModelInput.create(
779868
name="model",
@@ -819,6 +908,8 @@ def test_add_to_detailled_asset() -> None:
819908
assert asset_ext.compile_method == "bar"
820909
assert asset_ext.entrypoint == "baz"
821910

911+
assert repr(asset_ext) == f"<AssetDetailedMLMExtension Asset href={asset.href}>"
912+
822913

823914
def test_apply_detailled_asset() -> None:
824915
asset = pystac.Asset(
@@ -866,13 +957,69 @@ def test_apply_detailled_asset() -> None:
866957
assert asset_ext.entrypoint == "baz"
867958

868959

960+
def test_to_dict_detailed_asset() -> None:
961+
asset = pystac.Asset(
962+
href="http://example.com/test.tiff",
963+
title="image",
964+
description="asdf",
965+
media_type="application/tiff",
966+
roles=["mlm:model"],
967+
)
968+
asset_ext = AssetDetailedMLMExtension.ext(asset, add_if_missing=False)
969+
970+
model_input = ModelInput.create(
971+
name="model",
972+
bands=["B02"],
973+
input=InputStructure.create(
974+
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
975+
),
976+
)
977+
model_output = ModelOutput.create(
978+
name="output",
979+
tasks=[TaskType.CLASSIFICATION],
980+
result=ResultStructure.create(
981+
shape=[1], dim_order=["batch"], data_type=DataType.FLOAT64
982+
),
983+
)
984+
985+
asset_ext.apply(
986+
"asdf",
987+
"ResNet",
988+
[TaskType.CLASSIFICATION],
989+
[model_input],
990+
[model_output],
991+
artifact_type="foo",
992+
compile_method="bar",
993+
entrypoint="baz",
994+
)
995+
996+
d = {
997+
"mlm:name": "asdf",
998+
"mlm:architecture": "ResNet",
999+
"mlm:tasks": [TaskType.CLASSIFICATION],
1000+
"mlm:input": [model_input.to_dict()],
1001+
"mlm:output": [model_output.to_dict()],
1002+
"mlm:artifact_type": "foo",
1003+
"mlm:compile_method": "bar",
1004+
"mlm:entrypoint": "baz",
1005+
"mlm:accelerator": None,
1006+
"mlm:pretrained_source": None,
1007+
}
1008+
assert asset_ext.to_dict() == d
1009+
1010+
8691011
def test_item_asset_extension(mlm_collection: Collection) -> None:
8701012
assert mlm_collection.item_assets
8711013
item_asset = mlm_collection.item_assets["weights"]
872-
MLMExtension.ext(item_asset, add_if_missing=True)
1014+
item_asset_ext = MLMExtension.ext(item_asset, add_if_missing=True)
8731015
assert MLMExtension.get_schema_uri() in mlm_collection.stac_extensions
8741016
assert mlm_collection.item_assets["weights"].ext.has("mlm")
8751017

1018+
assert (
1019+
repr(item_asset_ext)
1020+
== f"<ItemAssetsMLMExtension ItemAssetDefinition={item_asset}"
1021+
)
1022+
8761023

8771024
def test_collection_extension(mlm_collection: Collection) -> None:
8781025
coll_ext = MLMExtension.ext(mlm_collection, add_if_missing=True)
@@ -881,6 +1028,9 @@ def test_collection_extension(mlm_collection: Collection) -> None:
8811028

8821029
coll_ext.mlm_name = "asdf"
8831030
assert coll_ext.mlm_name == "asdf"
1031+
assert (
1032+
repr(coll_ext) == f"<CollectionMLMExtension Collection id={mlm_collection.id}>"
1033+
)
8841034

8851035

8861036
def test_raise_exception_on_mlm_extension_and_asset() -> None:

0 commit comments

Comments
 (0)