Skip to content

Commit acd4999

Browse files
committed
refactor to do unit tests
1 parent 5b0367d commit acd4999

File tree

2 files changed

+291
-6
lines changed

2 files changed

+291
-6
lines changed

ext/auto-inst/parsing.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ def get_yaml_instructions(repo_directory):
265265
"yaml_vars": yaml_vars
266266
}
267267

268-
# Debug print
269-
print("Instructions + Encodings:\n", instructions_with_encodings)
270268
return instructions_with_encodings
271269

272270
def find_json_key(instr_name, json_data):
@@ -309,7 +307,7 @@ def run_parser(json_file, repo_directory, output_file="output.txt"):
309307
# Step 2: parse JSON
310308
try:
311309
with open(json_file, 'r') as f:
312-
data = json.loads(f.read())
310+
json_data = json.loads(f.read())
313311
except Exception as e:
314312
print(f"Error reading file: {str(e)}")
315313
return None
@@ -318,12 +316,12 @@ def run_parser(json_file, repo_directory, output_file="output.txt"):
318316

319317
# Step 3: For each YAML instruction, attempt to find it in JSON by AsmString
320318
for yaml_instr_name_lower, yaml_data in instructions_with_encodings.items():
321-
json_key = find_json_key(yaml_instr_name_lower, data)
319+
json_key = find_json_key(yaml_instr_name_lower, json_data)
322320
if json_key is None:
323321
print(f"DEBUG: Instruction '{yaml_instr_name_lower}' (from YAML) not found in JSON, skipping...", file=sys.stderr)
324322
continue
325323

326-
instr_data = data.get(json_key)
324+
instr_data = json_data.get(json_key)
327325
if not isinstance(instr_data, dict):
328326
print(f"DEBUG: Instruction '{yaml_instr_name_lower}' is in JSON but not a valid dict, skipping...", file=sys.stderr)
329327
continue
@@ -334,7 +332,6 @@ def run_parser(json_file, repo_directory, output_file="output.txt"):
334332

335333
# We'll keep track of them so we can print details
336334
all_instructions.append((json_key, instr_data))
337-
338335
# Sort instructions by JSON key
339336
all_instructions.sort(key=lambda x: x[0].lower())
340337

ext/auto-inst/test_parsing.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import pytest
2+
import json
3+
import os
4+
import re
5+
import yaml
6+
from pathlib import Path
7+
8+
def get_json_path():
9+
"""Get the path to the JSON file relative to the test file."""
10+
current_dir = Path(__file__).parent
11+
return str(current_dir / "../../../llvm-project/build/unorder.json")
12+
13+
def get_yaml_directory():
14+
"""Get the path to the YAML directory relative to the test file."""
15+
current_dir = Path(__file__).parent
16+
return str(current_dir / "../../arch/inst/")
17+
18+
def load_inherited_variable(var_path, repo_dir):
19+
"""Load variable definition from an inherited YAML file."""
20+
try:
21+
# Parse the path to get directory and anchor
22+
path, anchor = var_path.split('#')
23+
if anchor.startswith('/'):
24+
anchor = anchor[1:] # Remove leading slash
25+
26+
# Construct full path
27+
full_path = os.path.join(repo_dir, path)
28+
29+
if not os.path.exists(full_path):
30+
print(f"Warning: Inherited file not found: {full_path}")
31+
return None
32+
33+
with open(full_path, 'r') as f:
34+
data = yaml.safe_load(f)
35+
36+
# Navigate through the anchor path
37+
for key in anchor.split('/'):
38+
if key in data:
39+
data = data[key]
40+
else:
41+
print(f"Warning: Anchor path {anchor} not found in {path}")
42+
return None
43+
44+
return data
45+
except Exception as e:
46+
print(f"Error loading inherited variable {var_path}: {str(e)}")
47+
return None
48+
49+
def resolve_variable_definition(var, repo_dir):
50+
"""Resolve variable definition, handling inheritance if needed."""
51+
if 'location' in var:
52+
return var
53+
elif '$inherits' in var:
54+
print(f"Warning: Failed to resolve inheritance for variable: {var}")
55+
return None
56+
57+
def parse_location(loc_str):
58+
"""Parse location string that may contain multiple ranges."""
59+
if not loc_str:
60+
return []
61+
62+
loc_str = str(loc_str).strip()
63+
ranges = []
64+
65+
# Split on pipe if there are multiple ranges
66+
for range_str in loc_str.split('|'):
67+
range_str = range_str.strip()
68+
if '-' in range_str:
69+
high, low = map(int, range_str.split('-'))
70+
ranges.append((high, low))
71+
else:
72+
# Single bit case
73+
try:
74+
val = int(range_str)
75+
ranges.append((val, val))
76+
except ValueError:
77+
print(f"Warning: Invalid location format: {range_str}")
78+
continue
79+
80+
return ranges
81+
82+
def compare_yaml_json_encoding(instr_name, yaml_match, yaml_vars, json_encoding_str, repo_dir):
83+
"""Compare the YAML encoding with the JSON encoding."""
84+
if not yaml_match:
85+
return ["No YAML match field available for comparison."]
86+
if not json_encoding_str:
87+
return ["No JSON encoding available for comparison."]
88+
89+
# Determine expected length based on whether it's a compressed instruction (C_ or c.)
90+
expected_length = 16 if instr_name.lower().startswith(('c_', 'c.')) else 32
91+
92+
yaml_pattern_str = yaml_match.replace('-', '.')
93+
if len(yaml_pattern_str) != expected_length:
94+
return [f"YAML match pattern length is {len(yaml_pattern_str)}, expected {expected_length}. Cannot compare properly."]
95+
96+
# Process variables and their locations
97+
yaml_var_positions = {}
98+
for var in (yaml_vars or []):
99+
resolved_var = resolve_variable_definition(var, repo_dir)
100+
if not resolved_var or 'location' not in resolved_var:
101+
print(f"Warning: Could not resolve variable definition for {var.get('name', 'unknown')}")
102+
continue
103+
104+
ranges = parse_location(resolved_var['location'])
105+
if ranges:
106+
yaml_var_positions[var['name']] = ranges
107+
108+
# Tokenize the JSON encoding string
109+
tokens = re.findall(r'(?:[01]|[A-Za-z0-9]+(?:\[\d+\]|\[\?\])?)', json_encoding_str)
110+
json_bits = []
111+
bit_index = expected_length - 1
112+
for t in tokens:
113+
json_bits.append((bit_index, t))
114+
bit_index -= 1
115+
116+
if bit_index != -1:
117+
return [f"JSON encoding does not appear to be {expected_length} bits. Ends at bit {bit_index+1}."]
118+
119+
# Normalize JSON bits (handle vm[?] etc.)
120+
normalized_json_bits = []
121+
for pos, tt in json_bits:
122+
if re.match(r'vm\[[^\]]*\]', tt):
123+
tt = 'vm'
124+
normalized_json_bits.append((pos, tt))
125+
json_bits = normalized_json_bits
126+
127+
differences = []
128+
129+
# Check fixed bits
130+
for b in range(expected_length):
131+
yaml_bit = yaml_pattern_str[expected_length - 1 - b]
132+
token = [tt for (pos, tt) in json_bits if pos == b]
133+
if not token:
134+
differences.append(f"Bit {b}: No corresponding JSON bit found.")
135+
continue
136+
json_bit_str = token[0]
137+
138+
if yaml_bit in ['0', '1']:
139+
if json_bit_str not in ['0', '1']:
140+
differences.append(f"Bit {b}: YAML expects fixed bit '{yaml_bit}' but JSON has '{json_bit_str}'")
141+
elif json_bit_str != yaml_bit:
142+
differences.append(f"Bit {b}: YAML expects '{yaml_bit}' but JSON has '{json_bit_str}'")
143+
else:
144+
if json_bit_str in ['0', '1']:
145+
differences.append(f"Bit {b}: YAML variable bit but JSON is fixed '{json_bit_str}'")
146+
147+
# Check variable fields
148+
for var_name, ranges in yaml_var_positions.items():
149+
for high, low in ranges:
150+
# Ensure the variable range fits within the expected_length
151+
if high >= expected_length or low < 0:
152+
differences.append(f"Variable {var_name}: location {high}-{low} is out of range for {expected_length}-bit instruction.")
153+
continue
154+
155+
json_var_fields = []
156+
for bb in range(low, high+1):
157+
token = [tt for (pos, tt) in json_bits if pos == bb]
158+
if token:
159+
json_var_fields.append(token[0])
160+
else:
161+
json_var_fields.append('?')
162+
163+
# Extract field names
164+
field_names = set(re.findall(r'([A-Za-z0-9]+)(?:\[\d+\]|\[\?\])?', ' '.join(json_var_fields)))
165+
if len(field_names) == 0:
166+
differences.append(f"Variable {var_name}: No corresponding field found in JSON bits {high}-{low}")
167+
elif len(field_names) > 1:
168+
differences.append(f"Variable {var_name}: Multiple fields {field_names} found in JSON for bits {high}-{low}")
169+
170+
return differences
171+
172+
@pytest.fixture
173+
def yaml_instructions():
174+
"""Load all YAML instructions from the repository."""
175+
from parsing import get_yaml_instructions
176+
repo_dir = get_yaml_directory()
177+
if not os.path.exists(repo_dir):
178+
pytest.skip(f"Repository directory not found at {repo_dir}")
179+
return get_yaml_instructions(repo_dir)
180+
181+
@pytest.fixture
182+
def json_data():
183+
"""Load the real JSON data from the TableGen file."""
184+
json_file = get_json_path()
185+
if not os.path.exists(json_file):
186+
pytest.skip(f"JSON file not found at {json_file}")
187+
with open(json_file, 'r') as f:
188+
return json.load(f)
189+
190+
def pytest_configure(config):
191+
"""Configure the test session."""
192+
print(f"\nUsing JSON file: {get_json_path()}")
193+
print(f"Using YAML directory: {get_yaml_directory()}\n")
194+
195+
class TestEncodingComparison:
196+
def test_encoding_matches(self, yaml_instructions, json_data):
197+
"""Test YAML-defined instructions against their JSON counterparts if they exist."""
198+
mismatches = []
199+
total_yaml_instructions = 0
200+
checked_instructions = 0
201+
skipped_instructions = []
202+
repo_dir = get_yaml_directory()
203+
204+
for yaml_instr_name, yaml_data in yaml_instructions.items():
205+
total_yaml_instructions += 1
206+
207+
# Skip if no YAML match pattern
208+
if not yaml_data.get("yaml_match"):
209+
skipped_instructions.append(yaml_instr_name)
210+
continue
211+
212+
# Get JSON encoding from instruction data
213+
json_key = self._find_matching_instruction(yaml_instr_name, json_data)
214+
if not json_key:
215+
skipped_instructions.append(yaml_instr_name)
216+
continue
217+
218+
checked_instructions += 1
219+
json_encoding = self._get_json_encoding(json_data[json_key])
220+
221+
# Compare encodings using the existing function
222+
differences = compare_yaml_json_encoding(
223+
yaml_instr_name,
224+
yaml_data["yaml_match"],
225+
yaml_data["yaml_vars"],
226+
json_encoding,
227+
repo_dir
228+
)
229+
230+
if differences and differences != ["No YAML match field available for comparison."]:
231+
mismatches.append({
232+
'instruction': yaml_instr_name,
233+
'json_key': json_key,
234+
'differences': differences,
235+
'yaml_match': yaml_data["yaml_match"],
236+
'json_encoding': json_encoding
237+
})
238+
239+
# Print statistics
240+
print(f"\nYAML instructions found: {total_yaml_instructions}")
241+
print(f"Instructions checked: {checked_instructions}")
242+
print(f"Instructions skipped: {len(skipped_instructions)}")
243+
print(f"Instructions with encoding mismatches: {len(mismatches)}")
244+
245+
if skipped_instructions:
246+
print("\nSkipped instructions:")
247+
for instr in skipped_instructions:
248+
print(f" - {instr}")
249+
250+
if mismatches:
251+
error_msg = "\nEncoding mismatches found:\n"
252+
for m in mismatches:
253+
error_msg += f"\nInstruction: {m['instruction']} (JSON key: {m['json_key']})\n"
254+
error_msg += f"YAML match: {m['yaml_match']}\n"
255+
error_msg += f"JSON encoding: {m['json_encoding']}\n"
256+
error_msg += "Differences:\n"
257+
for d in m['differences']:
258+
error_msg += f" - {d}\n"
259+
pytest.fail(error_msg)
260+
261+
def _find_matching_instruction(self, yaml_instr_name, json_data):
262+
"""Find matching instruction in JSON data by comparing instruction names."""
263+
yaml_instr_name = yaml_instr_name.lower().strip()
264+
for key, value in json_data.items():
265+
if not isinstance(value, dict):
266+
continue
267+
asm_string = value.get('AsmString', '').lower().strip()
268+
if not asm_string:
269+
continue
270+
base_asm_name = asm_string.split()[0]
271+
if base_asm_name == yaml_instr_name:
272+
return key
273+
return None
274+
275+
def _get_json_encoding(self, json_instr):
276+
"""Extract encoding string from JSON instruction data."""
277+
encoding_bits = []
278+
try:
279+
inst = json_instr.get('Inst', [])
280+
for bit in inst:
281+
if isinstance(bit, dict):
282+
encoding_bits.append(f"{bit.get('var', '?')}[{bit.get('index', '?')}]")
283+
else:
284+
encoding_bits.append(str(bit))
285+
encoding_bits.reverse()
286+
return "".join(encoding_bits)
287+
except:
288+
return ""

0 commit comments

Comments
 (0)