@@ -17,29 +17,36 @@ class miniAMR(App):
1717 def __init__ (
1818 self ,
1919 num_ranks : int ,
20- source : Path = BASE_PATH / "stencil.c" ,
2120 run_args : Optional [list [str ]] = None ,
22- compiler_options : list = None ,
2321 base : Path = BASE_PATH ,
22+ * ,
23+ source : Path = BASE_PATH / "stencil.c" ,
24+ translator : Optional [Translator ] = None ,
25+ compiler_options : list = None ,
26+ ephemeral : bool = False ,
27+ populate_scops : bool = True ,
2428 ):
2529 self .base = base
2630 self .num_ranks = num_ranks
2731 if not run_args :
2832 run_args = []
2933 self .run_args = run_args
34+ # todo, move this to amend_compiler_options
3035 include_paths = (
31- self .mpich_includes ()
32- + self .gcc_includes ("gcc" )
33- + self .gcc_includes ("mpicc" )
36+ self ._mpich_includes ()
37+ + self ._gcc_includes ("gcc" )
38+ + self ._gcc_includes ("mpicc" )
3439 )
35- self . _finalize_object (
36- source = source ,
37- include_paths = include_paths ,
40+ super (). __init__ (
41+ source ,
42+ translator ,
3843 compiler_options = compiler_options ,
44+ ephemeral = ephemeral ,
45+ populate_scops = populate_scops ,
3946 )
4047
4148 @staticmethod
42- def mpich_includes ():
49+ def _mpich_includes ():
4350 cmd = ["mpicc" , "-compile_info" ]
4451 result = run (cmd , stdout = PIPE , stderr = DEVNULL , check = False )
4552 if result .returncode == 1 :
@@ -50,7 +57,7 @@ def mpich_includes():
5057 return include_paths
5158
5259 @staticmethod
53- def gcc_includes (compiler ):
60+ def _gcc_includes (compiler ):
5461 cmd = [compiler , "-xc" , "-E" , "-v" , "/dev/null" ]
5562 result = run (cmd , stdout = DEVNULL , stderr = PIPE , check = False )
5663 if result .returncode == 1 :
@@ -67,21 +74,12 @@ def gcc_includes(compiler):
6774 collect = True
6875 return include_paths
6976
70- def generate_code (self , alt_source : str = None , ephemeral = True ):
71- if alt_source :
72- assert str (alt_source ).endswith (".c" )
73- assert Path (alt_source ).name == str (alt_source )
74- new_file = self .source .parent / alt_source
75- else :
76- new_file = self .make_new_filename ()
77- self .scops .generate_code (self .source , Path (new_file ))
78- kwargs = {
79- "source" : new_file ,
80- "base" : self .base ,
77+ def codegen_init_args (self ):
78+ return {
8179 "num_ranks" : self .num_ranks ,
80+ "base" : self .base ,
8281 "run_args" : self .run_args ,
8382 }
84- return self .make_new_app (ephemeral , ** kwargs )
8583
8684 @property
8785 def compile_cmd (self ) -> list [str ]:
@@ -96,7 +94,15 @@ def compile_cmd(self) -> list[str]:
9694
9795 @property
9896 def run_cmd (self ) -> list [str ]:
99- cmd = ["mpirun" , "-N" , str (self .num_ranks ), str (self .output_binary ), "--stencil" , "0" , * self .run_args ,]
97+ cmd = [
98+ "mpirun" ,
99+ "-N" ,
100+ str (self .num_ranks ),
101+ str (self .output_binary ),
102+ "--stencil" ,
103+ "0" ,
104+ * self .run_args ,
105+ ]
100106 return cmd
101107
102108 def extract_runtime (self , stdout : str ) -> float :
@@ -125,30 +131,30 @@ def main():
125131
126132 node = app .scops [scop_idx ].schedule_tree [16 ]
127133 tr = [
128- [16 , TrEnum .FULL_SPLIT ],
129- [20 , TrEnum .INTERCHANGE ],
130- [15 , TrEnum .FULL_FUSE ],
131- [11 , TrEnum .FULL_SPLIT ],
132- [10 , TrEnum .FULL_FUSE ],
133- [6 , TrEnum .FULL_SPLIT ],
134- [17 , TrEnum .FULL_SPLIT ],
135- # [5, TrEnum.FULL_FUSE],
136- ]
137- legals = app .scops [scop_idx ].transform_list (tr )
138- print (f"{ legals = } " )
139- if not all ( legals ) :
134+ [16 , TrEnum .FULL_SPLIT ],
135+ [20 , TrEnum .INTERCHANGE ],
136+ [15 , TrEnum .FULL_FUSE ],
137+ [11 , TrEnum .FULL_SPLIT ],
138+ [10 , TrEnum .FULL_FUSE ],
139+ [6 , TrEnum .FULL_SPLIT ],
140+ [17 , TrEnum .FULL_SPLIT ],
141+ # [5, TrEnum.FULL_FUSE],
142+ ]
143+ app .scops [scop_idx ].transform_list (tr )
144+ print (f"{ app . legal = } " )
145+ if not app . legal :
140146 return
141147 for i , node in enumerate (app .scops [scop_idx ].schedule_tree ):
142148 at = node .available_transformations
143149 if at :
144150 print (f"{ i } { at } " )
145151 # return
146152
147- repeat = 10
153+ repeat = 10
148154 app .compile ()
149155 orig_time = app .measure (repeat )
150156 for ts in [61 ]:
151- #tr = [
157+ # tr = [
152158 # [16, TrEnum.FULL_SPLIT],
153159 # [20, TrEnum.INTERCHANGE],
154160 # [15, TrEnum.FULL_FUSE],
@@ -157,9 +163,9 @@ def main():
157163 # [6, TrEnum.FULL_SPLIT],
158164 # ]
159165 app .scops [scop_idx ].reset ()
160- legals = app .scops [scop_idx ].transform_list (tr )
161- print (f"{ legals = } " )
162- if not all ( legals ) :
166+ app .scops [scop_idx ].transform_list (tr )
167+ print (f"{ app . legal = } " )
168+ if not app . legal :
163169 continue
164170 tapp = app .generate_code (f"{ ts = } .c" , ephemeral = False )
165171 tapp .compile ()
0 commit comments