Skip to content

Commit 473cee7

Browse files
committed
update stac-model with corresponding scaling definitions from JSON schema
1 parent 984ee77 commit 473cee7

File tree

3 files changed

+62
-34
lines changed

3 files changed

+62
-34
lines changed

stac_model/examples.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pystac.extensions.file import FileExtension
88

99
from stac_model.base import ProcessingExpression
10-
from stac_model.input import InputStructure, MLMStatistic, ModelInput
10+
from stac_model.input import InputStructure, ScalingObject, ModelInput
1111
from stac_model.output import MLMClassification, ModelOutput, ModelResult
1212
from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties
1313

@@ -63,21 +63,23 @@ def eurosat_resnet() -> ItemMLModelExtension:
6363
761.30323499,
6464
1231.58581042,
6565
]
66-
stats = [
67-
MLMStatistic(
68-
mean=mean,
69-
stddev=stddev,
66+
scaling = [
67+
cast(
68+
ScalingObject,
69+
dict(
70+
type="z-score",
71+
mean=mean,
72+
stddev=stddev,
73+
)
7074
)
7175
for mean, stddev in zip(stats_mean, stats_stddev)
7276
]
7377
model_input = ModelInput(
7478
name="13 Band Sentinel-2 Batch",
7579
bands=band_names,
7680
input=input_struct,
77-
norm_by_channel=True,
78-
norm_type="z-score",
7981
resize_type=None,
80-
statistics=stats,
82+
scaling=scaling,
8183
pre_processing_function=ProcessingExpression(
8284
format="python",
8385
expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn",

stac_model/input.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88

99
class InputStructure(MLMBaseModel):
10-
shape: List[Union[int, float]] = Field(min_items=1)
11-
dim_order: List[str] = Field(min_items=1)
10+
shape: List[Union[int, float]] = Field(min_length=1)
11+
dim_order: List[str] = Field(min_length=1)
1212
data_type: DataType
1313

1414
@model_validator(mode="after")
@@ -18,27 +18,56 @@ def validate_dimensions(self) -> Self:
1818
return self
1919

2020

21-
class MLMStatistic(MLMBaseModel): # FIXME: add 'Statistics' dep from raster extension (cases required to be triggered)
22-
minimum: Annotated[Optional[Number], OmitIfNone] = None
23-
maximum: Annotated[Optional[Number], OmitIfNone] = None
24-
mean: Annotated[Optional[Number], OmitIfNone] = None
25-
stddev: Annotated[Optional[Number], OmitIfNone] = None
26-
count: Annotated[Optional[int], OmitIfNone] = None
27-
valid_percent: Annotated[Optional[Number], OmitIfNone] = None
21+
class ScalingClipMin(MLMBaseModel):
22+
type: Literal["clip-min"] = "clip-min"
23+
minimum: Number
2824

2925

30-
NormalizeType: TypeAlias = Optional[
31-
Literal[
32-
"min-max",
33-
"z-score",
34-
"l1",
35-
"l2",
36-
"l2sqr",
37-
"hamming",
38-
"hamming2",
39-
"type-mask",
40-
"relative",
41-
"inf",
26+
class ScalingClipMax(MLMBaseModel):
27+
type: Literal["clip-max"] = "clip-max"
28+
maximum: Number
29+
30+
31+
class ScalingClip(ScalingClipMin, ScalingClipMax):
32+
type: Literal["clip"] = "clip"
33+
34+
35+
class ScalingMinMax(MLMBaseModel):
36+
type: Literal["min-max"] = "min-max"
37+
minimum: Number
38+
maximum: Number
39+
40+
41+
class ScalingZScore(MLMBaseModel):
42+
type: Literal["z-score"] = "z-score"
43+
mean: Number
44+
stddev: Number
45+
46+
47+
class ScalingOffset(MLMBaseModel):
48+
type: Literal["offset"] = "offset"
49+
value: Number
50+
51+
52+
class ScalingScale(MLMBaseModel):
53+
type: Literal["scale"] = "scale"
54+
value: Number
55+
56+
57+
class ScalingProcessingExpression(ProcessingExpression):
58+
type: Literal["processing"] = "processing"
59+
60+
61+
ScalingObject: TypeAlias = Optional[
62+
Union[
63+
ScalingMinMax,
64+
ScalingZScore,
65+
ScalingClip,
66+
ScalingClipMin,
67+
ScalingClipMax,
68+
ScalingOffset,
69+
ScalingScale,
70+
ScalingProcessingExpression,
4271
]
4372
]
4473

@@ -107,9 +136,6 @@ class ModelInput(MLMBaseModel):
107136
],
108137
)
109138
input: InputStructure
110-
norm_by_channel: Annotated[Optional[bool], OmitIfNone] = None
111-
norm_type: Annotated[Optional[NormalizeType], OmitIfNone] = None
112-
norm_clip: Annotated[Optional[List[Union[float, int]]], OmitIfNone] = None
139+
scaling: Annotated[Optional[List[ScalingObject]], OmitIfNone] = None
113140
resize_type: Annotated[Optional[ResizeType], OmitIfNone] = None
114-
statistics: Annotated[Optional[List[MLMStatistic]], OmitIfNone] = None
115141
pre_processing_function: Optional[ProcessingExpression] = None

tests/test_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_mlm_input_scaling_combination(
118118
mlm_item = pystac.Item.from_dict(mlm_data)
119119
pystac.validation.validate(mlm_item, validator=mlm_validator) # ensure original is valid
120120

121-
mlm_data["properties"]["mlm:input"][0]["scaling"] = test_scaling
121+
mlm_data["properties"]["mlm:input"][0]["scaling"] = test_scaling # type: ignore
122122
mlm_item = pystac.Item.from_dict(mlm_data)
123123
if is_valid:
124124
pystac.validation.validate(mlm_item, validator=mlm_validator)

0 commit comments

Comments
 (0)