@@ -56,6 +56,18 @@ def get_output(self):
5656 setup += '\n \n '
5757 return setup + self .output
5858
59+
60+ class Preprocessor :
61+
62+ def transform (self , input ):
63+ output = ''
64+ for line in input .splitlines ():
65+ strip = line .strip ()
66+ if strip .startswith ('#' ):
67+ continue
68+ output += line + '\n '
69+ return output
70+
5971class Compiler :
6072 def __init__ (self ):
6173 self .types = {}
@@ -247,6 +259,13 @@ def __init__(self, writer):
247259 self .writer .write_constant (reg .name , self .global_offset )
248260 self .global_offset += 1 # size always 1
249261
262+ self .writer .write_constant ('csp' , self .global_offset )
263+ self .writer .write_instruction ('MOV' , '#0' , str (self .global_offset ))
264+ self .global_offset += 1
265+ self .writer .write_constant ('cbp' , self .global_offset )
266+ self .writer .write_instruction ('MOV' , '#0' , str (self .global_offset ))
267+ self .global_offset += 1
268+
250269 self .current_function = None
251270
252271 def visit_program (self , program ):
@@ -360,6 +379,9 @@ def visit_func_decl(self, stmt):
360379 self .current_function = func
361380 self .functions [name ] = func
362381 self .visit_statements (stmt .body )
382+ if not stmt .body or not isinstance (stmt .body [- 1 ], ReturnStmt ) \
383+ and name != 'main' : # main function doesn't return
384+ self .write ('RET' )
363385 self .current_function = None
364386 self .writer .end_subroutine ()
365387
@@ -433,7 +455,7 @@ def visit_var_decl(self, stmt):
433455 def clear_locals (self ):
434456 self .locals = {}
435457 self .local_labels = {}
436- local_offset = 0
458+ self . local_offset = 0
437459
438460 def add_global (self , type , name ):
439461 assert name not in self .globals
@@ -449,9 +471,10 @@ def add_local(self, type, name):
449471 self .local_offset += type .size
450472 return var
451473
452- def new_temporary_var (self , copy_from = None ):
453- type = copy_from .type if copy_from else Types .types ['int' ]
454- if len (self .free_gpr ) > 0 :
474+ def new_temporary_var (self , copy_from = None , type = None ):
475+ type = type if type else copy_from .type if copy_from else Types .types ['int' ]
476+ assert type .size > 0
477+ if len (self .free_gpr ) > 0 and type .size == 1 :
455478 return self .free_gpr .pop ()
456479 tmp = self .add_local (type , 'tmp_%d' % len (self .locals ))
457480 self .temporary_names .add (tmp .name )
@@ -576,9 +599,15 @@ def move(self, src, dest):
576599
577600 if type (src ) == int :
578601 s_ref , s_off = '#%d' % src , None
602+ size = 1
579603 elif type (src ) == str :
580604 s_ref , s_off = src , None
605+ size = 1
581606 else :
607+ if isinstance (src , Register ):
608+ size = 1
609+ else :
610+ size = src .type .size
582611 src_addr = self .load_address (src )
583612 steps = unwrap (src_addr , lambda : self .__next_volatile ().name )
584613 s_ref , s_off = act_out (steps )
@@ -590,11 +619,21 @@ def move(self, src, dest):
590619 steps = unwrap (dest_addr , lambda : self .__next_volatile ().name )
591620 d_ref , d_off = act_out (steps )
592621
593- move (s_ref , s_off , d_ref , d_off )
622+ def shift (ref , off , shift ):
623+ if shift == 0 :
624+ return (ref , off )
625+ if off is None :
626+ assert type (ref ) == int , type (ref )
627+ return (ref + shift , off )
628+ else :
629+ return (ref , off + shift )
630+
631+ for sh in range (size ):
632+ move (* (shift (s_ref , s_off , sh ) + shift (d_ref , d_off , sh )))
594633
595634 def load_address (self , ref ):
596635 if isinstance (ref , Variable ):
597- return Relative (rel_to = 'sp ' , offset = ref .index )
636+ return Relative (rel_to = 'csp ' , offset = ref .index )
598637 if isinstance (ref , Dereference ):
599638 addr = self .load_address (ref .ref )
600639 return Indirect (addr )
@@ -604,6 +643,8 @@ def load_address(self, ref):
604643 return Direct (ref .name )
605644 if isinstance (ref , Global ):
606645 return Direct (ref .loc )
646+ if isinstance (ref , (Relative , Direct , Offset , Indirect )):
647+ return ref
607648 assert False , type (ref )
608649
609650 def dereference (self , ref ):
@@ -759,7 +800,6 @@ def visit_switch_stmt(self, stmt):
759800 label = self .local_label ('switch_case' )
760801 self .writer .write_local_sub (label )
761802 choice = self .visit_expression (case .choice )
762- print (choice )
763803 assert type (choice ) == int , "not a constant"
764804 cases [choice ] = label
765805 else : # This is the default case
@@ -795,13 +835,18 @@ def visit_break_stmt(self, stmt):
795835 self .break_jump = (label , True )
796836
797837 def visit_return_stmt (self , stmt ):
838+ ret_type = self .current_function .ret_type
798839 if stmt .expr :
799- assert self . current_function . ret_type != 'void'
840+ assert ret_type != Types . types [ 'void' ]
800841 ret = self .visit_expression (stmt .expr )
801- reg = self .gpr [- 1 ].name # Return value always in last register
802- self .write ('MOV' , ret , reg )
842+ if ret_type .size == 1 :
843+ reg = self .gpr [- 1 ].name # Return value always in last register
844+ self .write ('MOV' , ret , reg )
845+ else :
846+ # Move value to position allocated before calling this function
847+ self .write ('MOV' , ret , Relative (rel_to = 'csp' , offset = - ret_type .size ))
803848 else :
804- assert self . current_function . ret_type == 'void'
849+ assert ret_type == Types . types [ 'void' ]
805850 self .write ('RET' )
806851
807852 def visit_goto_stmt (self , stmt ):
@@ -1001,20 +1046,39 @@ def visit_func_call_expr(self, expr):
10011046 assert name in self .functions
10021047 func = self .functions [name ]
10031048 assert len (expr .args ) == len (func .param_types )
1049+ large_ret = func .ret_type .size > 1 # Can't fit in single register
1050+ ret_dest = None
1051+ if large_ret :
1052+ ret_dest = self .new_temporary_var (type = func .ret_type )
1053+
1054+ # shift base pointer to new stack region
1055+ self .write ('MOV' , 'csp' , 'cbp' )
1056+ self .write ('ADD' , self .local_offset , 'cbp' )
1057+ i = 0
10041058 for arg in expr .args :
10051059 var = self .visit_expression (arg )
1006- self .write ('MOV' , var , 'sr' )
1007- self .write ('PUSH' )
1008- if func .ret_type != Types .types ['void' ]:
1009- ret_reg = self .gpr [- 1 ]
1060+ self .write ('MOV' , var , Relative (rel_to = 'cbp' , offset = i ))
1061+ if type (var ) == int :
1062+ i += 1
1063+ else :
1064+ i += var .type .size
1065+ register_saved = False
1066+ if ret_dest is None and func .ret_type != Types .types ['void' ]:
1067+ ret_dest = self .gpr [- 1 ]
10101068 if len (self .free_gpr ) == 0 :
1011- # TODO restore tmp
1012- tmp = self .new_temporary_var ()
1013- self .write ('MOV' , ret_reg , tmp )
1014- else :
1015- ret_reg = None
1069+ register_saved = True
1070+ self .write ('MOV' , ret_reg , 'sr' )
1071+ self .write ('PUSH' )
1072+ self .write ('MOV' , 'csp' , 'sr' )
1073+ self .write ('PUSH' )
1074+ self .write ('MOV' , 'cbp' , 'csp' )
10161075 self .write ('CALL' , name )
1017- return ret_reg
1076+ self .write ('POP' )
1077+ self .write ('MOV' , 'sr' , 'csp' )
1078+ if register_saved :
1079+ self .write ('POP' )
1080+ #TODO delay this self.write('MOV', 'sr', ret_reg)
1081+ return ret_dest
10181082
10191083 def func_printf (self , expr ):
10201084 args = []
@@ -1026,6 +1090,9 @@ def func_printf(self, expr):
10261090
10271091 def string_format (self , template , args ):
10281092 ret = []
1093+ if template == '' :
1094+ assert not args
1095+ return ['""' ]
10291096 section = template
10301097 ind = section .find ('%' )
10311098 while ind != - 1 and args :
0 commit comments