Skip to content

Commit 21eb6c7

Browse files
committed
added regex constraint for mlm:framework to migration
1 parent d450c56 commit 21eb6c7

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

pystac/extensions/mlm.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import warnings
1112
from abc import ABC
1213
from collections.abc import Iterable
1314
from typing import Any, Generic, Literal, TypeVar, cast
@@ -2055,10 +2056,39 @@ class MLMExtensionHooks(ExtensionHooks):
20552056
@staticmethod
20562057
def _migrate_1_0_to_1_1(obj: dict[str, Any]) -> None:
20572058
if "mlm:framework" in obj["properties"]:
2059+
framework = obj["properties"]["mlm:framework"]
2060+
# remove invalid characters at beginning and end
2061+
forbidden_chars = [".", "_", "-", " ", "\t", "\n", "\r", "\f", "\v"]
2062+
if framework[0] in forbidden_chars or framework[-1] in forbidden_chars:
2063+
warnings.warn(
2064+
"Value for mlm:framework is invalid in mlm>=1.1, as it must"
2065+
"not start or end with one of the following characters: "
2066+
"._- and whitespace. These characters are therefore removed while"
2067+
"migrating the STAC object to v1.1.",
2068+
SyntaxWarning,
2069+
)
2070+
while obj["properties"]["mlm:framework"][0] in forbidden_chars:
2071+
new_str = obj["properties"]["mlm:framework"][1:]
2072+
obj["properties"]["mlm:framework"] = new_str
2073+
while obj["properties"]["mlm:framework"][-1] in forbidden_chars:
2074+
new_str = obj["properties"]["mlm:framework"][:-1]
2075+
obj["properties"]["mlm:framework"] = new_str
2076+
2077+
# rename frameworks
20582078
if obj["properties"]["mlm:framework"] == "Scikit-learn":
20592079
obj["properties"]["mlm:framework"] = "scikit-learn"
2080+
warnings.warn(
2081+
"mlm:framework value Scikit-learn is no longer valid in mlm>=1.1. "
2082+
"Renaming it to scikit-learn",
2083+
SyntaxWarning,
2084+
)
20602085
if obj["properties"]["mlm:framework"] == "Huggingface":
20612086
obj["properties"]["mlm:framework"] = "Hugging Face"
2087+
warnings.warn(
2088+
"mlm:framework value Huggingface is no longer valid in mlm>=1.1. "
2089+
"Renaming it to Hugging Face",
2090+
SyntaxWarning,
2091+
)
20622092

20632093
@staticmethod
20642094
def _migrate_1_1_to_1_2(obj: dict[str, Any]) -> None:

tests/extensions/test_mlm.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import re
34
from copy import deepcopy
45
from typing import Any, cast
56

@@ -1075,20 +1076,35 @@ def test_raise_exception_on_mlm_extension_and_asset() -> None:
10751076

10761077

10771078
@pytest.mark.parametrize(
1078-
"framework_old, framework_new",
1079-
((None, None), ("Scikit-learn", "scikit-learn"), ("Huggingface", "Hugging Face")),
1079+
"framework_old, framework_new, valid",
1080+
(
1081+
("Scikit-learn", "scikit-learn", False),
1082+
("Huggingface", "Hugging Face", False),
1083+
("-_ .asdf", "asdf", False),
1084+
("asdf-_ .", "asdf", False),
1085+
("-._ asdf-.", "asdf", False),
1086+
("test_framework", "test_framework", True),
1087+
),
10801088
)
10811089
def test_migration_1_0_to_1_1(
1082-
framework_old: None | str, framework_new: None | str
1090+
framework_old: None | str, framework_new: None | str, valid: bool
10831091
) -> None:
10841092
data: dict[str, Any] = {"properties": {}}
10851093

10861094
MLMExtensionHooks._migrate_1_0_to_1_1(data)
10871095
assert "mlm:framework" not in data["properties"]
10881096

1097+
pattern = r"^(?=[^\s._\-]).*[^\s._\-]$"
10891098
data["properties"]["mlm:framework"] = framework_old
1090-
MLMExtensionHooks._migrate_1_0_to_1_1(data)
1099+
1100+
if valid:
1101+
MLMExtensionHooks._migrate_1_0_to_1_1(data)
1102+
else:
1103+
with pytest.warns(SyntaxWarning):
1104+
MLMExtensionHooks._migrate_1_0_to_1_1(data)
1105+
10911106
assert data["properties"]["mlm:framework"] == framework_new
1107+
assert bool(re.match(pattern, data["properties"]["mlm:framework"]))
10921108

10931109

10941110
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)