@@ -75,6 +75,9 @@ def test_model_band() -> None:
7575
7676    assert  c .to_dict () ==  d 
7777
78+     with  pytest .raises (NotImplementedError ):
79+         _  =  c  ==  "blah" 
80+ 
7881
7982def  test_model_props () ->  None :
8083    c  =  ModelBand ({})
@@ -100,6 +103,9 @@ def test_processing_expression() -> None:
100103
101104    assert  c .to_dict () ==  d 
102105
106+     with  pytest .raises (NotImplementedError ):
107+         _  =  c  ==  "blah" 
108+ 
103109
104110def  test_processint_expression_props () ->  None :
105111    c  =  ProcessingExpression ({})
@@ -116,25 +122,43 @@ def test_processint_expression_props() -> None:
116122    assert  c .expression  ==  "B01 + B02" 
117123
118124
119- def  test_valuescaling_object () ->  None :
125+ @pytest .mark .parametrize ( 
126+     "scale_type, min_val, max_val, mean, stddev, value, format_val, expression" , 
127+     [ 
128+         (ValueScalingType .MIN_MAX , 0 , 4 , 3 , 3 , 4 , "asdf" , "asdf" ), 
129+         (ValueScalingType .MIN_MAX , 0.2 , 4.3 , 3.13 , 3.2 , 4.5 , "asdf" , "asdf" ), 
130+         (ValueScalingType .MIN_MAX , 0 , 4 , None , None , None , None , None ), 
131+         (ValueScalingType .SCALE , None , None , None , None , 2 , None , None ), 
132+     ], 
133+ ) 
134+ def  test_valuescaling_object (
135+     scale_type : ValueScalingType ,
136+     min_val : int  |  float  |  None ,
137+     max_val : int  |  float  |  None ,
138+     mean : int  |  float  |  None ,
139+     stddev : int  |  float  |  None ,
140+     value : int  |  float  |  None ,
141+     format_val : str  |  None ,
142+     expression : str  |  None ,
143+ ) ->  None :
120144    c  =  ValueScaling .create (
121-         ValueScalingType . MIN_MAX ,
122-         minimum = 0 ,
123-         maximum = 4 ,
124-         mean = 3 ,
125-         stddev = 3.141 ,
126-         value = 4 ,
127-         format = "asdf" ,
128-         expression = "asdf" ,
129-     )
130-     assert  c .type  ==  ValueScalingType . MIN_MAX 
131-     assert  c .minimum  ==  0 
132-     assert  c .maximum  ==  4 
133-     assert  c .mean  ==  3 
134-     assert  c .stddev  ==  3.141 
135-     assert  c .value  ==  4 
136-     assert  c .format  ==  "asdf" 
137-     assert  c .expression  ==  "asdf" 
145+         scale_type ,
146+         minimum = min_val ,
147+         maximum = max_val ,
148+         mean = mean ,
149+         stddev = stddev ,
150+         value = value ,
151+         format = format_val ,
152+         expression = expression ,
153+     )
154+     assert  c .type  ==  scale_type 
155+     assert  c .minimum  ==  min_val 
156+     assert  c .maximum  ==  max_val 
157+     assert  c .mean  ==  mean 
158+     assert  c .stddev  ==  stddev 
159+     assert  c .value  ==  value 
160+     assert  c .format  ==  format_val 
161+     assert  c .expression  ==  expression 
138162
139163    with  pytest .raises (STACError ):
140164        ValueScaling .create (
@@ -144,6 +168,9 @@ def test_valuescaling_object() -> None:
144168    with  pytest .raises (STACError ):
145169        ValueScaling .create (ValueScalingType .Z_SCORE , mean = 3 )  # missing param stddev 
146170
171+     with  pytest .raises (NotImplementedError ):
172+         _  =  c  ==  "blah" 
173+ 
147174
148175def  test_valuescaling_required_params () ->  None :
149176    assert  ValueScaling .get_required_props (ValueScalingType .MIN_MAX ) ==  [
@@ -178,6 +205,9 @@ def test_input_structure() -> None:
178205    assert  c .dim_order  ==  ["batch" , "channel" , "width" , "height" ]
179206    assert  c .data_type  ==  DataType .FLOAT64 
180207
208+     with  pytest .raises (NotImplementedError ):
209+         _  =  c  ==  "blah" 
210+ 
181211
182212def  test_model_input_structure_props () ->  None :
183213    c  =  InputStructure ({})
@@ -269,6 +299,9 @@ def test_model_input(
269299    assert  "resize_type"  in  d_reverse 
270300    assert  "pre_processing_function"  in  d_reverse 
271301
302+     with  pytest .raises (NotImplementedError ):
303+         _  =  c  ==  "blah" 
304+ 
272305
273306def  test_model_input_props () ->  None :
274307    c  =  ModelInput ({})
@@ -314,6 +347,9 @@ def test_result_structure() -> None:
314347    assert  c .dim_order  ==  ["time" , "width" , "height" ]
315348    assert  c .data_type  ==  DataType .FLOAT64 
316349
350+     with  pytest .raises (NotImplementedError ):
351+         _  =  c  ==  "blah" 
352+ 
317353
318354def  test_result_structure_props () ->  None :
319355    c  =  ResultStructure ({})
@@ -361,6 +397,9 @@ def test_model_output(post_proc_func: ProcessingExpression | None) -> None:
361397    ]
362398    assert  c .post_processing_function  ==  post_proc_func 
363399
400+     with  pytest .raises (NotImplementedError ):
401+         _  =  c  ==  "blah" 
402+ 
364403
365404def  test_model_output_props () ->  None :
366405    c  =  ModelOutput ({})
@@ -408,6 +447,9 @@ def test_hyperparameters() -> None:
408447        assert  key  in  c .to_dict ()
409448        assert  c .to_dict ()[key ] ==  d [key ]
410449
450+     with  pytest .raises (NotImplementedError ):
451+         _  =  c  ==  "blah" 
452+ 
411453
412454def  teest_get_schema_uri (basic_mlm_item : Item ) ->  None :
413455    with  pytest .raises (DeprecationWarning ):
@@ -571,6 +613,34 @@ def test_apply(plain_item: Item) -> None:
571613        and  MLMExtension .ext (plain_item ).hyperparameters  ==  hyp 
572614    )
573615
616+     d  =  {
617+         ** plain_item .properties ,
618+         "mlm:name" : "asdf" ,
619+         "mlm:architecture" : "ResNet" ,
620+         "mlm:tasks" : [TaskType .CLASSIFICATION ],
621+         "mlm:framework" : "PyTorch" ,
622+         "mlm:framework_version" : "1.2.3" ,
623+         "mlm:memory_size" : 3 ,
624+         "mlm:total_parameters" : 123 ,
625+         "mlm:pretrained" : True ,
626+         "mlm:pretrained_source" : "asdfasdfasdf" ,
627+         "mlm:batch_size_suggestion" : 32 ,
628+         "mlm:accelerator" : AcceleratorType .CUDA ,
629+         "mlm:accelerator_constrained" : False ,
630+         "mlm:accelerator_summary" : "This is the summary" ,
631+         "mlm:accelerator_count" : 1 ,
632+         "mlm:input" : [inp .to_dict () for  inp  in  model_input ],
633+         "mlm:output" : [out .to_dict () for  out  in  model_output ],
634+         "mlm:hyperparameters" : hyp .to_dict (),
635+     }
636+ 
637+     assert  MLMExtension .ext (plain_item ).to_dict () ==  d 
638+ 
639+ 
640+ def  test_apply_wrong_object () ->  None :
641+     with  pytest .raises (pystac .ExtensionTypeError ):
642+         _  =  MLMExtension .ext (1 , False )
643+ 
574644
575645def  test_to_from_dict (basic_item_dict : dict [str , Any ]) ->  None :
576646    d1  =  deepcopy (basic_item_dict )
@@ -774,6 +844,25 @@ def test_apply_generic_asset() -> None:
774844    assert  asset_ext .entrypoint  ==  "baz" 
775845
776846
847+ def  test_to_dict_asset_generic () ->  None :
848+     asset  =  pystac .Asset (
849+         href = "http://example.com/test.tiff" ,
850+         title = "image" ,
851+         description = "asdf" ,
852+         media_type = "application/tiff" ,
853+         roles = ["mlm:model" ],
854+     )
855+     asset_ext  =  AssetGeneralMLMExtension .ext (asset , add_if_missing = False )
856+     asset_ext .apply (artifact_type = "foo" , compile_method = "bar" , entrypoint = "baz" )
857+ 
858+     d  =  {
859+         "mlm:artifact_type" : "foo" ,
860+         "mlm:compile_method" : "bar" ,
861+         "mlm:entrypoint" : "baz" ,
862+     }
863+     assert  asset_ext .to_dict () ==  d 
864+ 
865+ 
777866def  test_add_to_detailled_asset () ->  None :
778867    model_input  =  ModelInput .create (
779868        name = "model" ,
@@ -819,6 +908,8 @@ def test_add_to_detailled_asset() -> None:
819908    assert  asset_ext .compile_method  ==  "bar" 
820909    assert  asset_ext .entrypoint  ==  "baz" 
821910
911+     assert  repr (asset_ext ) ==  f"<AssetDetailedMLMExtension Asset href={ asset .href }  >" 
912+ 
822913
823914def  test_apply_detailled_asset () ->  None :
824915    asset  =  pystac .Asset (
@@ -866,13 +957,69 @@ def test_apply_detailled_asset() -> None:
866957    assert  asset_ext .entrypoint  ==  "baz" 
867958
868959
960+ def  test_to_dict_detailed_asset () ->  None :
961+     asset  =  pystac .Asset (
962+         href = "http://example.com/test.tiff" ,
963+         title = "image" ,
964+         description = "asdf" ,
965+         media_type = "application/tiff" ,
966+         roles = ["mlm:model" ],
967+     )
968+     asset_ext  =  AssetDetailedMLMExtension .ext (asset , add_if_missing = False )
969+ 
970+     model_input  =  ModelInput .create (
971+         name = "model" ,
972+         bands = ["B02" ],
973+         input = InputStructure .create (
974+             shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64 
975+         ),
976+     )
977+     model_output  =  ModelOutput .create (
978+         name = "output" ,
979+         tasks = [TaskType .CLASSIFICATION ],
980+         result = ResultStructure .create (
981+             shape = [1 ], dim_order = ["batch" ], data_type = DataType .FLOAT64 
982+         ),
983+     )
984+ 
985+     asset_ext .apply (
986+         "asdf" ,
987+         "ResNet" ,
988+         [TaskType .CLASSIFICATION ],
989+         [model_input ],
990+         [model_output ],
991+         artifact_type = "foo" ,
992+         compile_method = "bar" ,
993+         entrypoint = "baz" ,
994+     )
995+ 
996+     d  =  {
997+         "mlm:name" : "asdf" ,
998+         "mlm:architecture" : "ResNet" ,
999+         "mlm:tasks" : [TaskType .CLASSIFICATION ],
1000+         "mlm:input" : [model_input .to_dict ()],
1001+         "mlm:output" : [model_output .to_dict ()],
1002+         "mlm:artifact_type" : "foo" ,
1003+         "mlm:compile_method" : "bar" ,
1004+         "mlm:entrypoint" : "baz" ,
1005+         "mlm:accelerator" : None ,
1006+         "mlm:pretrained_source" : None ,
1007+     }
1008+     assert  asset_ext .to_dict () ==  d 
1009+ 
1010+ 
8691011def  test_item_asset_extension (mlm_collection : Collection ) ->  None :
8701012    assert  mlm_collection .item_assets 
8711013    item_asset  =  mlm_collection .item_assets ["weights" ]
872-     MLMExtension .ext (item_asset , add_if_missing = True )
1014+     item_asset_ext   =   MLMExtension .ext (item_asset , add_if_missing = True )
8731015    assert  MLMExtension .get_schema_uri () in  mlm_collection .stac_extensions 
8741016    assert  mlm_collection .item_assets ["weights" ].ext .has ("mlm" )
8751017
1018+     assert  (
1019+         repr (item_asset_ext )
1020+         ==  f"<ItemAssetsMLMExtension ItemAssetDefinition={ item_asset }  " 
1021+     )
1022+ 
8761023
8771024def  test_collection_extension (mlm_collection : Collection ) ->  None :
8781025    coll_ext  =  MLMExtension .ext (mlm_collection , add_if_missing = True )
@@ -881,6 +1028,9 @@ def test_collection_extension(mlm_collection: Collection) -> None:
8811028
8821029    coll_ext .mlm_name  =  "asdf" 
8831030    assert  coll_ext .mlm_name  ==  "asdf" 
1031+     assert  (
1032+         repr (coll_ext ) ==  f"<CollectionMLMExtension Collection id={ mlm_collection .id }  >" 
1033+     )
8841034
8851035
8861036def  test_raise_exception_on_mlm_extension_and_asset () ->  None :
0 commit comments