1212from pystac .extensions .classification import Classification
1313from pystac .extensions .mlm import (
1414 ARCHITECTURE_PROP ,
15- NAME_PROP ,
1615 TASKS_PROP ,
1716 AcceleratorType ,
1817 AssetDetailedMLMExtension ,
@@ -770,20 +769,16 @@ def test_add_to_asset(plain_item: Item) -> None:
770769 MLMExtension .ext (plain_item , add_if_missing = True )
771770 asset = plain_item .assets ["analytic" ]
772771
773- assert NAME_PROP not in asset .extra_fields .keys ()
774772 assert ARCHITECTURE_PROP not in asset .extra_fields .keys ()
775773 assert TASKS_PROP not in asset .extra_fields .keys ()
776774
777775 asset_ext = AssetDetailedMLMExtension .ext (asset )
778- asset_ext .mlm_name = "asdf"
779776 asset_ext .architecture = "ResNet"
780777 asset_ext .tasks = [TaskType .CLASSIFICATION ]
781778
782- assert NAME_PROP in asset .extra_fields .keys ()
783779 assert ARCHITECTURE_PROP in asset .extra_fields .keys ()
784780 assert TASKS_PROP in asset .extra_fields .keys ()
785781
786- assert asset .extra_fields [NAME_PROP ] == "asdf"
787782 assert asset .extra_fields [ARCHITECTURE_PROP ] == "ResNet"
788783 assert asset .extra_fields [TASKS_PROP ] == [TaskType .CLASSIFICATION ]
789784
@@ -866,33 +861,15 @@ def test_to_dict_asset_generic() -> None:
866861
867862
868863def test_add_to_detailled_asset () -> None :
869- model_input = ModelInput .create (
870- name = "model" ,
871- bands = ["B02" ],
872- input = InputStructure .create (
873- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
874- ),
875- )
876- model_output = ModelOutput .create (
877- name = "output" ,
878- tasks = [TaskType .CLASSIFICATION ],
879- result = ResultStructure .create (
880- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
881- ),
882- )
883-
884864 asset = pystac .Asset (
885865 href = "http://example.com/test.tiff" ,
886866 title = "image" ,
887867 description = "asdf" ,
888868 media_type = "application/tiff" ,
889869 roles = ["mlm:model" ],
890870 extra_fields = {
891- "mlm:name" : "asdf" ,
892871 "mlm:architecture" : "ResNet" ,
893872 "mlm:tasks" : [TaskType .CLASSIFICATION ],
894- "mlm:input" : [model_input .to_dict ()],
895- "mlm:output" : [model_output .to_dict ()],
896873 "mlm:artifact_type" : "foo" ,
897874 "mlm:compile_method" : "bar" ,
898875 "mlm:entrypoint" : "baz" ,
@@ -901,11 +878,8 @@ def test_add_to_detailled_asset() -> None:
901878
902879 asset_ext = AssetDetailedMLMExtension .ext (asset , add_if_missing = False )
903880
904- assert asset_ext .mlm_name == "asdf"
905881 assert asset_ext .architecture == "ResNet"
906882 assert asset_ext .tasks == [TaskType .CLASSIFICATION ]
907- assert asset_ext .input == [model_input ]
908- assert asset_ext .output == [model_output ]
909883 assert asset_ext .artifact_type == "foo"
910884 assert asset_ext .compile_method == "bar"
911885 assert asset_ext .entrypoint == "baz"
@@ -930,7 +904,7 @@ def test_correct_asset_extension_is_used() -> None:
930904 asset = Asset ("https://example.com" )
931905 assert isinstance (asset .ext .mlm , AssetGeneralMLMExtension )
932906
933- asset .extra_fields ["mlm:name " ] = "asdf "
907+ asset .extra_fields ["mlm:architecture " ] = "ResNet "
934908 assert isinstance (asset .ext .mlm , AssetDetailedMLMExtension )
935909
936910
@@ -951,37 +925,16 @@ def test_apply_detailled_asset() -> None:
951925 )
952926 asset_ext = AssetDetailedMLMExtension .ext (asset , add_if_missing = False )
953927
954- model_input = ModelInput .create (
955- name = "model" ,
956- bands = ["B02" ],
957- input = InputStructure .create (
958- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
959- ),
960- )
961- model_output = ModelOutput .create (
962- name = "output" ,
963- tasks = [TaskType .CLASSIFICATION ],
964- result = ResultStructure .create (
965- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
966- ),
967- )
968-
969928 asset_ext .apply (
970- "asdf" ,
971929 "ResNet" ,
972930 [TaskType .CLASSIFICATION ],
973- [model_input ],
974- [model_output ],
975931 artifact_type = "foo" ,
976932 compile_method = "bar" ,
977933 entrypoint = "baz" ,
978934 )
979935
980- assert asset_ext .mlm_name == "asdf"
981936 assert asset_ext .architecture == "ResNet"
982937 assert asset_ext .tasks == [TaskType .CLASSIFICATION ]
983- assert asset_ext .input == [model_input ]
984- assert asset_ext .output == [model_output ]
985938 assert asset_ext .artifact_type == "foo"
986939 assert asset_ext .compile_method == "bar"
987940 assert asset_ext .entrypoint == "baz"
@@ -997,38 +950,17 @@ def test_to_dict_detailed_asset() -> None:
997950 )
998951 asset_ext = AssetDetailedMLMExtension .ext (asset , add_if_missing = False )
999952
1000- model_input = ModelInput .create (
1001- name = "model" ,
1002- bands = ["B02" ],
1003- input = InputStructure .create (
1004- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
1005- ),
1006- )
1007- model_output = ModelOutput .create (
1008- name = "output" ,
1009- tasks = [TaskType .CLASSIFICATION ],
1010- result = ResultStructure .create (
1011- shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64
1012- ),
1013- )
1014-
1015953 asset_ext .apply (
1016- "asdf" ,
1017954 "ResNet" ,
1018955 [TaskType .CLASSIFICATION ],
1019- [model_input ],
1020- [model_output ],
1021956 artifact_type = "foo" ,
1022957 compile_method = "bar" ,
1023958 entrypoint = "baz" ,
1024959 )
1025960
1026961 d = {
1027- "mlm:name" : "asdf" ,
1028962 "mlm:architecture" : "ResNet" ,
1029963 "mlm:tasks" : [TaskType .CLASSIFICATION ],
1030- "mlm:input" : [model_input .to_dict ()],
1031- "mlm:output" : [model_output .to_dict ()],
1032964 "mlm:artifact_type" : "foo" ,
1033965 "mlm:compile_method" : "bar" ,
1034966 "mlm:entrypoint" : "baz" ,
0 commit comments