diff --git a/sarpy/__about__.py b/sarpy/__about__.py index 6c779bf7..8860bafe 100644 --- a/sarpy/__about__.py +++ b/sarpy/__about__.py @@ -28,7 +28,7 @@ from sarpy.__details__ import __classification__, _post_identifier -__version__ = "1.3.61.1" +__version__ = "1.3.62" __author__ = "National Geospatial-Intelligence Agency" __url__ = "https://github.com/ngageoint/sarpy" diff --git a/sarpy/io/DEM/DTED.py b/sarpy/io/DEM/DTED.py index 80fe245a..5e6ca253 100644 --- a/sarpy/io/DEM/DTED.py +++ b/sarpy/io/DEM/DTED.py @@ -651,7 +651,7 @@ def from_reference_point(cls, ref_point, dted_list, dem_type=None, geoid_file=No DTEDInterpolator """ - pad_value = float(pad_value) + pad_value = abs( float(pad_value) ) if pad_value > 0.5: pad_value = 0.5 if pad_value < 0.05: @@ -660,7 +660,7 @@ def from_reference_point(cls, ref_point, dted_list, dem_type=None, geoid_file=No lat_max = min(ref_point[0] + lat_diff, 90) lat_min = max(ref_point[0] - lat_diff, -90) - lon_diff = min(15, lat_diff/(numpy.sin(numpy.deg2rad(ref_point[0])))) + lon_diff = min(15, abs( lat_diff/(numpy.cos(numpy.deg2rad(ref_point[0]))))) lon_max = ref_point[1] + lon_diff if lon_max > 180: lon_max -= 360 diff --git a/sarpy/io/general/nitf.py b/sarpy/io/general/nitf.py index d9b7443c..cb08c888 100644 --- a/sarpy/io/general/nitf.py +++ b/sarpy/io/general/nitf.py @@ -2693,7 +2693,9 @@ def item_bytes(self, value: Union[bytes, Sequence]) -> None: 'item_bytes input has size {},\n\t' 'but item_size has been defined as {}.'.format(len(value), self._item_size)) self._item_bytes = value - self.item_size = len(value) + if self._item_size is None or self.item_size != len(value): + self.item_size = len(value) + @property def item_written(self) -> bool: @@ -4125,9 +4127,7 @@ def get_data_segments(self) -> List[DataSegment]: def flush(self, force: bool = False) -> None: self._validate_closed() - BaseWriter.flush(self, force=force) - try: if self._in_memory: if self._image_segment_data_segments is not None: diff --git a/sarpy/io/general/slice_parsing.py b/sarpy/io/general/slice_parsing.py index 6be541cd..e4799728 100644 --- a/sarpy/io/general/slice_parsing.py +++ b/sarpy/io/general/slice_parsing.py @@ -147,10 +147,12 @@ def verify_subscript( subscript = subscript[:ellipsis_location] elif ellipsis_location == 0: init_pad = ndim - len(subscript) + 1 - subscript = tuple([None, ]*init_pad) + subscript[1:] + subscript = tuple([None, ]*init_pad) + tuple(subscript[1:]) else: # ellipsis in the middle middle_pad = ndim - len(subscript) + 1 - subscript = subscript[:ellipsis_location] + tuple([None, ]*middle_pad) + subscript[ellipsis_location+1:] + subscript = tuple(subscript[:ellipsis_location]) + \ + tuple([None, ]*middle_pad) + \ + tuple(subscript[ellipsis_location+1:]) if len(subscript) > ndim: raise ValueError('More subscript entries ({}) than shape dimensions ({}).'.format(len(subscript), ndim)) diff --git a/sarpy/io/product/sidd.py b/sarpy/io/product/sidd.py index 41026e63..ca8c6889 100644 --- a/sarpy/io/product/sidd.py +++ b/sarpy/io/product/sidd.py @@ -656,7 +656,17 @@ def _get_ftitle(self, index: int = 0) -> str: ftitle = 'SIDD: Unknown' return ftitle + # File Creation DateTime def _get_fdt(self, index: int) -> Optional[str]: + sidd = self.sidd_meta[index] + if sidd.ProductCreation.ProcessorInformation.ProcessingDateTime is not None: + the_time = sidd.ProductCreation.ProcessorInformation.ProcessingDateTime.astype('datetime64[s]') + return re.sub(r'[^0-9]', '', str(the_time)) + else: + return None + + # Image Acquisition (Collection) Datetime + def _get_collection_datetime(self, index: int) -> Optional[str]: sidd = self.sidd_meta[index] if sidd.ExploitationFeatures.Collections[0].Information.CollectionDateTime is not None: the_time = sidd.ExploitationFeatures.Collections[0].Information.CollectionDateTime.astype('datetime64[s]') @@ -730,7 +740,7 @@ def _create_image_segment_for_sidd( 'IC': 'NC', 'IID2': self._get_iid2(sidd_index), 'ISORCE': self._get_isorce(sidd_index), - 'IDATIM': self._get_fdt(sidd_index) + 'IDATIM': self._get_collection_datetime(sidd_index) } if sidd.Display.PixelType == 'MONO8I': diff --git a/sarpy/io/product/sidd2_elements/Compression.py b/sarpy/io/product/sidd2_elements/Compression.py index 098557a7..0ce5cd07 100644 --- a/sarpy/io/product/sidd2_elements/Compression.py +++ b/sarpy/io/product/sidd2_elements/Compression.py @@ -7,6 +7,7 @@ from typing import Union +import xml.etree.ElementTree import numpy from sarpy.io.xml.base import Serializable @@ -16,6 +17,8 @@ from .base import DEFAULT_STRICT, FLOAT_FORMAT +# default_strict is set to False which is why invalid types don't throw an error for NumWaveletLevels, NumBands, or LayerInfo + class J2KSubtype(Serializable): """ The Jpeg 2000 subtype. @@ -55,9 +58,38 @@ def __init__(self, NumWaveletLevels=None, NumBands=None, LayerInfo=None, **kwarg self._xml_ns_key = kwargs['_xml_ns_key'] self.NumWaveletLevels = NumWaveletLevels self.NumBands = NumBands - self.LayerInfo = LayerInfo + self.setLayerInfoType(LayerInfo) super(J2KSubtype, self).__init__(**kwargs) + def setLayerInfoType(self, obj): + """ + Since the LayerInfo is defined as an array of bit rates in the definition above and SarPy returns an + element tree with the bit rates nested within different layers, this function is used to parse and return + LayerInfo as an array of bit rates + """ + # if LayerInfo is an ElementTree as expected, then it gets parsed to return an array of bit rates + ET = xml.etree.ElementTree + if isinstance( obj, ET.Element): + numLayers = int(obj.attrib['numLayers']) + if (numLayers == 0): + return + bitrates = numpy.zeros(numLayers) + + for i in range(numLayers): + bitrates[i] = float(obj[i][0].text) + self.LayerInfo = bitrates + + # if LayerInfo is a numpy.ndarray, a list, or a tuple per the above definition, return LayerInfo + elif isinstance(obj, (list, tuple, numpy.ndarray)): + self.LayerInfo = obj + + # none object handler since it states that LayerInfo can be None in the definition above and it isn't a required field + elif obj is None: + self.LayerInfo = None + + # throw an error if LayerInfo is an invalid type and an array of bitrates is unable to be generated + else: + raise TypeError(f'Invalid input type for LayerInfo: {type(obj)}. Must be an ElementTree, list, tuple, ndarray, or None.') class J2KType(Serializable): """ @@ -110,4 +142,4 @@ def __init__(self, J2K=None, **kwargs): if '_xml_ns_key' in kwargs: self._xml_ns_key = kwargs['_xml_ns_key'] self.J2K = J2K - super(CompressionType, self).__init__(**kwargs) + super(CompressionType, self).__init__(**kwargs) \ No newline at end of file diff --git a/sarpy/io/xml/base.py b/sarpy/io/xml/base.py index 69cfbcef..06490de5 100644 --- a/sarpy/io/xml/base.py +++ b/sarpy/io/xml/base.py @@ -5,17 +5,16 @@ __classification__ = "UNCLASSIFIED" __author__ = "Thomas McCullough" -import logging -from xml.etree import ElementTree -import json -from datetime import date, datetime -from collections import OrderedDict import copy -import re -from io import StringIO -from typing import Dict, Optional - +import json +import logging import numpy +import re +from collections import OrderedDict +from datetime import date, datetime +from io import StringIO +from typing import Dict, Optional +from xml.etree import ElementTree from sarpy.compliance import bytes_to_string @@ -37,7 +36,8 @@ def get_node_value(nod: ElementTree.Element) -> Optional[str]: """ - XML parsing helper for extracting text value from an ElementTree Element. No error checking performed. + XML parsing helper for extracting text value from an ElementTree Element. + No error checking performed. Parameters ---------- @@ -123,10 +123,10 @@ def create_text_node( def find_first_child( - node: ElementTree.Element, - tag: str, - xml_ns: Optional[Dict[str, str]], - ns_key: Optional[str]) -> ElementTree.Element: + node : ElementTree.Element, + tag : str, + xml_ns: Optional[Dict[str, str]] = None, + ns_key: Optional[str] = None) -> ElementTree.Element: """ Finds the first child node @@ -134,8 +134,8 @@ def find_first_child( ---------- node : ElementTree.Element tag : str - xml_ns : None|dict - ns_key : None|str + xml_ns : None|dict XML namespace of the node in question + ns_key : None|str Namespace key to use to preface the tag Returns ------- @@ -150,7 +150,11 @@ def find_first_child( return node.find('{}:{}'.format(ns_key, tag), xml_ns) -def find_children(node, tag, xml_ns, ns_key): +def find_children( + node : ElementTree.Element, + tag : str, + xml_ns: Optional[Dict[str, str]] = None, + ns_key: Optional[str] = None): """ Finds the collection of children nodes @@ -290,6 +294,40 @@ def validate_xml_from_file(xml_path, xsd_path, output_logger=None): def parse_str(value, name, instance): + """ + The parse_str function is not a generic string parser. It is a helper + function specifically for parsing values within XML elements in + sarpy.io.xml.base. It's used internally by the library to extract text + values from XML ElementTree objects. + The function is not intended for public use for general string parsing, and + attempting to use it on a simple text string will just return the initial + string as parse_str expects an XML element as its input. + + Parameters + ---------- + value : ElementTree.Element|None|str + The ElementTree.Element entity that you want to get the value from. + None returns None. + A string value will return the given string unchanged. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | str + Returns None if value passed is None. Returns the string passed if value + is a string. Returns the string value of a node when passed an + ElementTree.Element + + Raises + ------- + TypeError + When passed a value with a type other than the expected input types. + """ + if value is None: return None if isinstance(value, str): @@ -303,6 +341,35 @@ def parse_str(value, name, instance): def parse_bool(value, name, instance): + """ + The parse_bool function is a helper function specifically for parsing boolean + values within XML elements in sarpy.io.xml.base. + The function is not intended for public use. + + Parameters + ---------- + value : ElementTree.Element|None|str[0, 1, true, false] + The ElementTree.Element entity that you want to get the value from. + None returns None. + A string value will return the boolean value for the string if it can be + converted to a boolean. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | bool + Returns None if value passed is None. Returns the boolean value + of a node when passed an ElementTree.Element + + Raises + ------- + ValueError + When passed a value with a type other than the expected input types. + """ def parse_string(val): if val.lower() in ['0', 'false']: return False @@ -330,6 +397,36 @@ def parse_string(val): def parse_int(value, name, instance): + """ + The parse_int function is a helper function specifically for parsing integer + values within XML elements in sarpy.io.xml.base. + The function is not intended for public use. This function is recursive. + + Parameters + ---------- + value : ElementTree.Element|None|int|str + The ElementTree.Element entity that you want to get the value from. + None returns None. + int returns the integer. + A string value will return the integer value for the string if it can be + converted to an integer. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | bool + Returns None if value passed is None. Returns the boolean value + of a node when passed an ElementTree.Element + + Raises + ------- + TypeError + When passed a value with a type other than the expected input types. + """ if value is None: return None if isinstance(value, int): @@ -357,6 +454,36 @@ def parse_int(value, name, instance): # noinspection PyUnusedLocal def parse_float(value, name, instance): + """ + The parse_float function is a helper function specifically for parsing float + values within XML elements in sarpy.io.xml.base. + The function is not intended for public use. This function is recursive. + + Parameters + ---------- + value : ElementTree.Element|None|float|str + The ElementTree.Element entity that you want to get the value from. + None returns None. + Float returns a float. + A string value will return the float value for the string if it can be + converted to a float. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | bool + Returns None if value passed is None. Returns the float value + of a node when passed an ElementTree.Element + + Raises + ------- + TypeError + When passed a value with a type other than the expected input types. + """ if value is None: return None if isinstance(value, float): @@ -370,6 +497,36 @@ def parse_float(value, name, instance): def parse_complex(value, name, instance): + """ + The parse_complex function is a helper function specifically for parsing + complex number (numbers with both a real and an imaginary component) values + within XML elements in sarpy.io.xml.base. + The function is not intended for public use. This function is recursive. + + Parameters + ---------- + value : ElementTree.Element|None|complex|dict + The ElementTree.Element entity that you want to get the value from. + None returns None. + complex returns a complex. + dict is a dictionary representation of a complex number. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | bool + Returns None if value passed is None. Returns the complex value + of a node when passed an ElementTree.Element + + Raises + ------- + TypeError + When passed a value with a type other than the expected input types. + """ if value is None: return None if isinstance(value, complex): @@ -416,6 +573,36 @@ def parse_complex(value, name, instance): def parse_datetime(value, name, instance, units='us'): + """ + The parse_datetime function is a helper function specifically for parsing + datetime values within XML elements in sarpy.io.xml.base. + The function is not intended for public use. This function is recursive. + + Parameters + ---------- + value : ElementTree.Element|None|datetime|dict + The ElementTree.Element entity that you want to get the value from. + None returns None. + datetime returns a datetime. + A string value will return the datetime value for the string if it can be + converted to a datetime. + name : str + Name of the field to return the value of. This is only used in the + raised error message + instance : + The class of the variable. This is only used in the raised error message. + + Returns + ------- + None | bool + Returns None if value passed is None. Returns the float value + of a node when passed an ElementTree.Element + + Raises + ------- + TypeError + When passed a value with a type other than the expected input types. + """ if value is None: return None if isinstance(value, numpy.datetime64): diff --git a/tests/io/DEM/geoid.json b/tests/io/DEM/geoid.json index 783da0f1..92679c2c 100644 --- a/tests/io/DEM/geoid.json +++ b/tests/io/DEM/geoid.json @@ -10,6 +10,10 @@ {"path": "dem/geoid/egm2008-1.pgm", "path_type": "relative"} ], "dted_with_null": [ - {"path": "dem/dted/s04_w061_3arc_v1.dt1", "path_type": "relative"} + {"path": "dem/dted/s04_w061_3arc_v1.dt1", "path_type": "relative", "comment" : "orginal"}, + {"path": "dem/dted/n33_w119_3arc_v1.dt1", "path_type": "relative", "comment" : "Catinlia island"}, + {"path": "dem/dted/s01_w070_3arc_v1.dt1", "path_type": "relative", "comment" : "issue tile"}, + {"path": "dem/dted/n27_e084_3arc_v1.dt1", "path_type": "relative", "comment" : "Nepal"}, + {"path": "dem/dted/s36_e149_3arc_v1.dt1", "path_type": "relative", "comment" : "Austrialia"} ] } diff --git a/tests/io/DEM/test_dted.py b/tests/io/DEM/test_dted.py index 9af76c31..79bd70c1 100644 --- a/tests/io/DEM/test_dted.py +++ b/tests/io/DEM/test_dted.py @@ -6,7 +6,9 @@ from sarpy.io.DEM.geoid import GeoidHeight import tests - +# Note +# set this for your storage of dted and egm files +# export SARPY_TEST_PATH= test_data = tests.find_test_data_files(pathlib.Path(__file__).parent / "geoid.json") egm96_file = test_data["geoid_files"][0] if test_data["geoid_files"] else None @@ -17,7 +19,7 @@ def test_interpolator_no_readers(): llb = [10.0, 20.0, 10.5, 20.5] geoid = GeoidHeight(egm96_file) dtedinterp = sarpy_dted.DTEDInterpolator([], geoid_file=geoid, lat_lon_box=llb) - + assert dtedinterp.get_max_geoid(llb) == 0 assert dtedinterp.get_max_hae(llb) == geoid(10, 10.5) @@ -35,3 +37,151 @@ def test_dted_reader(): } for index, expected_value in known_values.items(): assert dted_reader[index] == expected_value + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_south_west(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][0]) # belive wants the s file + + # From entity ID: SRTM3S04W061V1, date updated: 2013-04-17T12:16:47-05 + # Acquired from https://earthexplorer.usgs.gov/ on 2024-08-21 + # to follow along in qgis + # know_value is one of known_values index + # qgis row = 1200 - known_value[ 1 ] # dted1 data in 1200 blocks + # qgis col = known_value[ 0 ] + known_values = { + (1000, 800): -32767, # null + (1000, 799): 7, + (3, 841): -5, + (1004, 797): 7, # a value among the voids displayed in qgis + } + for index, expected_value in known_values.items(): + assert dted_reader[index] == expected_value + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_north_west(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][1]) # belive wants the northern file + + # From entity ID: SRTM3N33W119V1, date updated: 2013-04-17T12:16:47-05 + # Acquired from https://earthexplorer.usgs.gov/ on 2024-08-21 + # to follow along in qgis + # know_value is one of known_values index + # qgis row = 1200 - known_value[ 1 ] # dted1 data in 1200 blocks + # qgis col = known_value[ 0 ] + known_values = { + (812, 927): -32767, # null + (813, 927): -32767, # null + (811, 927): 79, # a value among the voids displayed in qgi + (813, 926): 110 # a value among the voids displayed in qgi + } + for index, expected_value in known_values.items(): + assert dted_reader[index] == expected_value + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_north_east(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][3]) # belive wants the northern file in Nepal + + # From entity ID: SRTM3N27E084V1, date updated: 2005-02-01 00:00:00-06 + # Acquired from https://earthexplorer.usgs.gov/ on 2024-08-21 + # to follow along in qgis + # known_value is one of known_values/index + # qgis row = 1200 - known_value[ 1 ] # dted1 data in 1200 blocks + # qgis col = known_value[ 0 ] + known_values = { + (927, 681): -32767, # null /void Nepal : 27.56751, 84.77241 [ lat/lon ] + (928, 681): 830, # a value east the void displayed in qgis + (927, 680): 756 # a value south the void displayed in qgis + } + for index, expected_value in known_values.items(): + assert dted_reader[index] == expected_value + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_south_east(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][4]) # belive wants the Austrial + + # From entity ID: SRTM3S36E149V1, date updated: 2005-02-01 00:00:00-06 Austrialia south west of Sydney + # Acquired from https://earthexplorer.usgs.gov/ on 2025-08-28 + # to follow along in qgis + # know_value is one of known_values index + # qgis row = 1200 - known_value[ 1 ] # dted1 data in 1200 blocks + # qgis col = known_value[ 0 ] + known_values = { + (547, 649): -32767, # null near -35.45881, 149.45613 + (547, 648): 752, # a value south of void displayable via QGIS + (546, 649): 756, # a value west of void displayable via QGIS + } + for index, expected_value in known_values.items(): + assert dted_reader[index] == expected_value + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_get_elevation_northern_pt(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][1]) # the northern file + llbx = [ 33.3748, -118.4187 ] + assert dted_reader.get_elevation( llbx[ 0], llbx[ 1 ]) == pytest.approx( 588.23, abs=0.01 ) + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_get_elevation_northern_box(): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][1]) # the northern file + + lats = [ 33.3748, 33.405 ] + lons = [ -118.4187, -118.4027 ] + assert dted_reader.get_elevation( lats, lons ) == pytest.approx( [ 588.2368, 468.36 ] , abs=0.01 ) + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_get_elevation_southern_pt (): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][2]) # second souther file + llbx = [ -1.0, -70.0 ] # lat long, + assert dted_reader.get_elevation( llbx[ 0], llbx[ 1 ]) == pytest.approx( 145.0, abs=0.01 ) + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_reader_get_elevation_southern_box (): + dted_reader = sarpy_dted.DTEDReader(test_data["dted_with_null"][2]) # second southern file + + lats = [ -0.92, -0.90 ] + lons = [ -69.9, -69.8 ] + assert dted_reader.get_elevation( lats, lons ) == pytest.approx( [ 81.0, 111.0 ] , abs=0.1 ) + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_interpolator_get_elevation_hae_north_west(): + ll = [ 33.3748, -118.4187 ] # catinlia island off California coast + geoid = GeoidHeight(egm96_file) + files = test_data["dted_with_null"][1] # dem/dted/n33_w119_3arc_v1.dt1 + dem_interpolator = sarpy_dted.DTEDInterpolator(files=files, geoid_file=geoid, lat_lon_box=ll) + assert dem_interpolator.get_elevation_hae(ll[0], ll[1]) == pytest.approx( 551.87, abs=0.01 ) + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_interpolator_get_elevation_hae_north_east(): + ll = [ 27.57071, 84.77881 ] # Nepal Near void used above + geoid = GeoidHeight(egm96_file) + files = test_data["dted_with_null"][3] # dem/dted/n27_e084_3arc_v1.dt1` + dem_interpolator = sarpy_dted.DTEDInterpolator(files=files, geoid_file=geoid, lat_lon_box=ll) + assert dem_interpolator.get_elevation_hae(ll[0], ll[1]) == pytest.approx( 954.70, abs=0.01 ) + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_interpolator_get_elevation_hae_south_west(): + ll = [-1, -70] # to get this point on -1,-70 tile with a tighter tolerance From kjurka Apr 29, 2025 github issue #587 + geoid = GeoidHeight(egm96_file) + files = test_data["dted_with_null"][2] # dem/dted/s01_w070_3arc_v1.dt1' + dem_interpolator = sarpy_dted.DTEDInterpolator(files=files, geoid_file=geoid, lat_lon_box=ll) + assert dem_interpolator.get_elevation_hae(ll[0], ll[1]) == pytest.approx( 159.98, abs=0.01 ) + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_interpolator_get_elevation_hae_south_west(): + ll = [ -35.4237, 149.5331 ] # Austrialia, south west of Sydney, this point is north east of the void used above in the reader test + geoid = GeoidHeight(egm96_file) + files = test_data["dted_with_null"][4] # dem/dted/s36_e149_3arc_v1.dt1' + dem_interpolator = sarpy_dted.DTEDInterpolator(files=files, geoid_file=geoid, lat_lon_box=ll) + assert dem_interpolator.get_elevation_hae(ll[0], ll[1]) == pytest.approx( 1283.93, abs=0.01 ) + + +@pytest.mark.skipif(not test_data["dted_with_null"], reason="DTED with null data does not exist") +def test_dted_interpolator_get_elevation_hae_south_east_cross_equator(): + ll = [ -0.01, -70.0 ] # lat long, + geoid = GeoidHeight(egm96_file) + files = test_data["dted_with_null"][2] + dem_interpolator = sarpy_dted.DTEDInterpolator.from_reference_point( ll, files, geoid_file=geoid, pad_value=1.0 ) + assert dem_interpolator.get_elevation_hae(ll[0], ll[1]) == pytest.approx( 13.53, abs=0.01 ) diff --git a/tests/io/complex/complex_file_types.json b/tests/io/complex/complex_file_types.json index 3b27478c..a3ec186c 100644 --- a/tests/io/complex/complex_file_types.json +++ b/tests/io/complex/complex_file_types.json @@ -29,6 +29,7 @@ "PALSAR": [ {"path": "palsar/1000000001_001001_ALOS2150521570-170307", "path_type": "relative"}], "Capella": [ + {"path": "capella/CAPELLA_C02_SS_GEC_HH_20210220023903_20210220023919.tif", "path_type": "relative"}, {"path": "capella/CAPELLA_C02_SM_SLC_HH_20210220171045_20210220171049.tif", "path_type": "relative"}, {"path": "capella/CAPELLA_C02_SS_SLC_HH_20210223023836_20210223023852.tif", "path_type": "relative"}, {"path": "capella/CAPELLA_C03_SP_SLC_HH_20210313173209_20210313173212.tif", "path_type": "relative"} diff --git a/tests/io/general/test_slice_parsing.py b/tests/io/general/test_slice_parsing.py new file mode 100644 index 00000000..f1f85b9d --- /dev/null +++ b/tests/io/general/test_slice_parsing.py @@ -0,0 +1,283 @@ +__classification__ = "UNCLASSIFIED" +__author__ = "Tex Peterson" + +import pytest, os, re +from unittest import TestCase + +from sarpy.io.general.slice_parsing import validate_slice_int, verify_slice, \ + verify_subscript, get_slice_result_size, get_subscript_result_size + + +class Test_validate_slice_int(TestCase): + def setUp(self): + self.the_int = 5 + self.bound = 20 + self.include = True + + def testNoParamsFail(self): + with self.assertRaisesRegex(TypeError, + re.escape( + "validate_slice_int() missing 2 " + \ + "required positional arguments: " + \ + "'the_int' and 'bound'")): + validate_slice_int() + + def testBoundAsZeroFail(self): + with self.assertRaisesRegex(TypeError, 'bound must be a positive integer.'): + validate_slice_int(self.the_int, 0, self.include) + + def testBoundAsFloatFail(self): + with self.assertRaisesRegex(TypeError, 'bound must be a positive integer.'): + validate_slice_int(self.the_int, 11.5, self.include) + + def testTheIntOutOfBoundFail(self): + with self.assertRaisesRegex(ValueError, 'Slice argument 5 does not fit with bound 2'): + validate_slice_int(self.the_int, 2, self.include) + + def testTheIntLessThanZeroSuccess(self): + self.assertEqual(validate_slice_int(-2, 4, self.include), 2) + + def testValidTheIntSuccess(self): + self.assertEqual(validate_slice_int(self.the_int, self.bound, self.include), self.the_int) + + def testBadBoundsFail(self): + with self.assertRaisesRegex(ValueError, + re.escape("Slice argument 2 does not fit " + \ + "with bound 1")): + self.assertEqual(validate_slice_int(2, 1, False), 2) + +class Test_verify_slice(TestCase): + def setUp(self): + self.item = [1, 60, 3] + self.max_element = 111 + + def testNoParamsFail(self): + with self.assertRaisesRegex(TypeError, + re.escape( + "verify_slice() missing 2 required " + \ + "positional arguments: " + \ + "'item' and 'max_element'")): + verify_slice() + + def testMaxElementFloatFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "slice verification requires a " + \ + "positive integer limit")): + verify_slice(None, 3.5) + + def testMaxElementLessThan1Fail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "slice verification requires a " + \ + "positive integer limit")): + verify_slice(None, 0) + + def testItemNonePass(self): + self.assertEqual(verify_slice(None, self.max_element), \ + slice(0, self.max_element, 1)) + + def testItemPositiveIntPass(self): + self.assertEqual(verify_slice(3, self.max_element), \ + slice(3, 4, 1)) + + def testItemNegativeIntPass(self): + self.assertEqual(verify_slice(-3, self.max_element), \ + slice(108, 109, 1)) + + def testItemOutOfBoundsPositiveFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "Got out of bounds argument (888) in " + \ + "slice limited by `111`")): + verify_slice(888, self.max_element) + + def testItemOutOfBoundsNegativeFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "Got out of bounds argument (-888) in " + \ + "slice limited by `111`")): + verify_slice(-888, self.max_element) + + def testSlicePass(self): + verify_slice(self.item, self.max_element) + + def testSliceItemOutOfBoundsPositiveFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "Got out of bounds argument (6) in " + \ + "slice limited by `3`")): + verify_slice([6, 3, 1], 3) + + def testSliceItemOutOfBoundsNegativeFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "Got out of bounds argument (-6) in " + \ + "slice limited by `3`")): + verify_slice([-6, 3, 1], 3) + + def testSliceItemPoorlyFormedFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "slice slice(-6, 3, 1) is not well formed")): + verify_slice([-6, 3, 1], 100) + + def testBadValueFail(self): + with self.assertRaisesRegex(ValueError, + re.escape( + "Got unexpected argument of type " + \ + " in slice")): + verify_slice({'key': 'Bad Info', 'Value': 'Not a number'}, self.max_element) + +class Test_verify_subscript(TestCase): + def setUp(self): + self.item = [1, 60, 3] + self.max_element = 111 + + def testNoParamsFail(self): + with self.assertRaisesRegex(TypeError, + re.escape( + "verify_subscript() missing 2 required " + \ + "positional arguments: 'subscript' " + \ + "and 'corresponding_shape'")): + verify_subscript() + + def testSubscriptNoneSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + return_value = verify_subscript(None, self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptElipsisSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + return_value = verify_subscript(..., self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptIntSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + return_value = verify_subscript(0, self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSliceSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + return_value = verify_subscript(slice(0,1), self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSequenceListSuccess(self): + expected_value = (slice(0, 1, 1), slice(1, 2, 1), slice(0, 3, 1)) + return_value = verify_subscript([0,1], self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSequenceListWithTwoElipsisFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(KeyError, + re.escape( + "slice definition cannot contain more " + \ + "than one ellipsis")): + return_value = verify_subscript([0,..., ...], self.item) + + def testSubscriptSequenceListWithOneElipsisSubscriptTooBigFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(ValueError, + re.escape( + "More subscript entries (4) than shape " + \ + "dimensions (3)")): + return_value = verify_subscript([0,..., 1, 5], self.item) + + def testSubscriptSequenceListWithLastElipsisSuccess(self): + expected_value = (slice(0, 1, 1), slice(1, 2, 1), slice(0, 3, 1)) + return_value = verify_subscript([0,1,...], self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSequenceListWithFirstElipsisSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(1, 2, 1)) + return_value = verify_subscript([...,1], self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSequenceListWithMiddleElipsisSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(2, 3, 1)) + return_value = verify_subscript([0,...,2], self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptSequenceListNoElipsisSubscriptTooBigFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(ValueError, + re.escape( + "More subscript entries (4) than shape " + \ + "dimensions (3)")): + return_value = verify_subscript([0,4, 1, 5], self.item) + + def testSubscriptSequenceListNoElipsisSuccess(self): + expected_value = (slice(0, 1, 1), slice(1, 2, 1), slice(0, 3, 1)) + return_value = verify_subscript([0,1], self.item) + self.assertEqual(expected_value, return_value) + + def testSubscriptFloatFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(ValueError, + re.escape("Got unhandled subscript 4.5")): + return_value = verify_subscript(4.5, self.item) + + def testSubscriptStringFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(TypeError, + re.escape("'<=' not supported between " + \ + "instances of 'int' and 'str'")): + return_value = verify_subscript("bob", self.item) + + def testSubscriptRangeFail(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + with self.assertRaisesRegex(ValueError, + re.escape("More subscript entries (5) than " + \ + "shape dimensions (3).")): + return_value = verify_subscript(range(5), self.item) + + def testSubscriptSequenceRangeSuccess(self): + expected_value = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + return_value = verify_subscript(range(1), self.item) + self.assertEqual(expected_value, return_value) + +class Test_get_slice_result_size(TestCase): + def setUp(self): + self.item = [1, 60, 3] + self.max_element = 111 + + def testNoParamsFail(self): + with self.assertRaisesRegex(TypeError, + re.escape( + "get_slice_result_size() missing 1 " + \ + "required positional argument: " + \ + "'slice_in'")): + get_slice_result_size() + + def testFullSliceSuccess(self): + return_value = get_slice_result_size(slice(0,5,1)) + self.assertEqual(5, return_value) + + def testSliceNegativeStepNoneStopSuccess(self): + return_value = get_slice_result_size(slice(0,None,-1)) + self.assertEqual(1, return_value) + + def testSliceNegativeStepSuccess(self): + return_value = get_slice_result_size(slice(0,5,-1)) + self.assertEqual(-5, return_value) + +class Test_get_subscript_result_size(TestCase): + def setUp(self): + self.item = [1, 60, 3] + self.max_element = 111 + + def testNoParamsFail(self): + with self.assertRaisesRegex(TypeError, + re.escape( + "get_subscript_result_size() missing " + \ + "2 required positional arguments: " + \ + "'subscript' and " + \ + "'corresponding_shape'")): + get_subscript_result_size() + + def testSubscriptNoneSuccess(self): + expected_subscript = (slice(0, 1, 1), slice(0, 60, 1), slice(0, 3, 1)) + expected_shape = (1, 60, 3) + return_subscript, return_shape = get_subscript_result_size(None, self.item) + self.assertEqual(expected_subscript, return_subscript) + self.assertEqual(expected_shape, return_shape) diff --git a/tests/io/product/sidd2_elements/test_compression.py b/tests/io/product/sidd2_elements/test_compression.py new file mode 100644 index 00000000..4e844545 --- /dev/null +++ b/tests/io/product/sidd2_elements/test_compression.py @@ -0,0 +1,104 @@ +import pytest +from sarpy.io.product.sidd2_elements.Compression import J2KSubtype, J2KType, CompressionType +import xml.etree.ElementTree as ET +import numpy +import re + +def test_j2ksubtype_init(): + layer_info = [0.0, 1.0, 2.0] + j2ksubtype = J2KSubtype(NumWaveletLevels=3, NumBands=5, LayerInfo=layer_info) + + +def test_j2ksubtype_init_kwargs(): + layer_info = [0.0, 1.0, 2.0] + ns_value = "test_namespace_value" + ns_key = "test_namespace_key" + + j2ksubtype = J2KSubtype(NumWaveletLevels=3, NumBands=5, LayerInfo=layer_info, _xml_ns=ns_value, _xml_ns_key=ns_key) + + # checks if the xml ns value and keys are assigned correctly + assert j2ksubtype._xml_ns == ns_value + assert j2ksubtype._xml_ns_key == ns_key + +# this function is used as a helper to create an XML that mimic's the structure of LayerInfo to be used for testing +def create_layerinfo_xml_helper(num_layers, bitrates): + root = ET.Element("LayerInfo", numLayers=str(num_layers)) + for rate in bitrates: + bitrate_element = ET.SubElement(root, "Bitrate") + value_element = ET.SubElement(bitrate_element, "value") + value_element.text = str(rate) + return root + +def test_j2ksubtype_set_layer_info_type_element(): + bitrates = [1.0, 2.0, 3.0] + xml_element = create_layerinfo_xml_helper(3, bitrates) + + layer_info = [0.0, 1.0, 2.0] + j2ksubtype = J2KSubtype(NumWaveletLevels=3, NumBands=5, LayerInfo=layer_info) + + j2ksubtype.setLayerInfoType(xml_element) + + assert numpy.array_equal(j2ksubtype.LayerInfo, bitrates) + +def test_j2ksubtype_set_layer_info_no_layer_info_passed(): + bitrates = [1.0, 2.0, 3.0] + xml_element = create_layerinfo_xml_helper(3, bitrates) + + j2ksubtype = J2KSubtype(NumWaveletLevels=3, NumBands=5) + + j2ksubtype.setLayerInfoType(xml_element) + + assert numpy.array_equal(j2ksubtype.LayerInfo, bitrates) + +def test_j2ksubtype_init_failure_layer_info(): + # makes sure that an error is thrown if LayerInfo is neither an ElementTree, list, tuple, ndarray, or None + layer_info = "abc" + with pytest.raises(TypeError, + match=re.escape("Invalid input type for LayerInfo: . Must be an ElementTree, list, tuple, ndarray, or None.")): + J2KSubtype(NumWaveletLevels=3, NumBands=7, LayerInfo=layer_info) + +@pytest.fixture() +def setup_j2ksubtype(): + layer_info = [0.0, 1.0, 2.0] + yield J2KSubtype(NumWaveletLevels=3, NumBands=5, LayerInfo=layer_info) + +def test_j2ktype_init(setup_j2ksubtype): + layer_info = [0.0, 1.0, 2.0] + J2KType(setup_j2ksubtype) + +@pytest.fixture() +def setup_j2ktype(setup_j2ksubtype): + yield J2KType(setup_j2ksubtype) + +def test_j2ktype_init_w_parsed(setup_j2ksubtype): + layer_info = [0.0, 1.0, 2.0] + original = setup_j2ksubtype + parsed = J2KSubtype(NumWaveletLevels=2, NumBands=4, LayerInfo=layer_info) + J2KType(original, parsed) + +def test_j2ktype_init_kwargs(setup_j2ksubtype): + layer_info = [0.0, 1.0, 2.0] + ns_value = "test_namespace_value" + ns_key = "test_namespace_key" + + j2ksubtype = setup_j2ksubtype + j2ktype = J2KType(j2ksubtype,_xml_ns=ns_value, _xml_ns_key=ns_key) + + # checks if the xml ns value and keys are assigned correctly + assert j2ktype._xml_ns == ns_value + assert j2ktype._xml_ns_key == ns_key + +def test_compressiontype_init(setup_j2ktype): + j2ktype = setup_j2ktype + CompressionType(j2ktype) + +def test_compressiontype_init_kwargs(setup_j2ktype): + j2ktype = setup_j2ktype + ns_value = "test_namespace_value" + ns_key = "test_namespace_key" + + compressiontype = CompressionType(j2ktype, _xml_ns=ns_value, _xml_ns_key=ns_key) + + # checks if the xml ns value and keys are assigned correctly + assert compressiontype._xml_ns == ns_value + assert compressiontype._xml_ns_key == ns_key \ No newline at end of file diff --git a/tests/io/xml/actor_test_data.xml b/tests/io/xml/actor_test_data.xml new file mode 100644 index 00000000..cf9dfebf --- /dev/null +++ b/tests/io/xml/actor_test_data.xml @@ -0,0 +1,16 @@ + + + + John Cleese + Lancelot + Archie Leach + 1939-10-27 + + + Eric Idle + Sir Robin + Gunther + Commander Clement + + \ No newline at end of file diff --git a/tests/io/xml/country.xsd b/tests/io/xml/country.xsd new file mode 100644 index 00000000..087c0f68 --- /dev/null +++ b/tests/io/xml/country.xsd @@ -0,0 +1,45 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/io/xml/country_data.xml b/tests/io/xml/country_data.xml new file mode 100644 index 00000000..a2ddc829 --- /dev/null +++ b/tests/io/xml/country_data.xml @@ -0,0 +1,38 @@ + + + + 1 + 2008 + 141100 + + + 3 + 2 + + + + + + 4 + 2011 + 59900 + + 3 + 4 + 2 + + + + + 68 + 2011 + 13600 + + 3 + 2 + 1 + + + + + \ No newline at end of file diff --git a/tests/io/xml/test_base_functions.py b/tests/io/xml/test_base_functions.py new file mode 100644 index 00000000..f75eb68b --- /dev/null +++ b/tests/io/xml/test_base_functions.py @@ -0,0 +1,1158 @@ +__classification__ = "UNCLASSIFIED" +__author__ = "Tex Peterson" + +import numpy as np +import unittest +import xml.etree.ElementTree as ET +from collections import OrderedDict + +import sarpy.io.xml.base as base + +# ******************** +# get_node_value tests +# ******************** +class TestGetNodeValue(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_get_node_value_success_with_text(self): + branch = base.get_node_value(self.root[0][1]) + self.assertEqual(branch, '2008') + + def test_get_node_value_success_none(self): + branch = base.get_node_value(self.root[0]) + self.assertIsNone(branch) + + def test_get_node_value_success_empty(self): + branch = base.get_node_value(self.root[0][3]) + self.assertIsNone(branch) + +# ******************** +# create_new_node tests +# ******************** +class TestCreateNewNode(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_create_new_node_no_parent_success(self): + new_node_tag = "country" + self.assertEqual(len(self.root), 3, + "Root should have 3 children before adding new node") + new_node = base.create_new_node(self.tree, new_node_tag) + self.assertEqual(len(self.root), 4, + "Root should have 4 children after adding new node") + self.assertEqual(self.root[-1].tag, new_node_tag, + "Last child should have the new node tag") + self.assertIs(self.root[-1], new_node, + "Returned node should be the last child of root") + + def test_create_new_node_with_parent_success(self): + new_node_tag = "ocean" + self.assertEqual(len(self.root[1]), 5, + "Parent should have 5 children before adding new node") + new_node = base.create_new_node(self.tree, new_node_tag, self.root[1]) + self.assertEqual(len(self.root[1]), 6, + "Parent should have 6 children after adding new node") + self.assertEqual(self.root[1][5].tag, new_node_tag, + "Last child should have the new node tag") + self.assertIs(self.root[1][5], new_node, + "Returned node should be the last child of parent") + +# ******************** +# create_text_node tests +# ******************** +class TestCreateTextNode(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_create_text_node_no_parent_success(self): + new_node_tag = "country" + new_node_value = "Costa Rica" + self.assertEqual(len(self.root), 3, + "Root should have 3 children before adding new " + \ + "text node") + new_node = base.create_text_node(self.tree, new_node_tag, new_node_value) + self.assertEqual(len(self.root), 4, + "Root should have 4 children after adding new " + \ + "text node") + self.assertEqual(self.root[3].tag, new_node_tag, + "Last child should have the new node tag") + self.assertEqual(self.root[3].text, new_node_value, + "Last child's text should match the new node value") + self.assertIs(self.root[3], new_node, + "Returned node should be the last child of root") + + def test_create_text_node_with_parent_success(self): + new_node_tag = "ocean" + new_node_value = "Pacific" + self.assertEqual(len(self.root[2]), 6, + "Parent should have 6 children before adding new " + \ + "text node") + new_node = base.create_text_node(self.tree, new_node_tag, + new_node_value, self.root[2]) + self.assertEqual(len(self.root[2]), 7, + "Parent should have 7 children after adding new " + \ + "text node") + self.assertEqual(self.root[2][6].tag, new_node_tag, + "Last child should have the new node tag") + self.assertEqual(self.root[2][6].text, new_node_value, + "Last child's text should match the new node value") + self.assertIs(self.root[2][6], new_node, + "Returned node should be the last child of parent") + +# ******************** +# find_first_child tests +# ******************** +class TestFindFirstChild(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_find_first_child_no_optional_params_success(self): + found_node = base.find_first_child(self.root, "country") + self.assertIsNotNone(found_node, "Should find a node with tag 'country'") + self.assertEqual(found_node.attrib, self.root[0].attrib, + "Found node's attributes should match the first " + \ + "'country' node") + + def test_find_first_child_namespace_params_success(self): + found_node = base.find_first_child(self.actor_root, "actor", + self.actor_ns_dict) + self.assertIsNotNone(found_node, + "Should find a node with tag 'actor' using namespace") + self.assertEqual(found_node.attrib, + self.actor_root[0].attrib, + "Found node's attributes should match the first " + \ + "'actor' node") + + def test_find_first_child_namespace_nskey_params_success(self): + found_actor_node = base.find_first_child(self.actor_root, "actor", + self.actor_ns_dict) + found_node = base.find_first_child(found_actor_node, "character", + self.actor_ns_dict, "fictional") + self.assertIsNotNone( + found_node, + "Should find a node with tag 'character' using namespace and nskey") + self.assertEqual(found_node.attrib, + self.actor_root[0].attrib, + "Found node's attributes should match the first " + \ + "'actor' node") + +# ******************** +# find_children tests +# ******************** +class TestFindChildren(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_find_children_no_optional_params_success(self): + found_nodes = base.find_children(self.root, "country") + self.assertEqual(found_nodes, + self.root.findall("country"), + "Should find all 'country' nodes without namespace") + + def test_find_children_namespace_params_success(self): + found_node = base.find_children(self.actor_root, "actor", + self.actor_ns_dict) + self.assertEqual(found_node, + self.actor_root.findall('actor',self.actor_ns_dict) + ) + + def test_find_children_namespace_nskey_params_success(self): + found_actor_node = base.find_first_child(self.actor_root, "actor", + self.actor_ns_dict) + found_nodes = base.find_children(found_actor_node, "character", + self.actor_ns_dict, "fictional") + self.assertEqual( + found_nodes, + found_actor_node.findall('fictional:character', self.actor_ns_dict), + "Should find all 'character' nodes with namespace and nskey" + ) + +# ******************** +# parse_xml_from_string tests +# ******************** +class TestParseXmlFromString(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_xml_from_string_success(self): + xml_string = ET.tostring(self.root, encoding='unicode', method='xml') + root_node, ns_dict = base.parse_xml_from_string(xml_string) + self.assertEqual(root_node.attrib, self.root.attrib, + "Parsed root node's attributes should match the " + \ + "original root") + +# ******************** +# parse_xml_from_file tests +# ******************** +class TestParseXmlFromFile(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_xml_from_file_success(self): + test_root, test_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/country_data.xml') + self.assertEqual(test_root.attrib, self.root.attrib) + +# ******************** +# validate_xml_from_string tests +# ******************** +class TestValidateXmlFromString(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_validate_xml_from_string_success(self): + xml_string = ET.tostring(self.root, encoding='unicode', method='xml') + xsd_path = 'tests/io/xml/country.xsd' + self.assertTrue(base.validate_xml_from_string(xml_string, xsd_path), + "XML string should validate against the provided XSD") + + def test_validate_xml_from_string_with_logger_success(self): + xml_string = ET.tostring(self.root, encoding='unicode', method='xml') + xsd_path = 'tests/io/xml/country.xsd' + self.assertTrue(base.validate_xml_from_string(xml_string, xsd_path, + base.logger), + "XML string should validate against the provided " + \ + "XSD with logger") + +# ******************** +# validate_xml_from_file tests +# ******************** +class TestValidateXmlFromFile(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_validate_xml_from_file_success(self): + xml_path = 'tests/io/xml/country_data.xml' + xsd_path = 'tests/io/xml/country.xsd' + self.assertTrue(base.validate_xml_from_file(xml_path, xsd_path), + "XML file should validate against the provided XSD") + + def test_validate_xml_from_file_with_logger_success(self): + xml_path = 'tests/io/xml/country_data.xml' + xsd_path = 'tests/io/xml/country.xsd' + self.assertTrue( + base.validate_xml_from_file(xml_path, xsd_path, base.logger), + "XML file should validate against the provided XSD with logger" + ) + +# ******************** +# parse_str tests +# ******************** +class TestParseStr(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_str_no_params_fail(self): + with self.assertRaisesRegex( + TypeError, + r"parse_str\(\) missing 3 required positional arguments: " + \ + "'value', 'name', and 'instance'$" + ): + base.parse_str() + + def test_parse_str_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_str\(\) missing 2 " + \ + "required positional arguments: 'name' " + \ + "and 'instance'$"): + base.parse_str("Test") + + def test_parse_str_missing_instance_param_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_str\(\) missing 1 " + \ + "required positional argument: 'instance'$"): + base.parse_str("Test", "Bob") + + def test_parse_str_value_param_is_string_success(self): + self.assertEqual(base.parse_str("Test", "Bob", "base"), "Test") + + def test_parse_str_value_param_is_None_success(self): + self.assertIsNone(base.parse_str(None, "Bob", "base")) + + def test_parse_str_value_param_is_xml_with_value_success(self): + self.assertEqual(base.parse_str(self.root[0][2], "text", "base"), + self.root[0][2].text) + + def test_parse_str_value_param_is_xml_empty_value_success(self): + self.assertEqual(base.parse_str(self.root[0], "text", "base"), + self.root[0].text.strip() + ) + + def test_parse_str_bad_value_param_fail(self): + with self.assertRaisesRegex(TypeError, r"field Bob of class str " + \ + "requires a string value."): + base.parse_str(1, "Bob", "base") + +# ******************** +# parse_bool tests +# ******************** +class TestParseBool(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_bool_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_bool\(\) missing 3 " + \ + "required positional arguments: 'value', " + \ + "'name', and 'instance'$"): + base.parse_bool() + + def test_parse_bool_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_bool\(\) missing 2 " + \ + "required positional arguments: 'name' " + \ + "and 'instance'$"): + base.parse_bool("Test") + + def test_parse_bool_missing_instance_param_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_bool\(\) missing 1 " + \ + "required positional argument: 'instance'$"): + base.parse_bool("Test", "Bob") + + def test_parse_bool_value_param_is_None_success(self): + self.assertIsNone(base.parse_bool(None, "Bob", "base")) + + def test_parse_bool_value_param_is_bool_success(self): + self.assertTrue(base.parse_bool(True, "Bob", "base")) + + def test_parse_bool_value_param_is_int_success(self): + self.assertTrue(base.parse_bool(1, "Bob", "base")) + + def test_parse_bool_value_param_is_np_bool_success(self): + arr_bool = np.array([True, False, True, False], dtype=bool) + self.assertTrue(base.parse_bool(arr_bool[0], "Bob", "base")) + + def test_parse_bool_value_param_is_xml_success(self): + self.assertTrue(base.parse_bool(self.root[0][0], "Bob", "base")) + + def test_parse_bool_value_param_is_string_true_success(self): + self.assertTrue(base.parse_bool('trUe', "Bob", "base")) + + def test_parse_bool_value_param_is_string_1_success(self): + self.assertTrue(base.parse_bool('1', "Bob", "base")) + + def test_parse_bool_value_param_is_string_false_success(self): + self.assertFalse(base.parse_bool('FALSE', "Bob", "base")) + + def test_parse_bool_value_param_is_string_0_success(self): + self.assertFalse(base.parse_bool('0', "Bob", "base")) + + def test_parse_bool_value_param_is_float_fail(self): + with self.assertRaisesRegex(ValueError, r"Boolean field Bob of " + \ + "class str cannot assign from type " + \ + "."): + base.parse_bool(3.5, "Bob", "base") + +# ******************** +# parse_int tests +# ******************** +class TestParseInt(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_int_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_int\(\) missing 3 " + \ + "required positional arguments: 'value', " + \ + "'name', and 'instance'$"): + base.parse_int() + + def test_parse_int_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_int\(\) missing 2 " + \ + "required positional arguments: 'name' " + \ + "and 'instance'$"): + base.parse_int("Test") + + def test_parse_int_missing_instance_param_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_int\(\) missing 1 " + \ + "required positional argument: 'instance'$"): + base.parse_int("Test", "Bob") + + def test_parse_int_value_param_is_None_success(self): + self.assertIsNone(base.parse_int(None, "Bob", "base")) + + def test_parse_int_value_param_is_int_success(self): + self.assertEqual(base.parse_int(1, "Bob", "base"), 1) + + def test_parse_int_value_param_is_xml_success(self): + self.assertEqual(base.parse_int(self.root[0][0], "Bob", "base"), 1) + + def test_parse_int_value_param_is_string_1_success(self): + self.assertEqual(base.parse_int('1', "Bob", "base"), 1) + + def test_parse_int_value_param_is_string_non_int_success(self): + with self.assertRaisesRegex(ValueError, r"invalid literal for " + \ + "int\(\) with base 10: 'Bob'"): + assert(base.parse_int('Bob', "Bob", "base") == 1) + + def test_parse_int_value_param_is_list_non_int_success(self): + with self.assertRaisesRegex(TypeError, r"int\(\) argument must be a " + \ + "string, a bytes-like object or a real " + \ + "number, not 'list'"): + assert(base.parse_int([3.5], "Bob", "base") == 1) + +# ******************** +# parse_float tests +# ******************** +class TestParseFloat(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_float_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_float\(\) missing 3 " + \ + "required positional arguments: 'value', " + \ + "'name', and 'instance'$"): + base.parse_float() + + def test_parse_float_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_float\(\) missing 2 " + \ + "required positional arguments: 'name' " + \ + "and 'instance'$"): + base.parse_float("Test") + + def test_parse_float_missing_instance_param_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_float\(\) missing 1 " + \ + "required positional argument: 'instance'$"): + base.parse_float("Test", "Bob") + + def test_parse_float_value_param_is_None_success(self): + self.assertIsNone(base.parse_float(None, "Bob", "base")) + + def test_parse_float_value_param_is_float_success(self): + self.assertEqual(base.parse_float(1.5, "Bob", "base"), 1.5) + + def test_parse_float_value_param_is_xml_success(self): + self.assertEqual(base.parse_float(self.root[0][0], "Bob", "base"), 1.0) + + def test_parse_float_value_param_is_string_1dot5_success(self): + self.assertEqual(base.parse_float('1.5', "Bob", "base"), 1.5) + + def test_parse_float_value_param_is_string_non_int_success(self): + with self.assertRaisesRegex(ValueError, r"could not convert string " + \ + "to float: 'Bob'"): + base.parse_float('Bob', "Bob", "base") + + def test_parse_float_value_param_is_list_non_int_success(self): + with self.assertRaisesRegex(TypeError, r"float\(\) argument must be " + \ + "a string or a real number, not 'list'"): + base.parse_float([3.5], "Bob", "base") + +# ******************** +# parse_complex tests +# ******************** +class TestParseComplex(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_parse_complex_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_complex\(\) missing " + \ + "3 required positional arguments: 'value'," + \ + " 'name', and 'instance'$"): + base.parse_complex() + + def test_parse_complex_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_complex\(\) missing 2 " + \ + "required positional arguments: 'name' " + \ + "and 'instance'$"): + base.parse_complex("Test") + + def test_parse_complex_missing_instance_param_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_complex\(\) missing 1 " + \ + "required positional argument: 'instance'$"): + base.parse_complex("Test", "Bob") + + def test_parse_complex_value_param_is_None_success(self): + self.assertIsNone(base.parse_complex(None, "Bob", "base")) + + def test_parse_complex_value_param_is_complex_success(self): + test_complex = 3 + 2j + self.assertEqual(base.parse_complex(test_complex, "Bob", "base"), + test_complex) + + def test_parse_complex_value_param_is_xml_success(self): + test_complex = 3 + 2j + self.assertEqual(base.parse_complex(self.root[0][4], "Bob", "base"), + test_complex) + + def test_parse_complex_value_param_is_xml_2_real_fail(self): + test_complex = 3 + 2j + with self.assertRaisesRegex(ValueError, r"There must be exactly one " + \ + "Real component of a complex type node " + \ + "defined for field Bob of class str."): + base.parse_complex(self.root[1][3], "Bob", "base") + + def test_parse_complex_value_param_is_xml_2_imag_fail(self): + test_complex = 3 + 2j + with self.assertRaisesRegex(ValueError, r"There must be exactly one " + \ + "Imag component of a complex type node " + \ + "defined for field Bob of class str."): + base.parse_complex(self.root[2][3], "Bob", "base") + + def test_parse_complex_value_param_is_complex_dict_1_success(self): + test_complex = 3 + 2j + self.assertEqual(base.parse_complex({"real":3, "imag":2}, "Bob", "base"), + test_complex) + + def test_parse_complex_value_param_is_complex_dict_2_success(self): + test_complex = 3 + 2j + self.assertEqual(base.parse_complex({"Real":3, "Imag":2}, "Bob", "base"), + test_complex) + + def test_parse_complex_value_param_is_complex_dict_3_success(self): + test_complex = 3 + 2j + self.assertEqual(base.parse_complex({"re":3, "im":2}, "Bob", "base"), + test_complex) + + def test_parse_complex_value_param_is_complex_dict_4_fail(self): + test_complex = 3 + 2j + with self.assertRaisesRegex(ValueError, r"Cannot convert dict {'not': 3, 'valid': 2} to a complex number for field Bob of class str."): + base.parse_complex({"not":3, "valid":2}, "Bob", "base") + + def test_parse_complex_value_param_is_complex_dict_5_fail(self): + test_complex = 3 + 2j + with self.assertRaisesRegex(ValueError, r"Cannot convert dict {'real': None, 'imag': 2} to a complex number for field Bob of class str."): + base.parse_complex({"real":None, "imag":2}, "Bob", "base") + + def test_parse_complex_value_param_is_complex_dict_6_fail(self): + test_complex = 3 + 2j + with self.assertRaisesRegex(ValueError, r"Cannot convert dict {'real': 4, 'imag': None} to a complex number for field Bob of class str."): + base.parse_complex({"real":4, "imag":None}, "Bob", "base") + + def test_parse_complex_value_param_is_string_non_int_fail(self): + with self.assertRaisesRegex(ValueError, r"complex\(\) arg is a malformed string"): + base.parse_complex('Bob', "Bob", "base") + + def test_parse_complex_value_param_is_list_non_int_fail(self): + with self.assertRaisesRegex(TypeError, r"complex\(\) first argument must be a string or a number, not 'list'"): + base.parse_complex([3.5], "Bob", "base") + + +# ******************** +# parse_serializable tests +# ******************** + +class ParseSerializableDummyType: + @classmethod + def from_dict(cls, d): + return cls(**d) + @classmethod + def from_node(cls, node, xml_ns, ns_key=None): + return cls(node=node, xml_ns=xml_ns, ns_key=ns_key) + @classmethod + def from_array(cls, arr): + return cls(arr=arr) + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + +class ParseSerializableDummyArrayable(base.Arrayable): + @classmethod + def from_array(cls, arr): + return cls(arr) + def __init__(self, arr): + self.arr = arr + +class ParseSerializableDummyInstance: + _xml_ns = {'default': 'urn:test'} + _xml_ns_key = 'default' + _child_xml_ns_key = {'foo': 'default'} + +class TestParseSerializable(unittest.TestCase): + def setUp(self): + self.tree = ET.parse('tests/io/xml/country_data.xml') + self.actor_tree = ET.parse('tests/io/xml/actor_test_data.xml') + self.root = self.tree.getroot() + # For xml ns is an abbreviation for name space + self.actor_root, self.actor_ns_dict = \ + base.parse_xml_from_file('tests/io/xml/actor_test_data.xml') + + def test_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable\(\) missing " + \ + "4 required positional arguments: 'value'," + \ + " 'name', 'instance', and 'the_type'$"): + base.parse_serializable() + + def test_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable\(\) missing 3 " + \ + "required positional arguments: 'name', " + \ + "'instance', and 'the_type'$"): + base.parse_serializable("Test") + + def test_value_name_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable\(\) missing 2 " + \ + "required positional arguments: " + \ + "'instance' and 'the_type'$"): + base.parse_serializable("Test", "foo") + + def test_value_name_instance_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable\(\) missing 1 " + \ + "required positional argument: 'the_type'$"): + base.parse_serializable("Test", "foo", ParseSerializableDummyInstance()) + + def test_none(self): + self.assertIsNone(base.parse_serializable(None, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType)) + + def test_instance(self): + obj = ParseSerializableDummyType(a=1) + self.assertIs(base.parse_serializable(obj, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType), obj) + + def test_dict(self): + result = base.parse_serializable({'a': 1}, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType) + self.assertIsInstance(result, ParseSerializableDummyType) + self.assertEqual(result.a, 1) + + def test_element(self): + elem = ET.Element('Dummy') + result = base.parse_serializable(elem, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType) + self.assertIsInstance(result, ParseSerializableDummyType) + self.assertEqual(result.node, elem) + self.assertEqual(result.xml_ns, ParseSerializableDummyInstance._xml_ns) + self.assertEqual(result.ns_key, + ParseSerializableDummyInstance._child_xml_ns_key['foo']) + + def test_arrayable_ndarray(self): + arr = np.array([1, 2, 3]) + result = base.parse_serializable(arr, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyArrayable) + self.assertIsInstance(result, ParseSerializableDummyArrayable) + np.testing.assert_array_equal(result.arr, arr) + + def test_arrayable_ndarray_bad_type_fail(self): + arr = np.array([1, 2, 3]) + with self.assertRaisesRegex(TypeError, r"Field foo of class " + \ + "ParseSerializableDummyInstance is of type " + \ + " \(not a subclass of " + \ + "Arrayable\) and got an argument of type " + \ + ".$"): + result = base.parse_serializable(arr, 'foo', + ParseSerializableDummyInstance(), + int) + + def test_arrayable_list(self): + arr = [1, 2, 3] + result = base.parse_serializable(arr, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyArrayable) + self.assertIsInstance(result, ParseSerializableDummyArrayable) + self.assertEqual(result.arr, arr) + + def test_arrayable_tuple(self): + arr = (1, 2, 3) + result = base.parse_serializable(arr, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyArrayable) + self.assertIsInstance(result, ParseSerializableDummyArrayable) + self.assertEqual(result.arr, arr) + + def test_non_arrayable_array(self): + arr = [1, 2, 3] + with self.assertRaises(TypeError): + base.parse_serializable(arr, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType) + + def test_invalid_type(self): + with self.assertRaisesRegex(TypeError, r"Field foo of class " + \ + "ParseSerializableDummyInstance is " + \ + "expecting type , but got an " + \ + "instance of incompatible type " + \ + ".$"): + base.parse_serializable(123.456, 'foo', + ParseSerializableDummyInstance(), + ParseSerializableDummyType) + +class ParseSerializableArrayDummyArrayable(base.Arrayable): + def __init__(self, arr): + self.arr = np.array(arr) + @classmethod + def from_array(cls, arr): + return cls(arr) + def get_array(self, dtype=None): + return self.arr + +class ParseSerializableArrayDummySerializable: + @classmethod + def from_node(cls, node, xml_ns, ns_key=None): + return cls(node.tag) + @classmethod + def from_dict(cls, d): + return cls(d['tag']) + def __init__(self, tag): + self.tag = tag + +class ParseSerializableArrayDummyInstance: + _xml_ns = None + _xml_ns_key = None + _child_xml_ns_key = {} + +class TestParseSerializableArray(unittest.TestCase): + def test_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_array\(\) " + \ + "missing 5 required positional arguments: " + \ + "'value', 'name', 'instance', " + \ + "'child_type', and 'child_tag'$"): + base.parse_serializable_array() + + def test_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_array\(\) " + \ + "missing 4 required positional arguments: " + \ + "'name', 'instance', 'child_type', " + \ + "and 'child_tag'$"): + base.parse_serializable_array('foo') + + def test_value_name_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_array\(\) " + \ + "missing 3 required positional arguments: " + \ + "'instance', 'child_type', and 'child_tag'$"): + base.parse_serializable_array('foo', 'bar') + + def test_value_name_instance_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_array\(\) " + \ + "missing 2 required positional arguments: " + \ + "'child_type' and 'child_tag'$"): + base.parse_serializable_array('foo', 'bar', + ParseSerializableArrayDummyInstance()) + + def test_value_name_instance_child_type_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_array\(\) " + \ + "missing 1 required positional argument: " + \ + "'child_tag'$"): + base.parse_serializable_array('foo', 'bar', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable) + + def test_none_returns_none(self): + self.assertIsNone(base.\ + parse_serializable_array(None, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child')) + + def test_single_child_type(self): + obj = ParseSerializableArrayDummySerializable('child') + arr = base.parse_serializable_array(obj, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.size, 1) + self.assertIs(arr[0], obj) + + def test_ndarray_of_arrayable(self): + arr = np.array([[1, 2], [3, 4]]) + result = base.parse_serializable_array(arr, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummyArrayable, + 'child') + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.size, 2) + self.assertTrue(all( + isinstance(x, ParseSerializableArrayDummyArrayable) for x in result) + ) + + def test_ndarray_wrong_dtype(self): + arr = np.array([1, 2, 3], dtype=int) + with self.assertRaisesRegex(ValueError, r"Attribute test of array " + \ + "type functionality belonging to class " + \ + "ParseSerializableArrayDummyInstance got " + \ + "an ndarray of dtype int64,and child " + \ + "type is not a subclass of Arrayable.$"): + base.parse_serializable_array(arr, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + + def test_ndarray_wrong_shape(self): + arr = np.empty((2,2), dtype=object) + arr[0,0] = ParseSerializableArrayDummySerializable('child') + arr[0,1] = ParseSerializableArrayDummySerializable('child') + arr[1,0] = ParseSerializableArrayDummySerializable('child') + arr[1,1] = ParseSerializableArrayDummySerializable('child') + with self.assertRaisesRegex(ValueError, r"Attribute test of array " + \ + "type functionality belonging to class " + \ + "ParseSerializableArrayDummyInstance got " + \ + "an ndarray of shape \(2, 2\),but requires " + \ + "a one dimensional array.$"): + base.parse_serializable_array(arr, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + + def test_ndarray_wrong_type(self): + arr = np.array([1, 2, 3], dtype=object) + with self.assertRaisesRegex(TypeError, r"Attribute test of array type " + \ + "functionality belonging to class " + \ + "ParseSerializableArrayDummyInstance got " + \ + "an ndarray containing first element of " + \ + "incompatible type .$"): + base.parse_serializable_array(arr, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + + def test_xml_element(self): + xml = "" + elem = ET.fromstring(xml) + result = base.parse_serializable_array(elem, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.size, 2) + self.assertTrue( + all(isinstance(x, ParseSerializableArrayDummySerializable) + for x in result)) + + def test_xml_element_wrong_size(self): + xml = "" + elem = ET.fromstring(xml) + with self.assertRaisesRegex(ValueError, r"Attribute test of array " + \ + "type functionality belonging to class " + \ + "ParseSerializableArrayDummyInstance got " + \ + "a ElementTree element with size " + \ + "attribute 3, but has 2 child nodes " + \ + "with tag child.$"): + base.parse_serializable_array(elem, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + + def test_list_of_child_type(self): + objs = [ParseSerializableArrayDummySerializable('child'), + ParseSerializableArrayDummySerializable('child')] + arr = base.parse_serializable_array(objs, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.size, 2) + self.assertTrue(all(isinstance(x, ParseSerializableArrayDummySerializable) + for x in arr)) + + def test_list_of_dict(self): + dicts = [{'tag': 'child'}, {'tag': 'child2'}] + arr = base.parse_serializable_array(dicts, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.size, 2) + self.assertTrue(all( + isinstance(x, ParseSerializableArrayDummySerializable) for x in arr) + ) + + def test_list_of_arrayable(self): + arrs = [[1,2], [3,4]] + arr = base.parse_serializable_array(arrs, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummyArrayable, + 'child') + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.size, 2) + self.assertTrue(all( + isinstance(x, ParseSerializableArrayDummyArrayable) for x in arr)) + + def test_list_of_incompatible_type(self): + arrs = [1, 2] + with self.assertRaisesRegex(TypeError, r"Attribute test of array type " + \ + "functionality belonging to class " + \ + "ParseSerializableArrayDummyInstance got " + \ + "a list containing first element of " + \ + "incompatible type .$"): + base.parse_serializable_array(arrs, 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + + def test_empty_list(self): + arr = base.parse_serializable_array([], 'test', + ParseSerializableArrayDummyInstance(), + ParseSerializableArrayDummySerializable, + 'child') + self.assertIsInstance(arr, np.ndarray) + self.assertEqual(arr.size, 0) + +class ParseSerializableListDummySerializable: + @classmethod + def from_node(cls, node, xml_ns, ns_key=None): + return cls(node.tag) + @classmethod + def from_dict(cls, d): + return cls(d['tag']) + def __init__(self, tag): + self.tag = tag + +class ParseSerializableListDummyInstance: + _xml_ns = None + _xml_ns_key = None + _child_xml_ns_key = {} + +class TestParseSerializableList(unittest.TestCase): + def test_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_list\(\) " + \ + "missing 4 required positional arguments: " + \ + "'value', 'name', 'instance', and 'child_type'$"): + base.parse_serializable_list() + + def test_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_list\(\) " + \ + "missing 3 required positional arguments: " + \ + "'name', 'instance', and 'child_type'$"): + base.parse_serializable_list('foo') + + def test_value_name_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_list\(\) " + \ + "missing 2 required positional arguments: " + \ + "'instance' and 'child_type'$"): + base.parse_serializable_list('foo', 'bar') + + def test_value_name_instance_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_serializable_list\(\) " + \ + "missing 1 required positional argument: " + \ + "'child_type'$"): + base.parse_serializable_list('foo', 'bar', + ParseSerializableListDummyInstance()) + + def test_none_returns_none(self): + self.assertIsNone(base.parse_serializable_list(None, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable)) + + def test_single_child_type(self): + obj = ParseSerializableListDummySerializable('child') + result = base.parse_serializable_list(obj, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIs(result[0], obj) + + def test_xml_element(self): + xml = "" + elem = ET.fromstring(xml) + result = base.parse_serializable_list(elem, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertIsInstance(result[0], ParseSerializableListDummySerializable) + self.assertEqual(result[0].tag, "child") + + def test_list_of_child_type(self): + objs = [ParseSerializableListDummySerializable('child'), + ParseSerializableListDummySerializable('child')] + result = base.parse_serializable_list(objs, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertTrue(all( + isinstance(x, ParseSerializableListDummySerializable) for x in result) + ) + + def test_list_of_dict(self): + dicts = [{'tag': 'child'}, {'tag': 'child2'}] + result = base.parse_serializable_list(dicts, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertTrue(all( + isinstance(x, ParseSerializableListDummySerializable) for x in result) + ) + self.assertEqual(result[0].tag, 'child') + self.assertEqual(result[1].tag, 'child2') + + def test_list_of_xml_elements(self): + xml = "" + elem = ET.fromstring(xml) + children = list(elem) + result = base.parse_serializable_list(children, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertTrue(all( + isinstance(x, ParseSerializableListDummySerializable) for x in result) + ) + + def test_list_of_incompatible_type(self): + arrs = [1, 2] + with self.assertRaisesRegex(TypeError, r"Field test of list type " + \ + "functionality belonging to class " + \ + "ParseSerializableListDummyInstance got a " + \ + "list containing first element of " + \ + "incompatible type .$"): + base.parse_serializable_list(arrs, 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + + def test_empty_list(self): + result = base.parse_serializable_list([], 'test', + ParseSerializableListDummyInstance(), + ParseSerializableListDummySerializable) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + +class ParseParametersCollectionDummyInstance: + pass + +class TestParseParametersCollection(unittest.TestCase): + def test_no_params_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_parameters_collection\(\) " + \ + "missing 3 required positional arguments: " + \ + "'value', 'name', and 'instance'$"): + base.parse_parameters_collection() + + def test_value_param_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_parameters_collection\(\) " + \ + "missing 2 required positional arguments: " + \ + "'name' and 'instance'$"): + base.parse_parameters_collection('foo') + + def test_value_name_params_only_fail(self): + with self.assertRaisesRegex(TypeError, r"parse_parameters_collection\(\) " + \ + "missing 1 required positional argument: " + \ + "'instance'$"): + base.parse_parameters_collection('foo', 'bar') + + def test_none_returns_none(self): + self.assertIsNone(base.parse_parameters_collection(None, 'params', + ParseParametersCollectionDummyInstance())) + + def test_dict_returns_dict(self): + d = {'a': '1', 'b': '2'} + result = base.parse_parameters_collection(d, 'params', + ParseParametersCollectionDummyInstance()) + self.assertIsInstance(result, dict) + self.assertEqual(result, d) + + def test_empty_list_returns_empty_ordereddict(self): + result = base.parse_parameters_collection([], 'params', + ParseParametersCollectionDummyInstance()) + self.assertIsInstance(result, OrderedDict) + self.assertEqual(len(result), 0) + + def test_list_of_xml_elements(self): + xml = """ + + A + B + + """ + elem = ET.fromstring(xml) + params = list(elem) + result = base.parse_parameters_collection(params, 'params', + ParseParametersCollectionDummyInstance()) + self.assertIsInstance(result, OrderedDict) + self.assertEqual(result['alpha'], 'A') + self.assertEqual(result['beta'], 'B') + + def test_list_of_xml_elements_empty_text(self): + xml = """ + + + + + """ + elem = ET.fromstring(xml) + params = list(elem) + result = base.parse_parameters_collection(params, 'params', + ParseParametersCollectionDummyInstance()) + self.assertIsInstance(result, OrderedDict) + self.assertIsNone(result['alpha']) + self.assertIsNone(result['beta']) + + def test_list_of_incompatible_type_raises(self): + with self.assertRaisesRegex(TypeError, r"Field params of list type " + \ + "functionality belonging to class " + \ + "ParseParametersCollectionDummyInstance " + \ + "got a list containing first element of " + \ + "incompatible type .$"): + base.parse_parameters_collection([1, 2], 'params', + ParseParametersCollectionDummyInstance()) + + def test_incompatible_type_raises(self): + with self.assertRaisesRegex(TypeError, r"Field params of class " + \ + "ParseParametersCollectionDummyInstance " + \ + "got incompatible type .$"): + base.parse_parameters_collection("not_a_list_or_dict", 'params', + ParseParametersCollectionDummyInstance()) + diff --git a/tests/processing/sidd/test_sidd_nitf_header_creation.py b/tests/processing/sidd/test_sidd_nitf_header_creation.py new file mode 100644 index 00000000..b27b05de --- /dev/null +++ b/tests/processing/sidd/test_sidd_nitf_header_creation.py @@ -0,0 +1,150 @@ +''' +This unit test determines if the NITF header value of FDT (FileDateTime) is properly updated when the SIDD product is created. +The NITF header includes (at least) two datetime values. IDATIM is the datetime that the image was collected (aka acquired). +FDT is the datetime that the NITF (SIDD) file was created. IDATIM and FDT should have different values. +''' + +import json +import os + +import pytest +import logging +import unittest +from tests import parse_file_entry # fails unless run using pytest +from datetime import datetime, timezone + +from sarpy.io.complex.converter import conversion_utility, open_complex +from sarpy.processing.ortho_rectify import NearestNeighborMethod +from sarpy.processing.sidd.sidd_product_creation import \ + create_detected_image_sidd, create_dynamic_image_sidd +import sarpy.visualization.remap as remap + +from sarpy.utils import create_product + +from sarpy.io.general.nitf import NITFDetails + +complex_file_types = {} +this_loc = os.path.abspath(__file__) + +# JSON file that specifies test file locations on the local system +file_reference = os.path.join(os.path.split(this_loc)[0], \ + 'complex_file_types.json') +# Find valid files from file_refernce and add them to the valid_entries list +if os.path.isfile(file_reference): + with open(file_reference, 'r') as local_file: + test_files_list = json.load(local_file) + for test_files_type in test_files_list: + valid_entries = [] + for entry in test_files_list[test_files_type]: + the_file = parse_file_entry(entry) + if the_file is not None: + valid_entries.append(the_file) + complex_file_types[test_files_type] = valid_entries + +sicd_files = complex_file_types.get('SICD', []) + +# Determine a valid file reader for the complex input image +def get_test_reader(idx): + if idx >= len(sicd_files): + return None + input_file = sicd_files[idx] + reader = open_complex(input_file) + return reader + +''' Read a single input SICD image, create the detected-image NITF SIDD product, then +read the NITF header from the newly created SIDD product and test that the FDT value +is different than the IDATIM value and that the FDT value is close to the current time +(becuase the NITF SIDD product was just created). +''' +def test_nitf_fdt_updated_for_detected_image_sidd(tmp_path): + local_reader = get_test_reader(0) + ortho_helper = NearestNeighborMethod(local_reader, index=0) + output_directory = tmp_path + output_file = 'output.sidd' + + # create SIDD product + test_sidd = create_detected_image_sidd(ortho_helper, output_directory, output_file) + + # Full path to the created SIDD file + sidd_file = os.path.join(*output_directory.parts,output_file) + + # Get NITF header data from created SIDD file + details = NITFDetails(sidd_file) + + # Get datetime values from the NITF header data + fdt_datetime = datetime.strptime(details.nitf_header.FDT,"%Y%m%d%H%M%S%f") + collection_datetime = datetime.strptime(details.img_headers[0].IDATIM,"%Y%m%d%H%M%S%f") + + # Since fdt_datetime and collection_datetime are realtive to Zulu but that information + # is not represented in their datetime python objects, we need to determine the + # current datetime relative to Zule, but then remove the tiemzone info from the + # object in order to compute time deltas later. + current_time_zulu = datetime.now(timezone.utc).replace(tzinfo=None) + + # Compute time difference between FDT (presumably current time) and IDATIM (collection time) + time_delta = fdt_datetime - collection_datetime + + # boolean representing collection time is before file creation time + fdt_gt_cdt = time_delta.total_seconds() > 0 + + # Compute time difference between current time the FDT from NITF header + recent_time_delta = current_time_zulu - fdt_datetime + + # Boolean representig that the SIDD file was created within the past 2 minutes + fdt_is_recent = recent_time_delta.total_seconds() < 120 + + assert (fdt_gt_cdt and fdt_is_recent), 'NITF FileDateTime should be greater than IDATIM: {} > {}?'.format(details.nitf_header.FDT,details.img_headers[0].IDATIM) + +''' Read a single input SICD image, create the dynamic-image NITF SIDD product, then +read the NITF header from the newly created SIDD product and test that the FDT value +is different than the IDATIM value and that the FDT value is close to the current time +(becuase the NITF SIDD product was just created). It was observed during debug that +the FDT and IDATIM values are recomputed for each subaerture of the dynamic-image, but +that information may not be represented in the NITF header data. +''' +def test_nitf_fdt_updated_for_dynamic_image_sidd(tmp_path): + local_reader = get_test_reader(0) + ortho_helper = NearestNeighborMethod(local_reader, index=0) + output_directory = tmp_path + output_file = 'output.sidd' + + # create SIDD product + test_sidd = create_dynamic_image_sidd(ortho_helper, output_directory, output_file) + + # Full path to the created SIDD file + sidd_file = os.path.join(*output_directory.parts,output_file) + + # Get NITF header data from created SIDD file + details = NITFDetails(sidd_file) + + # Get datetime values from the NITF header data + fdt_datetime = datetime.strptime(details.nitf_header.FDT,"%Y%m%d%H%M%S%f") + collection_datetime = datetime.strptime(details.img_headers[0].IDATIM,"%Y%m%d%H%M%S%f") + + # Since fdt_datetime and collection_datetime are realtive to Zulu but that information + # is not represented in their datetime python objects, we need to determine the + # current datetime relative to Zule, but then remove the tiemzone info from the + # object in order to compute time deltas later. + current_time_zulu = datetime.now(timezone.utc).replace(tzinfo=None) + + # Compute time difference between FDT (presumably current time) and IDATIM (collection time) + time_delta = fdt_datetime - collection_datetime + + # boolean representing collection time is before file creation time + fdt_gt_cdt = time_delta.total_seconds() > 0 + + # Compute time difference between current time the FDT from NITF header + recent_time_delta = current_time_zulu - fdt_datetime + + # Boolean representig that the SIDD file was created within the past 2 minutes + fdt_is_recent = recent_time_delta.total_seconds() < 120 + + assert (fdt_gt_cdt and fdt_is_recent), 'NITF FileDateTime should be greater than IDATIM: {} > {}?'.format(details.nitf_header.FDT,details.img_headers[0].IDATIM) + +if __name__ == '__main__': + unittest.main() + + + + + diff --git a/tests/test_compliance.py b/tests/test_compliance.py new file mode 100644 index 00000000..89a3c2b1 --- /dev/null +++ b/tests/test_compliance.py @@ -0,0 +1,35 @@ +__classification__ = "UNCLASSIFIED" +__author__ = "Tex Peterson" + +import pytest, os +from unittest import TestCase + +from sarpy.compliance import SarpyError, bytes_to_string + +# Test the SarpyError class. +# The SarpyError class is a pass through. +# The only way to test it is to try an operation that fails and check if it is +# raised. +def test_SarpyError() : + file_name = "bad/filename" + try: + os.path.exists(file_name) + assert False + except: + pytest.raises(SarpyError) + + +class Test_bytes_to_string(TestCase): + def setUp(self): + self.text_string = "Hello, world!" + self.byte_data = self.text_string.encode('utf-8') + + def testStringInputSuccess(self): + self.assertEqual(self.text_string, bytes_to_string(self.text_string)) + + def testBadInputFail(self): + with self.assertRaisesRegex(TypeError, 'Input is required to be bytes. Got type*'): + bytes_to_string(11) + + def testByteInputSuccess(self): + self.assertEqual(self.text_string, bytes_to_string(self.byte_data)) \ No newline at end of file