Skip to content

Commit e86281f

Browse files
committed
Update miniAMR app
1 parent e4308a0 commit e86281f

File tree

1 file changed

+46
-40
lines changed
  • examples/evaluation/miniAMR

1 file changed

+46
-40
lines changed

examples/evaluation/miniAMR/app.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,36 @@ class miniAMR(App):
1717
def __init__(
1818
self,
1919
num_ranks: int,
20-
source: Path = BASE_PATH / "stencil.c",
2120
run_args: Optional[list[str]] = None,
22-
compiler_options: list = None,
2321
base: Path = BASE_PATH,
22+
*,
23+
source: Path = BASE_PATH / "stencil.c",
24+
translator: Optional[Translator] = None,
25+
compiler_options: list = None,
26+
ephemeral: bool = False,
27+
populate_scops: bool = True,
2428
):
2529
self.base = base
2630
self.num_ranks = num_ranks
2731
if not run_args:
2832
run_args = []
2933
self.run_args = run_args
34+
# todo, move this to amend_compiler_options
3035
include_paths = (
31-
self.mpich_includes()
32-
+ self.gcc_includes("gcc")
33-
+ self.gcc_includes("mpicc")
36+
self._mpich_includes()
37+
+ self._gcc_includes("gcc")
38+
+ self._gcc_includes("mpicc")
3439
)
35-
self._finalize_object(
36-
source=source,
37-
include_paths=include_paths,
40+
super().__init__(
41+
source,
42+
translator,
3843
compiler_options=compiler_options,
44+
ephemeral=ephemeral,
45+
populate_scops=populate_scops,
3946
)
4047

4148
@staticmethod
42-
def mpich_includes():
49+
def _mpich_includes():
4350
cmd = ["mpicc", "-compile_info"]
4451
result = run(cmd, stdout=PIPE, stderr=DEVNULL, check=False)
4552
if result.returncode == 1:
@@ -50,7 +57,7 @@ def mpich_includes():
5057
return include_paths
5158

5259
@staticmethod
53-
def gcc_includes(compiler):
60+
def _gcc_includes(compiler):
5461
cmd = [compiler, "-xc", "-E", "-v", "/dev/null"]
5562
result = run(cmd, stdout=DEVNULL, stderr=PIPE, check=False)
5663
if result.returncode == 1:
@@ -67,21 +74,12 @@ def gcc_includes(compiler):
6774
collect = True
6875
return include_paths
6976

70-
def generate_code(self, alt_source: str = None, ephemeral=True):
71-
if alt_source:
72-
assert str(alt_source).endswith(".c")
73-
assert Path(alt_source).name == str(alt_source)
74-
new_file = self.source.parent / alt_source
75-
else:
76-
new_file = self.make_new_filename()
77-
self.scops.generate_code(self.source, Path(new_file))
78-
kwargs = {
79-
"source": new_file,
80-
"base": self.base,
77+
def codegen_init_args(self):
78+
return {
8179
"num_ranks": self.num_ranks,
80+
"base": self.base,
8281
"run_args": self.run_args,
8382
}
84-
return self.make_new_app(ephemeral, **kwargs)
8583

8684
@property
8785
def compile_cmd(self) -> list[str]:
@@ -96,7 +94,15 @@ def compile_cmd(self) -> list[str]:
9694

9795
@property
9896
def run_cmd(self) -> list[str]:
99-
cmd = ["mpirun", "-N", str(self.num_ranks), str(self.output_binary), "--stencil", "0", *self.run_args,]
97+
cmd = [
98+
"mpirun",
99+
"-N",
100+
str(self.num_ranks),
101+
str(self.output_binary),
102+
"--stencil",
103+
"0",
104+
*self.run_args,
105+
]
100106
return cmd
101107

102108
def extract_runtime(self, stdout: str) -> float:
@@ -125,30 +131,30 @@ def main():
125131

126132
node = app.scops[scop_idx].schedule_tree[16]
127133
tr = [
128-
[16, TrEnum.FULL_SPLIT],
129-
[20, TrEnum.INTERCHANGE],
130-
[15, TrEnum.FULL_FUSE],
131-
[11, TrEnum.FULL_SPLIT],
132-
[10, TrEnum.FULL_FUSE],
133-
[6, TrEnum.FULL_SPLIT],
134-
[17, TrEnum.FULL_SPLIT],
135-
# [5, TrEnum.FULL_FUSE],
136-
]
137-
legals = app.scops[scop_idx].transform_list(tr)
138-
print(f"{legals=}")
139-
if not all(legals):
134+
[16, TrEnum.FULL_SPLIT],
135+
[20, TrEnum.INTERCHANGE],
136+
[15, TrEnum.FULL_FUSE],
137+
[11, TrEnum.FULL_SPLIT],
138+
[10, TrEnum.FULL_FUSE],
139+
[6, TrEnum.FULL_SPLIT],
140+
[17, TrEnum.FULL_SPLIT],
141+
# [5, TrEnum.FULL_FUSE],
142+
]
143+
app.scops[scop_idx].transform_list(tr)
144+
print(f"{app.legal=}")
145+
if not app.legal:
140146
return
141147
for i, node in enumerate(app.scops[scop_idx].schedule_tree):
142148
at = node.available_transformations
143149
if at:
144150
print(f"{i} {at}")
145151
# return
146152

147-
repeat=10
153+
repeat = 10
148154
app.compile()
149155
orig_time = app.measure(repeat)
150156
for ts in [61]:
151-
#tr = [
157+
# tr = [
152158
# [16, TrEnum.FULL_SPLIT],
153159
# [20, TrEnum.INTERCHANGE],
154160
# [15, TrEnum.FULL_FUSE],
@@ -157,9 +163,9 @@ def main():
157163
# [6, TrEnum.FULL_SPLIT],
158164
# ]
159165
app.scops[scop_idx].reset()
160-
legals = app.scops[scop_idx].transform_list(tr)
161-
print(f"{legals=}")
162-
if not all(legals):
166+
app.scops[scop_idx].transform_list(tr)
167+
print(f"{app.legal=}")
168+
if not app.legal:
163169
continue
164170
tapp = app.generate_code(f"{ts=}.c", ephemeral=False)
165171
tapp.compile()

0 commit comments

Comments
 (0)