Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/cases/05-VirtualTables/test_vtable_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,6 @@ def test_vtable_join(self):
self.sqlFile = etool.curFile(__file__, f"in/{testCase}.in")
self.ansFile = etool.curFile(__file__, f"ans/{testCase}.ans")

tdCom.compare_testcase_result(self.sqlFile, self.ansFile, testCase)


tdCom.compare_testcase_result(
self.sqlFile, self.ansFile, testCase, float_tolerance=2e-5
)
144 changes: 139 additions & 5 deletions test/new_test_framework/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .constant import *
from .epath import *
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
from typing import List
from datetime import datetime, timedelta
import re
Expand Down Expand Up @@ -3039,7 +3040,119 @@ def generate_query_result(self, inputfile, test_case):
)
return self.query_result_file

def compare_result_files(self, file1, file2):
def _get_numeric_compare_tolerance(self, token1, token2, float_tolerance):
if float_tolerance > 0.0:
return Decimal(str(float_tolerance))

def count_decimal_places(token):
mantissa = token.lower().split("e", 1)[0]
if "." not in mantissa:
return 0
return len(mantissa.split(".", 1)[1])

precision = max(count_decimal_places(token1), count_decimal_places(token2))
if precision <= 0:
return Decimal("0")
return Decimal(1).scaleb(-precision)

def _normalize_result_line_for_compare(self, line):
"""Normalize CLI-only suffixes before answer/result file comparison.

Args:
line: A single line from an answer or result file.

Returns:
The normalized line with runtime-only Windows suffixes removed.
"""

normalized = line.rstrip()
normalized = re.sub(r"\s*\([0-9]+\.[0-9]+s\)$", "", normalized)
normalized = re.sub(r"cost=[0-9]+\.[0-9]+\.\.[0-9]+\.[0-9]+", "", normalized)
normalized = re.sub(r"Planning Time: [0-9]+\.[0-9]+ ms", "", normalized)
normalized = re.sub(r"Execution Time: [0-9]+\.[0-9]+ ms", "", normalized)
normalized = re.sub(r"max_row_task=[0-9]+, ", "", normalized)
return normalized.rstrip()

def _compare_normalized_result_lines(self, file1, file2):
"""Compare result files after stripping platform-specific CLI noise.

Args:
file1: Expected result file path.
file2: Actual result file path.

Returns:
True when the normalized result lines are identical.
"""

with open(file1, "r", encoding="utf-8", errors="ignore") as f1:
lines1 = f1.read().splitlines()
with open(file2, "r", encoding="utf-8", errors="ignore") as f2:
lines2 = f2.read().splitlines()

if len(lines1) != len(lines2):
return False

for line1, line2 in zip(lines1, lines2):
if self._normalize_result_line_for_compare(
line1
) != self._normalize_result_line_for_compare(line2):
return False

return True

def _compare_file_lines_with_float_tolerance(self, file1, file2, float_tolerance):
number_pattern = re.compile(r"[-+]?(?:\d+\.\d+|\d+|\.\d+)(?:[eE][-+]?\d+)?")

with open(file1, "r", encoding="utf-8", errors="ignore") as f1:
lines1 = f1.read().splitlines()
with open(file2, "r", encoding="utf-8", errors="ignore") as f2:
lines2 = f2.read().splitlines()

if len(lines1) != len(lines2):
return False

for line1, line2 in zip(lines1, lines2):
line1 = self._normalize_result_line_for_compare(line1)
line2 = self._normalize_result_line_for_compare(line2)

if line1 == line2:
continue

matches1 = list(number_pattern.finditer(line1))
matches2 = list(number_pattern.finditer(line2))
if len(matches1) != len(matches2):
return False

cursor1 = 0
cursor2 = 0
for match1, match2 in zip(matches1, matches2):
if line1[cursor1:match1.start()] != line2[cursor2:match2.start()]:
return False

token1 = match1.group(0)
token2 = match2.group(0)
try:
value1 = Decimal(token1)
value2 = Decimal(token2)
except InvalidOperation:
if token1 != token2:
return False
else:
tolerance = self._get_numeric_compare_tolerance(
token1, token2, float_tolerance
)
if abs(value1 - value2) > tolerance:
return False

cursor1 = match1.end()
cursor2 = match2.end()

if line1[cursor1:] != line2[cursor2:]:
return False

return True

def compare_result_files(self, file1, file2, float_tolerance=0.0):
try:
# use subprocess.run to execute diff/fc commands
# print(file1, file2)
Expand Down Expand Up @@ -3067,6 +3180,19 @@ def compare_result_files(self, file1, file2):
return True
# if result is not empty, print the differences and files name. Otherwise, the files are identical.
if result.returncode != 0:
if self._compare_normalized_result_lines(file1, file2):
tdLog.info("Result files matched after output normalization.")
return True
if platform.system().lower() == "windows" and self._compare_file_lines_with_float_tolerance(
file1, file2, float_tolerance
):
tdLog.info(
"Result files matched after Windows output normalization."
if float_tolerance <= 0.0
else "Result files matched after Windows output normalization "
f"with float tolerance {float_tolerance}."
)
return True
tdLog.info(f"{cmd} result.returncode: {result.returncode}")
tdLog.info(f"{cmd} result.stdout: {result.stdout}")
tdLog.info(f"{cmd} result.stderr: {result.stderr}")
Expand All @@ -3087,9 +3213,13 @@ def compare_result_files(self, file1, file2):
except Exception as e:
tdLog.debug(f"An error occurred: {e}")

def compare_query_with_result_file(self, idx, sql, resultFile, test_case):
def compare_query_with_result_file(
self, idx, sql, resultFile, test_case, float_tolerance=0.0
):
self.generate_query_result_file(test_case, idx, sql)
if self.compare_result_files(resultFile, self.query_result_file):
if self.compare_result_files(
resultFile, self.query_result_file, float_tolerance=float_tolerance
):
tdLog.info("Test passed: Result files are identical.")
# os.system(f"rm -f {self.query_result_file}")
else:
Expand All @@ -3098,10 +3228,14 @@ def compare_query_with_result_file(self, idx, sql, resultFile, test_case):
f"{caller.lineno}(line:{caller.lineno}) failed: expect_file:{resultFile} != reult_file:{self.query_result_file} "
)

def compare_testcase_result(self, inputfile, expected_file, test_case):
def compare_testcase_result(
self, inputfile, expected_file, test_case, float_tolerance=0.0
):
test_reulst_file = self.generate_query_result(inputfile, test_case)

if self.compare_result_files(expected_file, test_reulst_file):
if self.compare_result_files(
expected_file, test_reulst_file, float_tolerance=float_tolerance
):
tdLog.info("Test passed: Result files are identical.")
os.system(f"rm -f {test_reulst_file}")
else:
Expand Down
Loading