@@ -99,30 +99,31 @@ def extract_tree_predicates(
9999 inputs [value ] = position
100100
101101 # Dispatch based on position type (not value type!)
102- if isinstance (position , AttributePosition ):
103- assert isinstance (value , OpResult )
104- predicates .extend (
105- self ._extract_attribute_predicates (value .owner , position , inputs )
106- )
107- elif isinstance (position , OperationPosition ):
108- assert isinstance (value , OpResult )
109- predicates .extend (
110- self ._extract_operation_predicates (
111- value .owner , position , inputs , ignore_operand
102+ match position :
103+ case AttributePosition ():
104+ assert isinstance (value , OpResult )
105+ predicates .extend (
106+ self ._extract_attribute_predicates (value .owner , position , inputs )
112107 )
113- )
114- elif isinstance (position , TypePosition ):
115- assert isinstance (value , OpResult )
116- predicates .extend (
117- self ._extract_type_predicates (value .owner , position , inputs )
118- )
119- elif isinstance (position , OperandPosition | OperandGroupPosition ):
120- assert isinstance (value , SSAValue )
121- predicates .extend (
122- self ._extract_operand_tree_predicates (value , position , inputs )
123- )
124- else :
125- raise TypeError (f"Unexpected position kind: { type (position )} " )
108+ case OperationPosition ():
109+ assert isinstance (value , OpResult )
110+ predicates .extend (
111+ self ._extract_operation_predicates (
112+ value .owner , position , inputs , ignore_operand
113+ )
114+ )
115+ case TypePosition ():
116+ assert isinstance (value , OpResult )
117+ predicates .extend (
118+ self ._extract_type_predicates (value .owner , position , inputs )
119+ )
120+ case OperandPosition () | OperandGroupPosition ():
121+ assert isinstance (value , SSAValue )
122+ predicates .extend (
123+ self ._extract_operand_tree_predicates (value , position , inputs )
124+ )
125+ case _:
126+ raise TypeError (f"Unexpected position kind: { type (position )} " )
126127
127128 return predicates
128129
@@ -306,73 +307,82 @@ def _extract_operand_tree_predicates(
306307 defining_op = operand_value .owner
307308 is_variadic = isinstance (operand_value .type , pdl .RangeType )
308309
309- if isinstance (defining_op , pdl .OperandOp | pdl .OperandsOp ):
310- if isinstance (defining_op , pdl .OperandOp ):
311- is_not_null = Predicate .get_is_not_null ()
312- predicates .append (
313- PositionalPredicate (
314- q = is_not_null .q , a = is_not_null .a , position = operand_pos
315- )
316- )
317- elif (
318- isinstance (operand_pos , OperandGroupPosition )
319- and operand_pos .group_number is not None
320- ):
321- is_not_null = Predicate .get_is_not_null ()
322- predicates .append (
323- PositionalPredicate (
324- q = is_not_null .q , a = is_not_null .a , position = operand_pos
310+ match defining_op :
311+ case pdl .OperandOp () | pdl .OperandsOp ():
312+ match defining_op :
313+ case pdl .OperandOp ():
314+ is_not_null = Predicate .get_is_not_null ()
315+ predicates .append (
316+ PositionalPredicate (
317+ q = is_not_null .q , a = is_not_null .a , position = operand_pos
318+ )
319+ )
320+ case pdl .OperandsOp () if (
321+ isinstance (operand_pos , OperandGroupPosition )
322+ and operand_pos .group_number is not None
323+ ):
324+ is_not_null = Predicate .get_is_not_null ()
325+ predicates .append (
326+ PositionalPredicate (
327+ q = is_not_null .q , a = is_not_null .a , position = operand_pos
328+ )
329+ )
330+ case _:
331+ pass
332+
333+ if defining_op .value_type :
334+ type_pos = operand_pos .get_type ()
335+ predicates .extend (
336+ self .extract_tree_predicates (
337+ defining_op .value_type , type_pos , inputs
338+ )
325339 )
326- )
327340
328- if defining_op .value_type :
329- type_pos = operand_pos .get_type ()
330- predicates .extend (
331- self .extract_tree_predicates (
332- defining_op .value_type , type_pos , inputs
341+ case pdl .ResultOp () | pdl .ResultsOp ():
342+ index_attr = defining_op .index
343+ index = index_attr .value .data if index_attr is not None else None
344+
345+ if index is not None :
346+ is_not_null = Predicate .get_is_not_null ()
347+ predicates .append (
348+ PositionalPredicate (
349+ q = is_not_null .q , a = is_not_null .a , position = operand_pos
350+ )
333351 )
334- )
335352
336- elif isinstance ( defining_op , pdl . ResultOp | pdl . ResultsOp ):
337- index_attr = defining_op .index
338- index = index_attr . value . data if index_attr is not None else None
353+ # Get the parent operation position
354+ parent_op = defining_op .parent_
355+ defining_op_pos = operand_pos . get_defining_op ()
339356
340- if index is not None :
357+ # Parent operation should not be null
341358 is_not_null = Predicate .get_is_not_null ()
342359 predicates .append (
343360 PositionalPredicate (
344- q = is_not_null .q , a = is_not_null .a , position = operand_pos
361+ q = is_not_null .q , a = is_not_null .a , position = defining_op_pos
345362 )
346363 )
347364
348- # Get the parent operation position
349- parent_op = defining_op .parent_
350- defining_op_pos = operand_pos .get_defining_op ()
365+ match defining_op :
366+ case pdl .ResultOp ():
367+ result_pos = defining_op_pos .get_result (
368+ index if index is not None else 0
369+ )
370+ case pdl .ResultsOp (): # ResultsOp
371+ result_pos = defining_op_pos .get_result_group (
372+ index , is_variadic
373+ )
351374
352- # Parent operation should not be null
353- is_not_null = Predicate .get_is_not_null ()
354- predicates .append (
355- PositionalPredicate (
356- q = is_not_null .q , a = is_not_null .a , position = defining_op_pos
375+ equal_to = Predicate .get_equal_to (operand_pos )
376+ predicates .append (
377+ PositionalPredicate (q = equal_to .q , a = equal_to .a , position = result_pos )
357378 )
358- )
359379
360- if isinstance ( defining_op , pdl . ResultOp ):
361- result_pos = defining_op_pos . get_result (
362- index if index is not None else 0
380+ # Recursively process the parent operation
381+ predicates . extend (
382+ self . extract_tree_predicates ( parent_op , defining_op_pos , inputs )
363383 )
364- else : # ResultsOp
365- result_pos = defining_op_pos .get_result_group (index , is_variadic )
366-
367- equal_to = Predicate .get_equal_to (operand_pos )
368- predicates .append (
369- PositionalPredicate (q = equal_to .q , a = equal_to .a , position = result_pos )
370- )
371-
372- # Recursively process the parent operation
373- predicates .extend (
374- self .extract_tree_predicates (parent_op , defining_op_pos , inputs )
375- )
384+ case _:
385+ pass
376386
377387 return predicates
378388
@@ -414,87 +424,75 @@ def extract_non_tree_predicates(
414424 predicates : list [PositionalPredicate ] = []
415425
416426 for op in pattern .body .ops :
417- if isinstance (op , pdl .AttributeOp ):
418- if op .output not in inputs :
419- if op .value :
420- # Create literal position for constant attribute
421- attr_pos = AttributeLiteralPosition (value = op .value , parent = None )
422- inputs [op .output ] = attr_pos
423-
424- elif isinstance (op , pdl .ApplyNativeConstraintOp ):
425- # Collect all argument positions
426- arg_positions = tuple (inputs .get (arg ) for arg in op .args )
427- for pos in arg_positions :
428- assert pos is not None
429- arg_positions = cast (tuple [Position , ...], arg_positions )
430-
431- # Find the furthest position (deepest)
432- furthest_pos = max (
433- arg_positions , key = lambda p : p .get_operation_depth () if p else 0
434- )
435-
436- # Create the constraint predicate
437- result_types = tuple (r .type for r in op .res )
438- # TODO: is_negated is not part of the dialect definition yet
439- is_negated = False
440- constraint_pred = Predicate .get_constraint (
441- op .constraint_name .data , arg_positions , result_types , is_negated
442- )
443-
444- # Register positions for constraint results
445- for i , result in enumerate (op .results ):
446- assert isinstance (constraint_pred .q , ConstraintQuestion )
447- constraint_pos = ConstraintPosition .get_constraint (
448- constraint_pred .q , i
449- )
450- existing = inputs .get (result )
451- if existing :
452- # Add equality constraint if result already has a position
453- deeper , shallower = (
454- (constraint_pos , existing )
455- if constraint_pos .get_operation_depth ()
456- > existing .get_operation_depth ()
457- else (existing , constraint_pos )
458- )
459- eq_pred = Predicate .get_equal_to (shallower )
460- predicates .append (
461- PositionalPredicate (
462- q = eq_pred .q , a = eq_pred .a , position = deeper
427+ match op :
428+ case pdl .AttributeOp ():
429+ if op .output not in inputs :
430+ if op .value :
431+ # Create literal position for constant attribute
432+ attr_pos = AttributeLiteralPosition (
433+ value = op .value , parent = None
463434 )
464- )
465- else :
466- inputs [result ] = constraint_pos
435+ inputs [op .output ] = attr_pos
436+
437+ case pdl .ApplyNativeConstraintOp ():
438+ # Collect all argument positions
439+ arg_positions = tuple (inputs .get (arg ) for arg in op .args )
440+ for pos in arg_positions :
441+ assert pos is not None
442+ arg_positions = cast (tuple [Position , ...], arg_positions )
443+
444+ # Find the furthest position (deepest)
445+ furthest_pos = max (
446+ arg_positions , key = lambda p : p .get_operation_depth () if p else 0
447+ )
467448
468- predicates .append (
469- PositionalPredicate (
470- q = constraint_pred .q , a = constraint_pred .a , position = furthest_pos
449+ # Create the constraint predicate
450+ result_types = tuple (r .type for r in op .res )
451+ # TODO: is_negated is not part of the dialect definition yet
452+ is_negated = False
453+ constraint_pred = Predicate .get_constraint (
454+ op .constraint_name .data , arg_positions , result_types , is_negated
471455 )
472- )
473456
474- elif isinstance (op , pdl .ResultOp ):
475- # Ensure result exists
476- if op .val not in inputs :
477- assert isinstance (op .parent_ .owner , pdl .OperationOp )
478- parent_pos = inputs .get (op .parent_ .owner .op )
479- if parent_pos and isinstance (parent_pos , OperationPosition ):
480- result_pos = parent_pos .get_result (op .index .value .data )
481- is_not_null = Predicate .get_is_not_null ()
482- predicates .append (
483- PositionalPredicate (
484- q = is_not_null .q , a = is_not_null .a , position = result_pos
457+ # Register positions for constraint results
458+ for i , result in enumerate (op .results ):
459+ assert isinstance (constraint_pred .q , ConstraintQuestion )
460+ constraint_pos = ConstraintPosition .get_constraint (
461+ constraint_pred .q , i
462+ )
463+ existing = inputs .get (result )
464+ if existing :
465+ # Add equality constraint if result already has a position
466+ deeper , shallower = (
467+ (constraint_pos , existing )
468+ if constraint_pos .get_operation_depth ()
469+ > existing .get_operation_depth ()
470+ else (existing , constraint_pos )
471+ )
472+ eq_pred = Predicate .get_equal_to (shallower )
473+ predicates .append (
474+ PositionalPredicate (
475+ q = eq_pred .q , a = eq_pred .a , position = deeper
476+ )
485477 )
478+ else :
479+ inputs [result ] = constraint_pos
480+
481+ predicates .append (
482+ PositionalPredicate (
483+ q = constraint_pred .q ,
484+ a = constraint_pred .a ,
485+ position = furthest_pos ,
486486 )
487+ )
487488
488- elif isinstance (op , pdl .ResultsOp ):
489- # Handle result groups
490- if op .val not in inputs :
491- assert isinstance (op .parent_ .owner , pdl .OperationOp )
492- parent_pos = inputs .get (op .parent_ .owner .op )
493- if parent_pos and isinstance (parent_pos , OperationPosition ):
494- is_variadic = isinstance (op .val .type , pdl .RangeType )
495- index = op .index .value .data if op .index else None
496- result_pos = parent_pos .get_result_group (index , is_variadic )
497- if index is not None :
489+ case pdl .ResultOp ():
490+ # Ensure result exists
491+ if op .val not in inputs :
492+ assert isinstance (op .parent_ .owner , pdl .OperationOp )
493+ parent_pos = inputs .get (op .parent_ .owner .op )
494+ if parent_pos and isinstance (parent_pos , OperationPosition ):
495+ result_pos = parent_pos .get_result (op .index .value .data )
498496 is_not_null = Predicate .get_is_not_null ()
499497 predicates .append (
500498 PositionalPredicate (
@@ -504,20 +502,42 @@ def extract_non_tree_predicates(
504502 )
505503 )
506504
507- elif isinstance (op , pdl .TypeOp ):
508- # Handle constant types
509- if op .result not in inputs and op .constantType :
510- type_pos = TypeLiteralPosition .get_type_literal (
511- value = op .constantType
512- )
513- inputs [op .result ] = type_pos
505+ case pdl .ResultsOp ():
506+ # Handle result groups
507+ if op .val not in inputs :
508+ assert isinstance (op .parent_ .owner , pdl .OperationOp )
509+ parent_pos = inputs .get (op .parent_ .owner .op )
510+ if parent_pos and isinstance (parent_pos , OperationPosition ):
511+ is_variadic = isinstance (op .val .type , pdl .RangeType )
512+ index = op .index .value .data if op .index else None
513+ result_pos = parent_pos .get_result_group (index , is_variadic )
514+ if index is not None :
515+ is_not_null = Predicate .get_is_not_null ()
516+ predicates .append (
517+ PositionalPredicate (
518+ q = is_not_null .q ,
519+ a = is_not_null .a ,
520+ position = result_pos ,
521+ )
522+ )
514523
515- elif isinstance (op , pdl .TypesOp ):
516- # Handle constant type arrays
517- if op .result not in inputs and op .constantTypes :
518- type_pos = TypeLiteralPosition .get_type_literal (
519- value = op .constantTypes
520- )
521- inputs [op .result ] = type_pos
524+ case pdl .TypeOp ():
525+ # Handle constant types
526+ if op .result not in inputs and op .constantType :
527+ type_pos = TypeLiteralPosition .get_type_literal (
528+ value = op .constantType
529+ )
530+ inputs [op .result ] = type_pos
531+
532+ case pdl .TypesOp ():
533+ # Handle constant type arrays
534+ if op .result not in inputs and op .constantTypes :
535+ type_pos = TypeLiteralPosition .get_type_literal (
536+ value = op .constantTypes
537+ )
538+ inputs [op .result ] = type_pos
539+
540+ case _:
541+ pass
522542
523543 return predicates
0 commit comments