Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 114 additions & 63 deletions src/sst/core/testingframework/test_engine_junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
import re
import xml.etree.ElementTree as ET
import xml.dom.minidom
from typing import IO, List, Mapping, Optional
from typing import DefaultDict, Dict, IO, List, Mapping, Optional, Union

Entry = Dict[str, Optional[str]]

################################################################################

Expand Down Expand Up @@ -88,15 +90,27 @@

################################################################################

class JUnitTestSuite(object):
class JUnitTestSuite:
"""
Suite of test cases.
Can handle unicode strings or binary strings if their encoding is provided.
"""

def __init__(self, name, test_cases=None, hostname=None, id=None,
package=None, timestamp=None, properties=None, file=None,
log=None, url=None, stdout=None, stderr=None):
def __init__(
self,
name: str,
test_cases: Optional[List["JUnitTestCase"]] = None,
hostname: Optional[str] = None,
id: Optional[str] = None,
package: Optional[str] = None,
timestamp: Optional[str] = None,
properties: Optional[Mapping[str, str]] = None,
file: Optional[str] = None,
log: Optional[str] = None,
url: Optional[str] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
) -> None:
self.name = name
if not test_cases:
test_cases = []
Expand All @@ -118,7 +132,7 @@ def __init__(self, name, test_cases=None, hostname=None, id=None,

####

def junit_build_xml_doc(self, encoding=None):
def junit_build_xml_doc(self, encoding: Optional[str] = None) -> ET.Element:
"""
Builds the XML document for the JUnit test suite.
Produces clean unicode strings and decodes non-unicode with the help of encoding.
Expand Down Expand Up @@ -184,35 +198,35 @@ def junit_build_xml_doc(self, encoding=None):
stderr_element.text = _junit_decode(self.stderr, encoding)

# test cases
for case in self.test_cases:
for test_case in self.test_cases:
test_case_attributes = dict()
test_case_attributes["name"] = _junit_decode(case.name, encoding)
if case.assertions:
test_case_attributes["name"] = _junit_decode(test_case.name, encoding)
if test_case.assertions:
# Number of assertions in the test case
test_case_attributes["assertions"] = "%d" % case.assertions
if case.elapsed_sec:
test_case_attributes["time"] = "%f" % case.elapsed_sec
if case.timestamp:
test_case_attributes["timestamp"] = _junit_decode(case.timestamp, encoding)
if case.classname:
test_case_attributes["classname"] = _junit_decode(case.classname, encoding)
if case.status:
test_case_attributes["status"] = _junit_decode(case.status, encoding)
if case.category:
test_case_attributes["class"] = _junit_decode(case.category, encoding)
if case.file:
test_case_attributes["file"] = _junit_decode(case.file, encoding)
if case.line:
test_case_attributes["line"] = _junit_decode(case.line, encoding)
if case.log:
test_case_attributes["log"] = _junit_decode(case.log, encoding)
if case.url:
test_case_attributes["url"] = _junit_decode(case.url, encoding)
test_case_attributes["assertions"] = "%d" % test_case.assertions # type: ignore [str-format]
if test_case.elapsed_sec:
test_case_attributes["time"] = "%f" % test_case.elapsed_sec
if test_case.timestamp:
test_case_attributes["timestamp"] = _junit_decode(test_case.timestamp, encoding)
if test_case.classname:
test_case_attributes["classname"] = _junit_decode(test_case.classname, encoding)
if test_case.status:
test_case_attributes["status"] = _junit_decode(test_case.status, encoding)
if test_case.category:
test_case_attributes["class"] = _junit_decode(test_case.category, encoding)
if test_case.file:
test_case_attributes["file"] = _junit_decode(test_case.file, encoding)
if test_case.line:
test_case_attributes["line"] = _junit_decode(test_case.line, encoding)
if test_case.log:
test_case_attributes["log"] = _junit_decode(test_case.log, encoding)
if test_case.url:
test_case_attributes["url"] = _junit_decode(test_case.url, encoding)

test_case_element = ET.SubElement(xml_element, "testcase", test_case_attributes)

# failures
for failure in case.failures:
for failure in test_case.failures:
if failure["output"] or failure["message"]:
attrs = {"type": "failure"}
if failure["message"]:
Expand All @@ -225,7 +239,7 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(failure_element)

# errors
for error in case.errors:
for error in test_case.errors:
if error["message"] or error["output"]:
attrs = {"type": "error"}
if error["message"]:
Expand All @@ -238,7 +252,7 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(error_element)

# skippeds
for skipped in case.skipped:
for skipped in test_case.skipped:
attrs = {"type": "skipped"}
if skipped["message"]:
attrs["message"] = _junit_decode(skipped["message"], encoding)
Expand All @@ -248,28 +262,41 @@ def junit_build_xml_doc(self, encoding=None):
test_case_element.append(skipped_element)

# test stdout
if case.stdout:
if test_case.stdout:
stdout_element = ET.Element("system-out")
stdout_element.text = _junit_decode(case.stdout, encoding)
stdout_element.text = _junit_decode(test_case.stdout, encoding)
test_case_element.append(stdout_element)

# test stderr
if case.stderr:
if test_case.stderr:
stderr_element = ET.Element("system-err")
stderr_element.text = _junit_decode(case.stderr, encoding)
stderr_element.text = _junit_decode(test_case.stderr, encoding)
test_case_element.append(stderr_element)

return xml_element

####

class JUnitTestCase(object):
class JUnitTestCase:
"""A JUnit test case with a result and possibly some stdout or stderr"""

def __init__(self, name, classname=None, elapsed_sec=None, stdout=None,
stderr=None, assertions=None, timestamp=None, status=None,
category=None, file=None, line=None, log=None, url=None,
allow_multiple_subelements=False):
def __init__(
self,
name: str,
classname: Optional[str] = None,
elapsed_sec: Optional[float] = None,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
assertions: Optional[str] = None,
timestamp: Optional[str] = None,
status: Optional[str] = None,
category: Optional[str] = None,
file: Optional[str] = None,
line: Optional[str] = None,
log: Optional[str] = None,
url: Optional[str] = None,
allow_multiple_subelements: bool = False,
) -> None:
self.name = name
self.assertions = assertions
self.elapsed_sec = elapsed_sec
Expand All @@ -284,18 +311,23 @@ def __init__(self, name, classname=None, elapsed_sec=None, stdout=None,
self.stdout = stdout
self.stderr = stderr
self.is_enabled = True
self.errors = []
self.failures = []
self.skipped = []
self.allow_multiple_subalements = allow_multiple_subelements

def junit_add_error_info(self, message=None, output=None, error_type=None):
self.errors: List[Entry] = []
self.failures: List[Entry] = []
self.skipped: List[Entry] = []
self.allow_multiple_subelements = allow_multiple_subelements

def junit_add_error_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
error_type: Optional[str] = None,
) -> None:
"""Adds an error message, output, or both to the test case"""
error = {}
error["message"] = message
error["output"] = output
error["type"] = error_type
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.errors.append(error)
elif not len(self.errors):
Expand All @@ -308,13 +340,18 @@ def junit_add_error_info(self, message=None, output=None, error_type=None):
if error_type:
self.errors[0]["type"] = error_type

def junit_add_failure_info(self, message=None, output=None, failure_type=None):
def junit_add_failure_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
failure_type: Optional[str] = None,
) -> None:
"""Adds a failure message, output, or both to the test case"""
failure = {}
failure["message"] = message
failure["output"] = output
failure["type"] = failure_type
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.failures.append(failure)
elif not len(self.failures):
Expand All @@ -327,12 +364,16 @@ def junit_add_failure_info(self, message=None, output=None, failure_type=None):
if failure_type:
self.failures[0]["type"] = failure_type

def junit_add_skipped_info(self, message=None, output=None):
def junit_add_skipped_info(
self,
message: Optional[str] = None,
output: Optional[str] = None,
) -> None:
"""Adds a skipped message, output, or both to the test case"""
skipped = {}
skipped["message"] = message
skipped["output"] = output
if self.allow_multiple_subalements:
if self.allow_multiple_subelements:
if message or output:
self.skipped.append(skipped)
elif not len(self.skipped):
Expand All @@ -343,25 +384,29 @@ def junit_add_skipped_info(self, message=None, output=None):
if output:
self.skipped[0]["output"] = output

def junit_add_elapsed_sec(self, elapsed_sec):
def junit_add_elapsed_sec(self, elapsed_sec: float) -> None:
"""Add the elapsed time to the testcase"""
self.elapsed_sec = elapsed_sec

def junit_is_failure(self):
def junit_is_failure(self) -> bool:
"""returns true if this test case is a failure"""
return sum(1 for f in self.failures if f["message"] or f["output"]) > 0

def junit_is_error(self):
def junit_is_error(self) -> bool:
"""returns true if this test case is an error"""
return sum(1 for e in self.errors if e["message"] or e["output"]) > 0

def junit_is_skipped(self):
def junit_is_skipped(self) -> bool:
"""returns true if this test case has been skipped"""
return len(self.skipped) > 0

####

def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
def junit_to_xml_report_string(
test_suites: List["JUnitTestSuite"],
prettyprint: bool = True,
encoding: Optional[str] = None,
) -> str:
"""
Returns the string representation of the JUnit XML document.
@param encoding: The encoding of the input.
Expand All @@ -374,7 +419,7 @@ def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
raise TypeError("test_suites must be a list of test suites")

xml_element = ET.Element("testsuites")
attributes = defaultdict(int)
attributes: DefaultDict[str, Union[int, float]] = defaultdict(int)
for ts in test_suites:
ts_xml = ts.junit_build_xml_doc(encoding=encoding)
for key in ["disabled", "errors", "failures", "tests"]:
Expand All @@ -396,18 +441,24 @@ def junit_to_xml_report_string(test_suites, prettyprint=True, encoding=None):
if prettyprint:
# minidom.parseString() works just on correctly encoded binary strings
xml_string = xml_string.encode(encoding or "utf-8")
xml_string = xml.dom.minidom.parseString(xml_string)
xml_string_document = xml.dom.minidom.parseString(xml_string)
# toprettyxml() produces unicode if no encoding is being passed
# or binary string with an encoding
xml_string = xml_string.toprettyxml(encoding=encoding)
if encoding:
xml_string = xml_string_document.toprettyxml(encoding=encoding)
if isinstance(xml_string, bytes):
assert encoding is not None
xml_string = xml_string.decode(encoding)
# is unicode now
return xml_string

####

def junit_to_xml_report_file(file_descriptor, test_suites, prettyprint=True, encoding=None):
def junit_to_xml_report_file(
file_descriptor: IO[str],
test_suites: List["JUnitTestSuite"],
prettyprint: bool = True,
encoding: Optional[str] = None,
) -> None:
"""
Writes the JUnit XML document to a file.
"""
Expand All @@ -417,15 +468,15 @@ def junit_to_xml_report_file(file_descriptor, test_suites, prettyprint=True, enc

####

def _junit_decode(var, encoding):
def _junit_decode(var: Optional[str], encoding: Optional[str]) -> str:
"""
If not already unicode, decode it.
"""
return str(var)

####

def _junit_clean_illegal_xml_chars(string_to_clean):
def _junit_clean_illegal_xml_chars(string_to_clean: str) -> str:
"""
Removes any illegal unicode characters from the given XML string.
@see: http://stackoverflow.com/questions/1707890/fast-way-to-filter-
Expand Down
Loading