Skip to content

Commit 629e9ab

Browse files
authored
sst_unittest.py: add type annotations to run_sst (#1197)
* test_LookupTable.py: remove unused imports * testsuite_default_UnitAlgebra.py: idiomatic boolean comparison * sst_unittest.py: missing multiprocessing import * sst_unittest.py: avoid shadowing variable names * sst_unittest.py: add type annotations * sst_unittest.py: enforce run_sst timeout_sec as integer
1 parent 7a2b424 commit 629e9ab

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

src/sst/core/testingframework/sst_unittest.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import threading
2929
import signal
3030
import time
31+
import multiprocessing
3132
from typing import Optional
3233

3334
import test_engine_globals
@@ -58,11 +59,10 @@ class SSTTestCase(unittest.TestCase):
5859
def __init__(self, methodName: str) -> None:
5960
# NOTE: __init__ is called at startup for all tests before any
6061
# setUpModules(), setUpClass(), setUp() and the like are called.
61-
super(SSTTestCase, self).__init__(methodName)
62+
super().__init__(methodName)
6263
self.testname = methodName
63-
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore
64+
parent_module_path: str = os.path.dirname(sys.modules[self.__class__.__module__].__file__) # type: ignore [assignment,type-var]
6465
self._testsuite_dirpath = parent_module_path
65-
#log_forced("SSTTestCase: __init__() - {0}".format(self.testname))
6666
self.initializeClass(self.testname)
6767
self._start_test_time = time.time()
6868
self._stop_test_time = time.time()
@@ -195,7 +195,7 @@ def get_testsuite_dir(self) -> str:
195195
""" Return the directory path of the testsuite that is being run
196196
197197
Returns:
198-
(str)The path of the testsite directory
198+
(str) The path of the testsite directory
199199
"""
200200
return self._testsuite_dirpath
201201

@@ -235,9 +235,23 @@ def get_test_runtime_sec(self) -> float:
235235
### Method to run an SST simulation
236236
################################################################################
237237

238-
def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files="",
239-
other_args="", num_ranks=None, num_threads=None, global_args=None,
240-
timeout_sec=120, expected_rc=0, check_sdl_file=True, send_signal=signal.NSIG, signal_sec=3):
238+
def run_sst(
239+
self,
240+
sdl_file: str,
241+
out_file: str,
242+
err_file: Optional[str] = None,
243+
set_cwd: Optional[str] = None,
244+
mpi_out_files: str = "",
245+
other_args: str = "",
246+
num_ranks: Optional[int] = None,
247+
num_threads: Optional[int] = None,
248+
global_args: Optional[str] = None,
249+
timeout_sec: int = 120,
250+
expected_rc: int = 0,
251+
check_sdl_file: bool = True,
252+
send_signal: int = signal.NSIG,
253+
signal_sec: int = 3
254+
) -> str:
241255
""" Launch sst with with the command line and send output to the
242256
output file. The SST execution will be monitored for result errors and
243257
timeouts. On an error or timeout, a SSTTestCase.assert() will be generated
@@ -288,8 +302,7 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
288302
check_param_type("num_threads", num_threads, int)
289303
if global_args is not None:
290304
check_param_type("global_args", global_args, str)
291-
if not (isinstance(timeout_sec, (int, float)) and not isinstance(timeout_sec, bool)):
292-
raise ValueError("ERROR: Timeout_sec must be a postive int or a float")
305+
check_param_type("timeout_sec", timeout_sec, int)
293306
if expected_rc is not None:
294307
check_param_type("expected_rc", expected_rc, int)
295308

@@ -331,8 +344,8 @@ def run_sst(self, sdl_file, out_file, err_file=None, set_cwd=None, mpi_out_files
331344
numa_param = ""
332345
if num_ranks > 1:
333346
# Check to see if mpirun is available
334-
rtn = os.system("which mpirun > /dev/null 2>&1")
335-
if rtn == 0:
347+
rtn_mpirun = os.system("which mpirun > /dev/null 2>&1")
348+
if rtn_mpirun == 0:
336349
mpi_avail = True
337350

338351
numa_param = "-map-by numa:PE={0}".format(num_threads)
@@ -433,7 +446,7 @@ def tearDownModule() -> None:
433446

434447
###################
435448

436-
def setUpModuleConcurrent(test):
449+
def setUpModuleConcurrent(test: SSTTestCase) -> None:
437450
""" Perform setup functions before the testing Module loads.
438451
439452
This function is called by the Frameworks before tests in any TestCase
@@ -461,7 +474,7 @@ def setUpModuleConcurrent(test):
461474

462475
###
463476

464-
def tearDownModuleConcurrent(test):
477+
def tearDownModuleConcurrent(test: SSTTestCase) -> None:
465478
""" Perform teardown functions immediately after a testing Module finishes.
466479
467480
This function is called by the Frameworks after all tests in all TestCases

tests/test_LookupTable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# information, see the LICENSE file in the top level directory of the
1010
# distribution.
1111
import sst
12-
import inspect, os, sys
12+
import inspect
1313

1414
currentframe = inspect.currentframe()
1515
assert currentframe is not None

tests/testsuite_default_UnitAlgebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def unitalgebra_test_template(self, testtype):
4949

5050
# Perform the test
5151
cmp_result = testing_compare_sorted_diff(testtype, outfile, reffile)
52-
if (cmp_result == False):
52+
if not cmp_result:
5353
diffdata = testing_get_diff_data(testtype)
5454
log_failure(diffdata)
5555
self.assertTrue(cmp_result, "Output/Compare file {0} does not match Reference File {1}".format(outfile, reffile))

0 commit comments

Comments
 (0)