Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/sst/core/model/xmlToPython.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ def processParamSets(groups: ET.Element) -> None:
for group in groups:
params = dict()
for p in group:
params[getParamName(p)] = processString(p.text.strip()) # type: ignore
params[getParamName(p)] = processString(p.text.strip()) # type: ignore [union-attr]
sstParams[group.tag] = params


def processVars(varNode: ET.Element) -> None:
for var in varNode:
sstVars[var.tag] = processString(var.text.strip()) # type: ignore
sstVars[var.tag] = processString(var.text.strip()) # type: ignore [union-attr]

def processConfig(cfg: ET.Element) -> None:
for line in cfg.text.strip().splitlines(): # type: ignore
for line in cfg.text.strip().splitlines(): # type: ignore [union-attr]
var, val = line.split('=')
sst.setProgramOption(var, processString(val)) # strip quotes

Expand All @@ -107,7 +107,7 @@ def buildComp(compNode: ET.Element) -> None:
for paramInc in paramsNode.attrib['include'].split(','):
params.update(sstParams[processString(paramInc)])
for p in paramsNode:
params[getParamName(p)] = processString(p.text.strip()) # type: ignore
params[getParamName(p)] = processString(p.text.strip()) # type: ignore [union-attr]

comp.addParams(params)

Expand Down Expand Up @@ -147,7 +147,7 @@ def build(root: ET.Element) -> None:
if paramSets is not None:
processParamSets(paramSets)
if timebase is not None:
sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore
sst.setProgramOption('timebase', timebase.text.strip()) # type: ignore [union-attr]
if cfg is not None:
processConfig(cfg)
if graph is not None:
Expand Down
163 changes: 44 additions & 119 deletions src/sst/core/testingframework/test_engine_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import inspect
import signal
from subprocess import TimeoutExpired
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type

import test_engine_globals

Expand All @@ -40,8 +40,14 @@ class OSCommand:
"""
###

def __init__(self, cmd_str, output_file_path=None, error_file_path=None,
set_cwd=None, use_shell=False):
def __init__(
self,
cmd_str: str,
output_file_path: Optional[str] = None,
error_file_path: Optional[str] = None,
set_cwd: Optional[str] = None,
use_shell: bool = False,
) -> None:
"""
Args:
cmd_str (str): The command to be executed
Expand All @@ -55,10 +61,13 @@ def __init__(self, cmd_str, output_file_path=None, error_file_path=None,
"""
self._output_file_path = None
self._error_file_path = None
self._cmd_str = None
self._process = None
self._cmd_str = [""]
self._process: subprocess.Popen[str] = None # type: ignore [assignment]
self._timeout_sec = 60
self._run_status = None
# Use an invalid return code rather than None to identify an
# unintialized value.
self._run_status_sentinel = -999
self._run_status = self._run_status_sentinel
self._run_output = ''
self._run_error = ''
self._run_timeout = False
Expand All @@ -71,16 +80,20 @@ def __init__(self, cmd_str, output_file_path=None, error_file_path=None,
self._signal_sec = 3
####

def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs):
def run(self,
timeout_sec: int = 60,
send_signal: int = signal.NSIG,
signal_sec: int = 3,
**kwargs: Any,
) -> "OSCommandResult":
""" Run a command then return and OSCmdRtn object.

Args:
timeout_sec (int): The maximum runtime in seconds before thread
will be terminated and a timeout error will occur.
kwargs: Extra parameters e.g., timeout_sec to override the default timeout
"""
if not (isinstance(timeout_sec, (int, float)) and not isinstance(timeout_sec, bool)):
raise ValueError("ERROR: Timeout must be an int or a float")
check_param_type("timeout_sec", timeout_sec, int)

self._timeout_sec = timeout_sec
self._signal = send_signal
Expand All @@ -92,6 +105,7 @@ def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs):
thread.join(self._timeout_sec)
if thread.is_alive():
self._run_timeout = True
assert self._process is not None
self._process.kill()
thread.join()

Expand All @@ -102,7 +116,7 @@ def run(self, timeout_sec=60, send_signal=signal.NSIG, signal_sec=3, **kwargs):

####

def _run_cmd_in_subprocess(self, **kwargs):
def _run_cmd_in_subprocess(self, **kwargs: Any) -> None:
""" Run the command in a subprocess """
file_out = None
file_err = None
Expand Down Expand Up @@ -163,20 +177,17 @@ def _run_cmd_in_subprocess(self, **kwargs):

####

def _validate_cmd_str(self, cmd_str):
def _validate_cmd_str(self, cmd_str: str) -> None:
""" Validate the cmd_str """
if isinstance(cmd_str, str):
if cmd_str != "":
cmd_str = shlex.split(cmd_str)
else:
raise ValueError("ERROR: OSCommand() cmd_str must not be empty")
else:
if not isinstance(cmd_str, str):
raise ValueError("ERROR: OSCommand() cmd_str must be a string")
self._cmd_str = cmd_str
elif not cmd_str:
raise ValueError("ERROR: OSCommand() cmd_str must not be empty")
self._cmd_str = shlex.split(cmd_str)

####

def _validate_output_path(self, file_path):
def _validate_output_path(self, file_path: Optional[str]) -> Optional[str]:
""" Validate the output file path """
if file_path is not None:
dirpath = os.path.abspath(os.path.dirname(file_path))
Expand All @@ -188,12 +199,12 @@ def _validate_output_path(self, file_path):

################################################################################

class OSCommandResult():
class OSCommandResult:
""" This class returns result data about the OSCommand that was executed """
def __init__(self, cmd_str, status, output, error, timeout):
def __init__(self, cmd_str: List[str], status: int, output: str, error: str, timeout: bool) -> None:
"""
Args:
cmd_str (str): The command to be executed
cmd_str (list[str]): The command to be executed
status (int): The return status of the command execution.
output (str): The standard output of the command execution.
error (str): The error output of the command execution.
Expand All @@ -207,7 +218,7 @@ def __init__(self, cmd_str, status, output, error, timeout):

####

def __repr__(self):
def __repr__(self) -> str:
rtn_str = (("Cmd = {0}; Status = {1}; Timeout = {2}; ") +
("Error = {3}; Output = {4}")).format(self._run_cmd_str, \
self._run_status, self._run_timeout, self._run_error, \
Expand All @@ -216,24 +227,24 @@ def __repr__(self):

####

def __str__(self):
def __str__(self) -> str:
return self.__repr__()

####

def cmd(self):
def cmd(self) -> List[str]:
""" return the command that was run """
return self._run_cmd_str

####

def result(self):
def result(self) -> int:
""" return the run status result """
return self._run_status

####

def output(self):
def output(self) -> str:
""" return the run output result """
# Sometimes the output can be a unicode or a byte string - convert it
if isinstance(self._run_output, bytes):
Expand All @@ -242,7 +253,7 @@ def output(self):

####

def error(self):
def error(self) -> str:
""" return the run error output result """
# Sometimes the output can be a unicode or a byte string - convert it
if isinstance(self._run_error, bytes):
Expand All @@ -251,13 +262,13 @@ def error(self):

####

def timeout(self):
def timeout(self) -> bool:
""" return true if the run timed out """
return self._run_timeout

################################################################################

def check_param_type(varname, vardata, datatype):
def check_param_type(varname: str, vardata: Any, datatype: Type[Any]) -> None:
""" Validate a parameter to ensure it is of the correct type.

Args:
Expand All @@ -278,96 +289,10 @@ def check_param_type(varname, vardata, datatype):

################################################################################

def strclass(cls):
def strclass(cls: Type[Any]) -> str:
""" Return the classname of a class"""
return "%s" % (cls.__module__)

def strqual(cls):
def strqual(cls: Type[Any]) -> str:
""" Return the qualname of a class"""
return "%s" % (_qualname(cls))

################################################################################
# qualname from https://github.com/wbolster/qualname to support Py2 and Py3
# LICENSE -> https://github.com/wbolster/qualname/blob/master/LICENSE.rst
#__all__ = ['qualname']

_cache = {}

def _qualname(obj):
"""Find out the qualified name for a class or function."""

# For Python 3.3+, this is straight-forward.
if hasattr(obj, '__qualname__'):
return obj.__qualname__

# For older Python versions, things get complicated.
# Obtain the filename and the line number where the
# class/method/function is defined.
try:
filename = inspect.getsourcefile(obj)
except TypeError:
return obj.__qualname__ # raises a sensible error
if not filename:
return obj.__qualname__ # raises a sensible error
if inspect.isclass(obj):
try:
_, lineno = inspect.getsourcelines(obj)
except (OSError, IOError):
return obj.__qualname__ # raises a sensible error
elif inspect.isfunction(obj) or inspect.ismethod(obj):
if hasattr(obj, 'im_func'):
# Extract function from unbound method (Python 2)
obj = obj.im_func
try:
code = obj.__code__
except AttributeError:
code = obj.func_code
lineno = code.co_firstlineno
else:
return obj.__qualname__ # raises a sensible error

# Re-parse the source file to figure out what the
# __qualname__ should be by analysing the abstract
# syntax tree. Use a cache to avoid doing this more
# than once for the same file.
qualnames = _cache.get(filename)
if qualnames is None:
with open(filename, 'r') as filehandle:
source = filehandle.read()
node = ast.parse(source, filename)
visitor = _Visitor()
visitor.visit(node)
_cache[filename] = qualnames = visitor.qualnames
try:
return qualnames[lineno]
except KeyError:
return obj.__qualname__ # raises a sensible error


class _Visitor(ast.NodeVisitor):
"""Support class for qualname function"""
def __init__(self):
super(_Visitor, self).__init__()
self.stack = []
self.qualnames = {}

def store_qualname(self, lineno):
"""Support method for _Visitor class"""
q_n = ".".join(n for n in self.stack)
self.qualnames[lineno] = q_n

def visit_FunctionDef(self, node):
"""Support method for _Visitor class"""
self.stack.append(node.name)
self.store_qualname(node.lineno)
self.stack.append('<locals>')
self.generic_visit(node)
self.stack.pop()
self.stack.pop()

def visit_ClassDef(self, node):
"""Support method for _Visitor class"""
self.stack.append(node.name)
self.store_qualname(node.lineno)
self.generic_visit(node)
self.stack.pop()
return "%s" % (cls.__qualname__)
Loading