|
| 1 | +import shutil |
| 2 | +import os |
| 3 | +import shlex |
| 4 | +import pathlib |
| 5 | + |
| 6 | +""" |
| 7 | +This file provides the `diff_test_updater` function, which is invoked on failed RUN lines when lit is executed with --update-tests. |
| 8 | +It checks whether the failed command is `diff` and, if so, uses heuristics to determine which file is the checked-in reference file and which file is output from the test case. |
| 9 | +The heuristics are currently as follows: |
| 10 | + - if exactly one file originates from the `split-file` command, that file is the reference file and the other is the output file |
| 11 | + - if exactly one file ends with ".expected" (common pattern in LLVM), that file is the reference file and the other is the output file |
| 12 | + - if exactly one file path contains ".tmp" (e.g. because it contains the expansion of "%t"), that file is the reference file and the other is the output file |
| 13 | +If the command matches one of these patterns the output file content is copied to the reference file to make the test pass. |
| 14 | +If the reference file originated in `split-file`, the output file content is instead copied to the corresponding slice of the test file. |
| 15 | +Otherwise the test is ignored. |
| 16 | +
|
| 17 | +Possible improvements: |
| 18 | + - Support stdin patterns like "my_binary %s | diff expected.txt" |
| 19 | + - Scan RUN lines to see if a file is the source of output from a previous command (other than `split-file`). |
| 20 | + If it is then it is not a reference file that can be copied to, regardless of name, since the test will overwrite it anyways. |
| 21 | + - Only update the parts that need updating (based on the diff output). Could help avoid noisy updates when e.g. whitespace changes are ignored. |
| 22 | +""" |
| 23 | + |
| 24 | + |
| 25 | +class NormalFileTarget: |
| 26 | + def __init__(self, target): |
| 27 | + self.target = target |
| 28 | + |
| 29 | + def copyFrom(self, source): |
| 30 | + shutil.copy(source, self.target) |
| 31 | + |
| 32 | + def __str__(self): |
| 33 | + return self.target |
| 34 | + |
| 35 | + |
| 36 | +class SplitFileTarget: |
| 37 | + def __init__(self, slice_start_idx, test_path, lines): |
| 38 | + self.slice_start_idx = slice_start_idx |
| 39 | + self.test_path = test_path |
| 40 | + self.lines = lines |
| 41 | + |
| 42 | + def copyFrom(self, source): |
| 43 | + lines_before = self.lines[: self.slice_start_idx + 1] |
| 44 | + self.lines = self.lines[self.slice_start_idx + 1 :] |
| 45 | + slice_end_idx = None |
| 46 | + for i, l in enumerate(self.lines): |
| 47 | + if SplitFileTarget._get_split_line_path(l) != None: |
| 48 | + slice_end_idx = i |
| 49 | + break |
| 50 | + if slice_end_idx is not None: |
| 51 | + lines_after = self.lines[slice_end_idx:] |
| 52 | + else: |
| 53 | + lines_after = [] |
| 54 | + with open(source, "r") as f: |
| 55 | + new_lines = lines_before + f.readlines() + lines_after |
| 56 | + with open(self.test_path, "w") as f: |
| 57 | + for l in new_lines: |
| 58 | + f.write(l) |
| 59 | + |
| 60 | + def __str__(self): |
| 61 | + return f"slice in {self.test_path}" |
| 62 | + |
| 63 | + @staticmethod |
| 64 | + def get_target_dir(commands, test_path): |
| 65 | + # posix=True breaks Windows paths because \ is treated as an escaping character |
| 66 | + for cmd in commands: |
| 67 | + split = shlex.split(cmd, posix=False) |
| 68 | + if "split-file" not in split: |
| 69 | + continue |
| 70 | + start_idx = split.index("split-file") |
| 71 | + split = split[start_idx:] |
| 72 | + if len(split) < 3: |
| 73 | + continue |
| 74 | + p = unquote(split[1].strip()) |
| 75 | + if not test_path.samefile(p): |
| 76 | + continue |
| 77 | + return unquote(split[2].strip()) |
| 78 | + return None |
| 79 | + |
| 80 | + @staticmethod |
| 81 | + def create(path, commands, test_path, target_dir): |
| 82 | + path = pathlib.Path(path) |
| 83 | + with open(test_path, "r") as f: |
| 84 | + lines = f.readlines() |
| 85 | + for i, l in enumerate(lines): |
| 86 | + p = SplitFileTarget._get_split_line_path(l) |
| 87 | + if p and path.samefile(os.path.join(target_dir, p)): |
| 88 | + idx = i |
| 89 | + break |
| 90 | + else: |
| 91 | + return None |
| 92 | + return SplitFileTarget(idx, test_path, lines) |
| 93 | + |
| 94 | + @staticmethod |
| 95 | + def _get_split_line_path(l): |
| 96 | + if len(l) < 6: |
| 97 | + return None |
| 98 | + if l.startswith("//"): |
| 99 | + l = l[2:] |
| 100 | + else: |
| 101 | + l = l[1:] |
| 102 | + if l.startswith("--- "): |
| 103 | + l = l[4:] |
| 104 | + else: |
| 105 | + return None |
| 106 | + return l.rstrip() |
| 107 | + |
| 108 | + |
| 109 | +def unquote(s): |
| 110 | + if len(s) > 1 and s[0] == s[-1] and (s[0] == '"' or s[0] == "'"): |
| 111 | + return s[1:-1] |
| 112 | + return s |
| 113 | + |
| 114 | + |
| 115 | +def get_source_and_target(a, b, test_path, commands): |
| 116 | + """ |
| 117 | + Try to figure out which file is the test output and which is the reference. |
| 118 | + """ |
| 119 | + split_target_dir = SplitFileTarget.get_target_dir(commands, test_path) |
| 120 | + if split_target_dir: |
| 121 | + a_target = SplitFileTarget.create(a, commands, test_path, split_target_dir) |
| 122 | + b_target = SplitFileTarget.create(b, commands, test_path, split_target_dir) |
| 123 | + if a_target and b_target: |
| 124 | + return None |
| 125 | + if a_target: |
| 126 | + return b, a_target |
| 127 | + if b_target: |
| 128 | + return a, b_target |
| 129 | + |
| 130 | + expected_suffix = ".expected" |
| 131 | + if a.endswith(expected_suffix) and not b.endswith(expected_suffix): |
| 132 | + return b, NormalFileTarget(a) |
| 133 | + if b.endswith(expected_suffix) and not a.endswith(expected_suffix): |
| 134 | + return a, NormalFileTarget(b) |
| 135 | + |
| 136 | + tmp_substr = ".tmp" |
| 137 | + if tmp_substr in a and not tmp_substr in b: |
| 138 | + return a, NormalFileTarget(b) |
| 139 | + if tmp_substr in b and not tmp_substr in a: |
| 140 | + return b, NormalFileTarget(a) |
| 141 | + |
| 142 | + return None |
| 143 | + |
| 144 | + |
| 145 | +def filter_flags(args): |
| 146 | + return [arg for arg in args if not arg.startswith("-")] |
| 147 | + |
| 148 | + |
| 149 | +def diff_test_updater(result, test, commands): |
| 150 | + args = filter_flags(result.command.args) |
| 151 | + if len(args) != 3: |
| 152 | + return None |
| 153 | + [cmd, a, b] = args |
| 154 | + if cmd != "diff": |
| 155 | + return None |
| 156 | + res = get_source_and_target(a, b, pathlib.Path(test.getFilePath()), commands) |
| 157 | + if not res: |
| 158 | + return f"update-diff-test: could not deduce source and target from {a} and {b}" |
| 159 | + source, target = res |
| 160 | + target.copyFrom(source) |
| 161 | + return f"update-diff-test: copied {source} to {target}" |
0 commit comments