Skip to content

Commit ef20407

Browse files
committed
match statements
1 parent bfdf0e0 commit ef20407

File tree

1 file changed

+183
-163
lines changed

1 file changed

+183
-163
lines changed

xdsl/transforms/convert_pdl_to_pdl_interp/conversion.py

Lines changed: 183 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -99,30 +99,31 @@ def extract_tree_predicates(
9999
inputs[value] = position
100100

101101
# Dispatch based on position type (not value type!)
102-
if isinstance(position, AttributePosition):
103-
assert isinstance(value, OpResult)
104-
predicates.extend(
105-
self._extract_attribute_predicates(value.owner, position, inputs)
106-
)
107-
elif isinstance(position, OperationPosition):
108-
assert isinstance(value, OpResult)
109-
predicates.extend(
110-
self._extract_operation_predicates(
111-
value.owner, position, inputs, ignore_operand
102+
match position:
103+
case AttributePosition():
104+
assert isinstance(value, OpResult)
105+
predicates.extend(
106+
self._extract_attribute_predicates(value.owner, position, inputs)
112107
)
113-
)
114-
elif isinstance(position, TypePosition):
115-
assert isinstance(value, OpResult)
116-
predicates.extend(
117-
self._extract_type_predicates(value.owner, position, inputs)
118-
)
119-
elif isinstance(position, OperandPosition | OperandGroupPosition):
120-
assert isinstance(value, SSAValue)
121-
predicates.extend(
122-
self._extract_operand_tree_predicates(value, position, inputs)
123-
)
124-
else:
125-
raise TypeError(f"Unexpected position kind: {type(position)}")
108+
case OperationPosition():
109+
assert isinstance(value, OpResult)
110+
predicates.extend(
111+
self._extract_operation_predicates(
112+
value.owner, position, inputs, ignore_operand
113+
)
114+
)
115+
case TypePosition():
116+
assert isinstance(value, OpResult)
117+
predicates.extend(
118+
self._extract_type_predicates(value.owner, position, inputs)
119+
)
120+
case OperandPosition() | OperandGroupPosition():
121+
assert isinstance(value, SSAValue)
122+
predicates.extend(
123+
self._extract_operand_tree_predicates(value, position, inputs)
124+
)
125+
case _:
126+
raise TypeError(f"Unexpected position kind: {type(position)}")
126127

127128
return predicates
128129

@@ -306,73 +307,82 @@ def _extract_operand_tree_predicates(
306307
defining_op = operand_value.owner
307308
is_variadic = isinstance(operand_value.type, pdl.RangeType)
308309

309-
if isinstance(defining_op, pdl.OperandOp | pdl.OperandsOp):
310-
if isinstance(defining_op, pdl.OperandOp):
311-
is_not_null = Predicate.get_is_not_null()
312-
predicates.append(
313-
PositionalPredicate(
314-
q=is_not_null.q, a=is_not_null.a, position=operand_pos
315-
)
316-
)
317-
elif (
318-
isinstance(operand_pos, OperandGroupPosition)
319-
and operand_pos.group_number is not None
320-
):
321-
is_not_null = Predicate.get_is_not_null()
322-
predicates.append(
323-
PositionalPredicate(
324-
q=is_not_null.q, a=is_not_null.a, position=operand_pos
310+
match defining_op:
311+
case pdl.OperandOp() | pdl.OperandsOp():
312+
match defining_op:
313+
case pdl.OperandOp():
314+
is_not_null = Predicate.get_is_not_null()
315+
predicates.append(
316+
PositionalPredicate(
317+
q=is_not_null.q, a=is_not_null.a, position=operand_pos
318+
)
319+
)
320+
case pdl.OperandsOp() if (
321+
isinstance(operand_pos, OperandGroupPosition)
322+
and operand_pos.group_number is not None
323+
):
324+
is_not_null = Predicate.get_is_not_null()
325+
predicates.append(
326+
PositionalPredicate(
327+
q=is_not_null.q, a=is_not_null.a, position=operand_pos
328+
)
329+
)
330+
case _:
331+
pass
332+
333+
if defining_op.value_type:
334+
type_pos = operand_pos.get_type()
335+
predicates.extend(
336+
self.extract_tree_predicates(
337+
defining_op.value_type, type_pos, inputs
338+
)
325339
)
326-
)
327340

328-
if defining_op.value_type:
329-
type_pos = operand_pos.get_type()
330-
predicates.extend(
331-
self.extract_tree_predicates(
332-
defining_op.value_type, type_pos, inputs
341+
case pdl.ResultOp() | pdl.ResultsOp():
342+
index_attr = defining_op.index
343+
index = index_attr.value.data if index_attr is not None else None
344+
345+
if index is not None:
346+
is_not_null = Predicate.get_is_not_null()
347+
predicates.append(
348+
PositionalPredicate(
349+
q=is_not_null.q, a=is_not_null.a, position=operand_pos
350+
)
333351
)
334-
)
335352

336-
elif isinstance(defining_op, pdl.ResultOp | pdl.ResultsOp):
337-
index_attr = defining_op.index
338-
index = index_attr.value.data if index_attr is not None else None
353+
# Get the parent operation position
354+
parent_op = defining_op.parent_
355+
defining_op_pos = operand_pos.get_defining_op()
339356

340-
if index is not None:
357+
# Parent operation should not be null
341358
is_not_null = Predicate.get_is_not_null()
342359
predicates.append(
343360
PositionalPredicate(
344-
q=is_not_null.q, a=is_not_null.a, position=operand_pos
361+
q=is_not_null.q, a=is_not_null.a, position=defining_op_pos
345362
)
346363
)
347364

348-
# Get the parent operation position
349-
parent_op = defining_op.parent_
350-
defining_op_pos = operand_pos.get_defining_op()
365+
match defining_op:
366+
case pdl.ResultOp():
367+
result_pos = defining_op_pos.get_result(
368+
index if index is not None else 0
369+
)
370+
case pdl.ResultsOp(): # ResultsOp
371+
result_pos = defining_op_pos.get_result_group(
372+
index, is_variadic
373+
)
351374

352-
# Parent operation should not be null
353-
is_not_null = Predicate.get_is_not_null()
354-
predicates.append(
355-
PositionalPredicate(
356-
q=is_not_null.q, a=is_not_null.a, position=defining_op_pos
375+
equal_to = Predicate.get_equal_to(operand_pos)
376+
predicates.append(
377+
PositionalPredicate(q=equal_to.q, a=equal_to.a, position=result_pos)
357378
)
358-
)
359379

360-
if isinstance(defining_op, pdl.ResultOp):
361-
result_pos = defining_op_pos.get_result(
362-
index if index is not None else 0
380+
# Recursively process the parent operation
381+
predicates.extend(
382+
self.extract_tree_predicates(parent_op, defining_op_pos, inputs)
363383
)
364-
else: # ResultsOp
365-
result_pos = defining_op_pos.get_result_group(index, is_variadic)
366-
367-
equal_to = Predicate.get_equal_to(operand_pos)
368-
predicates.append(
369-
PositionalPredicate(q=equal_to.q, a=equal_to.a, position=result_pos)
370-
)
371-
372-
# Recursively process the parent operation
373-
predicates.extend(
374-
self.extract_tree_predicates(parent_op, defining_op_pos, inputs)
375-
)
384+
case _:
385+
pass
376386

377387
return predicates
378388

@@ -414,87 +424,75 @@ def extract_non_tree_predicates(
414424
predicates: list[PositionalPredicate] = []
415425

416426
for op in pattern.body.ops:
417-
if isinstance(op, pdl.AttributeOp):
418-
if op.output not in inputs:
419-
if op.value:
420-
# Create literal position for constant attribute
421-
attr_pos = AttributeLiteralPosition(value=op.value, parent=None)
422-
inputs[op.output] = attr_pos
423-
424-
elif isinstance(op, pdl.ApplyNativeConstraintOp):
425-
# Collect all argument positions
426-
arg_positions = tuple(inputs.get(arg) for arg in op.args)
427-
for pos in arg_positions:
428-
assert pos is not None
429-
arg_positions = cast(tuple[Position, ...], arg_positions)
430-
431-
# Find the furthest position (deepest)
432-
furthest_pos = max(
433-
arg_positions, key=lambda p: p.get_operation_depth() if p else 0
434-
)
435-
436-
# Create the constraint predicate
437-
result_types = tuple(r.type for r in op.res)
438-
# TODO: is_negated is not part of the dialect definition yet
439-
is_negated = False
440-
constraint_pred = Predicate.get_constraint(
441-
op.constraint_name.data, arg_positions, result_types, is_negated
442-
)
443-
444-
# Register positions for constraint results
445-
for i, result in enumerate(op.results):
446-
assert isinstance(constraint_pred.q, ConstraintQuestion)
447-
constraint_pos = ConstraintPosition.get_constraint(
448-
constraint_pred.q, i
449-
)
450-
existing = inputs.get(result)
451-
if existing:
452-
# Add equality constraint if result already has a position
453-
deeper, shallower = (
454-
(constraint_pos, existing)
455-
if constraint_pos.get_operation_depth()
456-
> existing.get_operation_depth()
457-
else (existing, constraint_pos)
458-
)
459-
eq_pred = Predicate.get_equal_to(shallower)
460-
predicates.append(
461-
PositionalPredicate(
462-
q=eq_pred.q, a=eq_pred.a, position=deeper
427+
match op:
428+
case pdl.AttributeOp():
429+
if op.output not in inputs:
430+
if op.value:
431+
# Create literal position for constant attribute
432+
attr_pos = AttributeLiteralPosition(
433+
value=op.value, parent=None
463434
)
464-
)
465-
else:
466-
inputs[result] = constraint_pos
435+
inputs[op.output] = attr_pos
436+
437+
case pdl.ApplyNativeConstraintOp():
438+
# Collect all argument positions
439+
arg_positions = tuple(inputs.get(arg) for arg in op.args)
440+
for pos in arg_positions:
441+
assert pos is not None
442+
arg_positions = cast(tuple[Position, ...], arg_positions)
443+
444+
# Find the furthest position (deepest)
445+
furthest_pos = max(
446+
arg_positions, key=lambda p: p.get_operation_depth() if p else 0
447+
)
467448

468-
predicates.append(
469-
PositionalPredicate(
470-
q=constraint_pred.q, a=constraint_pred.a, position=furthest_pos
449+
# Create the constraint predicate
450+
result_types = tuple(r.type for r in op.res)
451+
# TODO: is_negated is not part of the dialect definition yet
452+
is_negated = False
453+
constraint_pred = Predicate.get_constraint(
454+
op.constraint_name.data, arg_positions, result_types, is_negated
471455
)
472-
)
473456

474-
elif isinstance(op, pdl.ResultOp):
475-
# Ensure result exists
476-
if op.val not in inputs:
477-
assert isinstance(op.parent_.owner, pdl.OperationOp)
478-
parent_pos = inputs.get(op.parent_.owner.op)
479-
if parent_pos and isinstance(parent_pos, OperationPosition):
480-
result_pos = parent_pos.get_result(op.index.value.data)
481-
is_not_null = Predicate.get_is_not_null()
482-
predicates.append(
483-
PositionalPredicate(
484-
q=is_not_null.q, a=is_not_null.a, position=result_pos
457+
# Register positions for constraint results
458+
for i, result in enumerate(op.results):
459+
assert isinstance(constraint_pred.q, ConstraintQuestion)
460+
constraint_pos = ConstraintPosition.get_constraint(
461+
constraint_pred.q, i
462+
)
463+
existing = inputs.get(result)
464+
if existing:
465+
# Add equality constraint if result already has a position
466+
deeper, shallower = (
467+
(constraint_pos, existing)
468+
if constraint_pos.get_operation_depth()
469+
> existing.get_operation_depth()
470+
else (existing, constraint_pos)
471+
)
472+
eq_pred = Predicate.get_equal_to(shallower)
473+
predicates.append(
474+
PositionalPredicate(
475+
q=eq_pred.q, a=eq_pred.a, position=deeper
476+
)
485477
)
478+
else:
479+
inputs[result] = constraint_pos
480+
481+
predicates.append(
482+
PositionalPredicate(
483+
q=constraint_pred.q,
484+
a=constraint_pred.a,
485+
position=furthest_pos,
486486
)
487+
)
487488

488-
elif isinstance(op, pdl.ResultsOp):
489-
# Handle result groups
490-
if op.val not in inputs:
491-
assert isinstance(op.parent_.owner, pdl.OperationOp)
492-
parent_pos = inputs.get(op.parent_.owner.op)
493-
if parent_pos and isinstance(parent_pos, OperationPosition):
494-
is_variadic = isinstance(op.val.type, pdl.RangeType)
495-
index = op.index.value.data if op.index else None
496-
result_pos = parent_pos.get_result_group(index, is_variadic)
497-
if index is not None:
489+
case pdl.ResultOp():
490+
# Ensure result exists
491+
if op.val not in inputs:
492+
assert isinstance(op.parent_.owner, pdl.OperationOp)
493+
parent_pos = inputs.get(op.parent_.owner.op)
494+
if parent_pos and isinstance(parent_pos, OperationPosition):
495+
result_pos = parent_pos.get_result(op.index.value.data)
498496
is_not_null = Predicate.get_is_not_null()
499497
predicates.append(
500498
PositionalPredicate(
@@ -504,20 +502,42 @@ def extract_non_tree_predicates(
504502
)
505503
)
506504

507-
elif isinstance(op, pdl.TypeOp):
508-
# Handle constant types
509-
if op.result not in inputs and op.constantType:
510-
type_pos = TypeLiteralPosition.get_type_literal(
511-
value=op.constantType
512-
)
513-
inputs[op.result] = type_pos
505+
case pdl.ResultsOp():
506+
# Handle result groups
507+
if op.val not in inputs:
508+
assert isinstance(op.parent_.owner, pdl.OperationOp)
509+
parent_pos = inputs.get(op.parent_.owner.op)
510+
if parent_pos and isinstance(parent_pos, OperationPosition):
511+
is_variadic = isinstance(op.val.type, pdl.RangeType)
512+
index = op.index.value.data if op.index else None
513+
result_pos = parent_pos.get_result_group(index, is_variadic)
514+
if index is not None:
515+
is_not_null = Predicate.get_is_not_null()
516+
predicates.append(
517+
PositionalPredicate(
518+
q=is_not_null.q,
519+
a=is_not_null.a,
520+
position=result_pos,
521+
)
522+
)
514523

515-
elif isinstance(op, pdl.TypesOp):
516-
# Handle constant type arrays
517-
if op.result not in inputs and op.constantTypes:
518-
type_pos = TypeLiteralPosition.get_type_literal(
519-
value=op.constantTypes
520-
)
521-
inputs[op.result] = type_pos
524+
case pdl.TypeOp():
525+
# Handle constant types
526+
if op.result not in inputs and op.constantType:
527+
type_pos = TypeLiteralPosition.get_type_literal(
528+
value=op.constantType
529+
)
530+
inputs[op.result] = type_pos
531+
532+
case pdl.TypesOp():
533+
# Handle constant type arrays
534+
if op.result not in inputs and op.constantTypes:
535+
type_pos = TypeLiteralPosition.get_type_literal(
536+
value=op.constantTypes
537+
)
538+
inputs[op.result] = type_pos
539+
540+
case _:
541+
pass
522542

523543
return predicates

0 commit comments

Comments
 (0)