@@ -79,6 +79,11 @@ def to_string(self, gvn: InstrId) -> str:
79
79
return type (self ).__name__
80
80
81
81
82
+ @dataclasses .dataclass (eq = False )
83
+ class Nop (Instr ):
84
+ pass
85
+
86
+
82
87
@dataclasses .dataclass (eq = False )
83
88
class Const (Instr ):
84
89
value : Object
@@ -378,6 +383,8 @@ def to_string(self, fn: IRFunction, gvn: InstrId) -> str:
378
383
result += f" { block .name ()} {{\n "
379
384
for instr in block .instrs :
380
385
instr = instr .find ()
386
+ if isinstance (instr , Nop ):
387
+ continue
381
388
if isinstance (instr , Control ):
382
389
result += f" { instr .to_string (gvn )} \n "
383
390
else :
@@ -540,7 +547,10 @@ def op(idx: int) -> str:
540
547
def _to_c (self , f : io .StringIO , block : Block , gvn : InstrId ) -> None :
541
548
f .write (f"{ block .name ()} :;\n " )
542
549
for instr in block .instrs :
543
- f .write (self ._instr_to_c (instr .find (), gvn ))
550
+ instr = instr .find ()
551
+ if isinstance (instr , Nop ):
552
+ continue
553
+ f .write (self ._instr_to_c (instr , gvn ))
544
554
545
555
546
556
class Compiler :
@@ -995,6 +1005,45 @@ def remove_unreachable_blocks(self) -> bool:
995
1005
return len (self .fn .cfg .blocks ) != num_blocks
996
1006
997
1007
1008
+ @dataclasses .dataclass
1009
+ class DeadCodeElimination :
1010
+ fn : IRFunction
1011
+
1012
+ def is_critical (self , instr : Instr ) -> bool :
1013
+ if isinstance (instr , Const ):
1014
+ return False
1015
+ if isinstance (instr , IntAdd ):
1016
+ return False
1017
+ # TODO(max): Add more. Track heap effects?
1018
+ return True
1019
+
1020
+ def run (self ) -> None :
1021
+ worklist : list [Instr ] = []
1022
+ marked : set [Instr ] = set ()
1023
+ blocks = self .fn .cfg .rpo ()
1024
+ # Mark
1025
+ for block in blocks :
1026
+ for instr in block .instrs :
1027
+ instr = instr .find ()
1028
+ if self .is_critical (instr ):
1029
+ marked .add (instr )
1030
+ worklist .append (instr )
1031
+ while worklist :
1032
+ instr = worklist .pop (0 ).find ()
1033
+ if isinstance (instr , HasOperands ):
1034
+ for op in instr .operands :
1035
+ op = op .find ()
1036
+ if op not in marked :
1037
+ marked .add (op )
1038
+ worklist .append (op )
1039
+ # Sweep
1040
+ for block in blocks :
1041
+ for instr in block .instrs :
1042
+ instr = instr .find ()
1043
+ if instr not in marked :
1044
+ instr .make_equal_to (Nop ())
1045
+
1046
+
998
1047
def _parse (source : str ) -> Object :
999
1048
return parse (tokenize (source ))
1000
1049
@@ -2144,6 +2193,46 @@ def test_const_list(self) -> None:
2144
2193
self .assertEqual (analysis .instr_type [returned ], CList ())
2145
2194
2146
2195
2196
+ class DeadCodeEliminationTests (unittest .TestCase ):
2197
+ def test_remove_const (self ) -> None :
2198
+ compiler = Compiler ()
2199
+ compiler .emit (Const (1 ))
2200
+ compiler .emit (Const (2 ))
2201
+ compiler .emit (Const (3 ))
2202
+ four = compiler .emit (Const (4 ))
2203
+ compiler .emit (Return (four ))
2204
+ DeadCodeElimination (compiler .fn ).run ()
2205
+ self .assertEqual (
2206
+ compiler .fn .to_string (InstrId ()),
2207
+ """\
2208
+ fn0 {
2209
+ bb0 {
2210
+ v0 = Const<4>
2211
+ Return v0
2212
+ }
2213
+ }""" ,
2214
+ )
2215
+
2216
+ def test_remove_int_add (self ) -> None :
2217
+ compiler = Compiler ()
2218
+ one = compiler .emit (Const (1 ))
2219
+ two = compiler .emit (Const (2 ))
2220
+ compiler .emit (IntAdd (one , two ))
2221
+ four = compiler .emit (Const (4 ))
2222
+ compiler .emit (Return (four ))
2223
+ DeadCodeElimination (compiler .fn ).run ()
2224
+ self .assertEqual (
2225
+ compiler .fn .to_string (InstrId ()),
2226
+ """\
2227
+ fn0 {
2228
+ bb0 {
2229
+ v0 = Const<4>
2230
+ Return v0
2231
+ }
2232
+ }""" ,
2233
+ )
2234
+
2235
+
2147
2236
def opt (fn : IRFunction ) -> None :
2148
2237
CleanCFG (fn ).run ()
2149
2238
instr_type = SCCP (fn ).run ()
@@ -2152,6 +2241,7 @@ def opt(fn: IRFunction) -> None:
2152
2241
match instr_type [instr ]:
2153
2242
case CInt (int (i )):
2154
2243
instr .make_equal_to (Const (Int (i )))
2244
+ DeadCodeElimination (fn ).run ()
2155
2245
2156
2246
2157
2247
class OptTests (unittest .TestCase ):
@@ -2164,12 +2254,8 @@ def test_int_add(self) -> None:
2164
2254
"""\
2165
2255
fn0 {
2166
2256
bb0 {
2167
- v0 = Const<1>
2168
- v1 = Const<2>
2169
- v2 = Const<3>
2170
- v3 = Const<5>
2171
- v4 = Const<6>
2172
- Return v4
2257
+ v0 = Const<6>
2258
+ Return v0
2173
2259
}
2174
2260
}""" ,
2175
2261
)
0 commit comments