Skip to content

Commit 05d14d8

Browse files
committed
Formatting
1 parent 610a721 commit 05d14d8

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

test/codegen/compare_codegen.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ROOT
44
from enum import Enum, auto
55

6+
67
class ComparisonResult(Enum):
78
OK = 0
89
NO_LIMIT_TREE = auto()
@@ -16,7 +17,7 @@ def detect_keys(startpath) -> set:
1617
dirs = os.listdir(startpath)
1718
for d in dirs:
1819
if os.path.isdir(os.path.join(startpath, d)):
19-
keys.add(d.replace('_codegen', ''))
20+
keys.add(d.replace("_codegen", ""))
2021
return sorted(keys, key=lambda k: os.path.getmtime(os.path.join(startpath, k)))
2122

2223

@@ -41,10 +42,7 @@ def check_codegen_counterpart_files(startpath, keys, missing_codegen) -> dict[st
4142
# Union should equal both sets
4243
missing_in_codegen = nominal_files - codegen_files
4344
missing_in_nominal = codegen_files - nominal_files
44-
discrepancies[key] = {
45-
"missing_in_codegen": missing_in_codegen,
46-
"missing_in_nominal": missing_in_nominal
47-
}
45+
discrepancies[key] = {"missing_in_codegen": missing_in_codegen, "missing_in_nominal": missing_in_nominal}
4846
return discrepancies
4947

5048

@@ -55,19 +53,19 @@ def compare_file_contents(file1, file2, tol: float = 1e-3) -> ComparisonResult:
5553
# For now, just check the 'limit' tree
5654
keys1 = [k.GetName() for k in f1.GetListOfKeys()]
5755
keys2 = [k.GetName() for k in f2.GetListOfKeys()]
58-
if 'limit' not in keys1 or 'limit' not in keys2:
59-
if 'fitDiagnosticsTest' in file1:
56+
if "limit" not in keys1 or "limit" not in keys2:
57+
if "fitDiagnosticsTest" in file1:
6058
return ComparisonResult.OK # Skip fitDiagnosticsTest for now
6159
return ComparisonResult.NO_LIMIT_TREE
6260

63-
tree1 = f1.Get('limit')
64-
tree2 = f2.Get('limit')
61+
tree1 = f1.Get("limit")
62+
tree2 = f2.Get("limit")
6563
if tree1.GetEntries() != tree2.GetEntries():
6664
return ComparisonResult.DIFFERENT_ENTRIES
6765

6866
# Check POIs match
69-
pois_1 = [b.GetName() for b in tree1.GetListOfBranches() if 'r_' in b.GetName() or 'r' == b.GetName()]
70-
pois_2 = [b.GetName() for b in tree2.GetListOfBranches() if 'r_' in b.GetName() or 'r' == b.GetName()]
67+
pois_1 = [b.GetName() for b in tree1.GetListOfBranches() if "r_" in b.GetName() or "r" == b.GetName()]
68+
pois_2 = [b.GetName() for b in tree2.GetListOfBranches() if "r_" in b.GetName() or "r" == b.GetName()]
7169
if set(pois_1) != set(pois_2):
7270
return ComparisonResult.DIFFERENT_POIS
7371

@@ -78,8 +76,8 @@ def compare_file_contents(file1, file2, tol: float = 1e-3) -> ComparisonResult:
7876
for poi in pois_1:
7977
val1 = getattr(tree1, poi)
8078
val2 = getattr(tree2, poi)
81-
deltaNLL1 = getattr(tree1, 'deltaNLL')
82-
deltaNLL2 = getattr(tree2, 'deltaNLL')
79+
deltaNLL1 = getattr(tree1, "deltaNLL")
80+
deltaNLL2 = getattr(tree2, "deltaNLL")
8381
if abs(val1 - val2) > tol or abs(deltaNLL1 - deltaNLL2) > tol:
8482
return ComparisonResult.VALUE_MISMATCH
8583

@@ -92,26 +90,24 @@ def compare_file_contents(file1, file2, tol: float = 1e-3) -> ComparisonResult:
9290
sys.exit(1)
9391
comparison_input_dir = sys.argv[1]
9492
keys = detect_keys(comparison_input_dir)
95-
status = {k: 'OK' for k in keys}
93+
status = {k: "OK" for k in keys}
9694

9795
# ---- Comparisons ---
9896

9997
# 1. Check every directory has a codegen counterpart
10098
missing_codegen = check_codegen_counterparts(comparison_input_dir, keys)
10199
for key in missing_codegen:
102-
status[key] = 'MISSING_CODEGEN'
100+
status[key] = "MISSING_CODEGEN"
103101

104102
# 2. Check every codegen directory has the same files as its counterpart
105103
discrepancies = check_codegen_counterpart_files(comparison_input_dir, keys, missing_codegen)
106104
for key, diff in discrepancies.items():
107105
if diff["missing_in_codegen"] or diff["missing_in_nominal"]:
108-
status[key] = 'FILE_MISMATCH'
106+
status[key] = "FILE_MISMATCH"
109107

110108
# 3. Check file contents
111109
for key in keys:
112-
if key in missing_codegen or \
113-
key in discrepancies and (discrepancies[key]["missing_in_codegen"] \
114-
or discrepancies[key]["missing_in_nominal"]):
110+
if key in missing_codegen or key in discrepancies and (discrepancies[key]["missing_in_codegen"] or discrepancies[key]["missing_in_nominal"]):
115111
continue
116112

117113
files = os.listdir(os.path.join(comparison_input_dir, key))

0 commit comments

Comments
 (0)