diff --git a/src/sst/core/model/xmlToPython.py b/src/sst/core/model/xmlToPython.py index 532277ea5..a6c8915f7 100755 --- a/src/sst/core/model/xmlToPython.py +++ b/src/sst/core/model/xmlToPython.py @@ -79,16 +79,16 @@ def processParamSets(groups: ET.Element) -> None: for group in groups: params = dict() for p in group: - params[getParamName(p)] = processString(p.text.strip()) # type: ignore + params[getParamName(p)] = processString(p.text.strip()) # type: ignore [union-attr] sstParams[group.tag] = params def processVars(varNode: ET.Element) -> None: for var in varNode: - sstVars[var.tag] = processString(var.text.strip()) # type: ignore + sstVars[var.tag] = processString(var.text.strip()) # type: ignore [union-attr] def processConfig(cfg: ET.Element) -> None: - for line in cfg.text.strip().splitlines(): # type: ignore + for line in cfg.text.strip().splitlines(): # type: ignore [union-attr] var, val = line.split('=') sst.setProgramOption(var, processString(val)) # strip quotes @@ -107,7 +107,7 @@ def buildComp(compNode: ET.Element) -> None: for paramInc in paramsNode.attrib['include'].split(','): params.update(sstParams[processString(paramInc)]) for p in paramsNode: - params[getParamName(p)] = processString(p.text.strip()) # type: ignore + params[getParamName(p)] = processString(p.text.strip()) # type: ignore [union-attr] comp.addParams(params) @@ -147,7 +147,7 @@ def build(root: ET.Element) -> None: if paramSets is not None: processParamSets(paramSets) if timebase is not None: - sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore + sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore [union-attr] if cfg is not None: processConfig(cfg) if graph is not None: diff --git a/src/sst/core/testingframework/test_engine_support.py b/src/sst/core/testingframework/test_engine_support.py index 5712c850d..2a30138b9 100644 --- a/src/sst/core/testingframework/test_engine_support.py +++ b/src/sst/core/testingframework/test_engine_support.py @@ -25,7 +25,7 @@ import inspect import signal from subprocess import TimeoutExpired -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import test_engine_globals @@ -40,8 +40,14 @@ class OSCommand: """ ### - def __init__(self, cmd_str, output_file_path=None, error_file_path=None, - set_cwd=None, use_shell=False): + def __init__( + self, + cmd_str: str, + output_file_path: Optional[str] = None, + error_file_path: Optional[str] = None, + set_cwd: Optional[str] = None, + use_shell: bool = False, + ) -> None: """ Args: cmd_str (str): The command to be executed @@ -55,10 +61,13 @@ def __init__(self, cmd_str, output_file_path=None, error_file_path=None, """ self._output_file_path = None self._error_file_path = None - self._cmd_str = None - self._process = None + self._cmd_str = [""] + self._process: subprocess.Popen[str] = None # type: ignore [assignment] self._timeout_sec = 60 - self._run_status = None + # Use an invalid return code rather than None to identify an + # unintialized value. + self._run_status_sentinel = -999 + self._run_status = self._run_status_sentinel self._run_output = '' self._run_error = '' self._run_timeout = False @@ -71,7 +80,12 @@ def __init__(self, cmd_str, output_file_path=None, error_file_path=None, self._signal_sec = 3 #### - def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs): + def run(self, + timeout_sec: int = 60, + send_signal: int = signal.NSIG, + signal_sec: int = 3, + **kwargs: Any, + ) -> "OSCommandResult": """ Run a command then return and OSCmdRtn object. Args: @@ -79,8 +93,7 @@ def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs): will be terminated and a timeout error will occur. kwargs: Extra parameters e.g., timeout_sec to override the default timeout """ - if not (isinstance(timeout_sec, (int, float)) and not isinstance(timeout_sec, bool)): - raise ValueError("ERROR: Timeout must be an int or a float") + check_param_type("timeout_sec", timeout_sec, int) self._timeout_sec = timeout_sec self._signal = send_signal @@ -92,6 +105,7 @@ def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs): thread.join(self._timeout_sec) if thread.is_alive(): self._run_timeout = True + assert self._process is not None self._process.kill() thread.join() @@ -102,7 +116,7 @@ def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs): #### - def _run_cmd_in_subprocess(self, **kwargs): + def _run_cmd_in_subprocess(self, **kwargs: Any) -> None: """ Run the command in a subprocess """ file_out = None file_err = None @@ -163,20 +177,17 @@ def _run_cmd_in_subprocess(self, **kwargs): #### - def _validate_cmd_str(self, cmd_str): + def _validate_cmd_str(self, cmd_str: str) -> None: """ Validate the cmd_str """ - if isinstance(cmd_str, str): - if cmd_str != "": - cmd_str = shlex.split(cmd_str) - else: - raise ValueError("ERROR: OSCommand() cmd_str must not be empty") - else: + if not isinstance(cmd_str, str): raise ValueError("ERROR: OSCommand() cmd_str must be a string") - self._cmd_str = cmd_str + elif not cmd_str: + raise ValueError("ERROR: OSCommand() cmd_str must not be empty") + self._cmd_str = shlex.split(cmd_str) #### - def _validate_output_path(self, file_path): + def _validate_output_path(self, file_path: Optional[str]) -> Optional[str]: """ Validate the output file path """ if file_path is not None: dirpath = os.path.abspath(os.path.dirname(file_path)) @@ -188,12 +199,12 @@ def _validate_output_path(self, file_path): ################################################################################ -class OSCommandResult(): +class OSCommandResult: """ This class returns result data about the OSCommand that was executed """ - def __init__(self, cmd_str, status, output, error, timeout): + def __init__(self, cmd_str: List[str], status: int, output: str, error: str, timeout: bool) -> None: """ Args: - cmd_str (str): The command to be executed + cmd_str (list[str]): The command to be executed status (int): The return status of the command execution. output (str): The standard output of the command execution. error (str): The error output of the command execution. @@ -207,7 +218,7 @@ def __init__(self, cmd_str, status, output, error, timeout): #### - def __repr__(self): + def __repr__(self) -> str: rtn_str = (("Cmd = {0}; Status = {1}; Timeout = {2}; ") + ("Error = {3}; Output = {4}")).format(self._run_cmd_str, \ self._run_status, self._run_timeout, self._run_error, \ @@ -216,24 +227,24 @@ def __repr__(self): #### - def __str__(self): + def __str__(self) -> str: return self.__repr__() #### - def cmd(self): + def cmd(self) -> List[str]: """ return the command that was run """ return self._run_cmd_str #### - def result(self): + def result(self) -> int: """ return the run status result """ return self._run_status #### - def output(self): + def output(self) -> str: """ return the run output result """ # Sometimes the output can be a unicode or a byte string - convert it if isinstance(self._run_output, bytes): @@ -242,7 +253,7 @@ def output(self): #### - def error(self): + def error(self) -> str: """ return the run error output result """ # Sometimes the output can be a unicode or a byte string - convert it if isinstance(self._run_error, bytes): @@ -251,13 +262,13 @@ def error(self): #### - def timeout(self): + def timeout(self) -> bool: """ return true if the run timed out """ return self._run_timeout ################################################################################ -def check_param_type(varname, vardata, datatype): +def check_param_type(varname: str, vardata: Any, datatype: Type[Any]) -> None: """ Validate a parameter to ensure it is of the correct type. Args: @@ -278,96 +289,10 @@ def check_param_type(varname, vardata, datatype): ################################################################################ -def strclass(cls): +def strclass(cls: Type[Any]) -> str: """ Return the classname of a class""" return "%s" % (cls.__module__) -def strqual(cls): +def strqual(cls: Type[Any]) -> str: """ Return the qualname of a class""" - return "%s" % (_qualname(cls)) - -################################################################################ -# qualname from https://github.com/wbolster/qualname to support Py2 and Py3 -# LICENSE -> https://github.com/wbolster/qualname/blob/master/LICENSE.rst -#__all__ = ['qualname'] - -_cache = {} - -def _qualname(obj): - """Find out the qualified name for a class or function.""" - - # For Python 3.3+, this is straight-forward. - if hasattr(obj, '__qualname__'): - return obj.__qualname__ - - # For older Python versions, things get complicated. - # Obtain the filename and the line number where the - # class/method/function is defined. - try: - filename = inspect.getsourcefile(obj) - except TypeError: - return obj.__qualname__ # raises a sensible error - if not filename: - return obj.__qualname__ # raises a sensible error - if inspect.isclass(obj): - try: - _, lineno = inspect.getsourcelines(obj) - except (OSError, IOError): - return obj.__qualname__ # raises a sensible error - elif inspect.isfunction(obj) or inspect.ismethod(obj): - if hasattr(obj, 'im_func'): - # Extract function from unbound method (Python 2) - obj = obj.im_func - try: - code = obj.__code__ - except AttributeError: - code = obj.func_code - lineno = code.co_firstlineno - else: - return obj.__qualname__ # raises a sensible error - - # Re-parse the source file to figure out what the - # __qualname__ should be by analysing the abstract - # syntax tree. Use a cache to avoid doing this more - # than once for the same file. - qualnames = _cache.get(filename) - if qualnames is None: - with open(filename, 'r') as filehandle: - source = filehandle.read() - node = ast.parse(source, filename) - visitor = _Visitor() - visitor.visit(node) - _cache[filename] = qualnames = visitor.qualnames - try: - return qualnames[lineno] - except KeyError: - return obj.__qualname__ # raises a sensible error - - -class _Visitor(ast.NodeVisitor): - """Support class for qualname function""" - def __init__(self): - super(_Visitor, self).__init__() - self.stack = [] - self.qualnames = {} - - def store_qualname(self, lineno): - """Support method for _Visitor class""" - q_n = ".".join(n for n in self.stack) - self.qualnames[lineno] = q_n - - def visit_FunctionDef(self, node): - """Support method for _Visitor class""" - self.stack.append(node.name) - self.store_qualname(node.lineno) - self.stack.append('') - self.generic_visit(node) - self.stack.pop() - self.stack.pop() - - def visit_ClassDef(self, node): - """Support method for _Visitor class""" - self.stack.append(node.name) - self.store_qualname(node.lineno) - self.generic_visit(node) - self.stack.pop() + return "%s" % (cls.__qualname__)