Skip to content

Commit fea3124

Browse files
authored
Merge pull request #133 from vatai/132-insert-polly-in-middle-of-o3-list
132 insert polly in middle of o3 list
2 parents 2b27eba + b13cb0e commit fea3124

File tree

4 files changed

+192
-64
lines changed

4 files changed

+192
-64
lines changed

examples/inputs/fdepnodep.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
program main
22
implicit none
3-
integer, parameter :: N = 1000
3+
integer, parameter :: N = 3000
44
real(8) :: A(N, N)
55
integer :: i, count_rate, count_start, count_end
66
real(8) :: walltime

tadashi/isl.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ cdef extern from "isl/union_map.h":
7676
isl_union_map *isl_union_map_intersect_domain(isl_union_map *m1, isl_union_map *m2)
7777
isl_union_map *isl_union_map_intersect_domain_union_set(isl_union_map *map, isl_union_set *set)
7878
const char *isl_union_map_to_str(isl_union_map *map)
79+
ctypedef isl_union_map* union_map
7980

8081
# ---
8182
# aff

tadashi/passesparser.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/bin/env python
2+
import subprocess
3+
from pprint import pprint
4+
from typing import Optional
5+
6+
7+
class PassParser:
8+
_pass_tree: Optional[list[str | tuple]]
9+
10+
def __init__(self):
11+
cmd = [
12+
"opt",
13+
"-S",
14+
"-O3",
15+
"--print-pipeline-passes",
16+
"/dev/null",
17+
"-o",
18+
"/dev/null",
19+
]
20+
proc = subprocess.run(cmd, capture_output=True, check=True)
21+
self.passes_str = proc.stdout.decode().strip()
22+
self._pass_tree = None
23+
24+
def pass_tree(self) -> list:
25+
if self._pass_tree is None:
26+
self._pass_tree = self.parse(0, len(self.passes_str))
27+
return self._pass_tree
28+
29+
def parse(self, begin: int, end: int) -> list:
30+
cur = begin
31+
results = []
32+
while cur < end:
33+
if self.passes_str[cur] == ",":
34+
results.append(self.passes_str[begin:cur])
35+
begin = cur + 1
36+
elif self.passes_str[cur] == "(":
37+
key = self.passes_str[begin:cur]
38+
begin = cur + 1
39+
cur = self._closing_cur(cur)
40+
assert self.passes_str[cur - 1] == ")"
41+
subtree = self.parse(begin, cur - 1)
42+
results.append((key, subtree))
43+
begin = cur + 1
44+
cur += 1
45+
if begin < cur:
46+
results.append(self.passes_str[begin:cur])
47+
return results
48+
49+
def _closing_cur(self, cur: int):
50+
cur += 1
51+
num_open = 1
52+
while num_open != 0:
53+
if self.passes_str[cur] == "(":
54+
num_open += 1
55+
if self.passes_str[cur] == ")":
56+
num_open -= 1
57+
cur += 1
58+
return cur
59+
60+
@staticmethod
61+
def reassemble(passes: list | tuple):
62+
if isinstance(passes, list):
63+
flat = [
64+
p if isinstance(p, str) else PassParser.reassemble(p) for p in passes
65+
]
66+
return ",".join(flat)
67+
elif isinstance(passes, tuple):
68+
fn, subtree = passes
69+
return f"{fn}({PassParser.reassemble(subtree)})"
70+
else:
71+
raise ValueError("This shouldn't happen")
72+
73+
def find(self, prefix: str):
74+
return self._find(prefix, self.pass_tree())
75+
76+
@staticmethod
77+
def _find(prefix: str, subtree: list[str | tuple]):
78+
locs = []
79+
for i, node in enumerate(subtree):
80+
if isinstance(node, tuple):
81+
rest = PassParser._find(prefix, node[1])
82+
if rest:
83+
for r in rest:
84+
locs.append([i] + r)
85+
else:
86+
if node.startswith(prefix):
87+
locs.append([i])
88+
return locs
89+
90+
def split(self, locs: list[int]):
91+
return self._split(locs, self.pass_tree())
92+
93+
@staticmethod
94+
def _split(locs: list[int], subtree: list[tuple]):
95+
head, *tail = locs
96+
l, r = [], []
97+
if any(tail):
98+
k, v = subtree[head]
99+
l, r = PassParser._split(tail, v)
100+
left = subtree[0:head] + [(k, l)]
101+
right = [(k, r)] + subtree[head + 1 :]
102+
else:
103+
left = subtree[0:head]
104+
right = subtree[head:]
105+
return left, right
106+
107+
108+
def main():
109+
pp = PassParser()
110+
full = pp.pass_tree()
111+
reassembled = pp.reassemble(full)
112+
assert pp.passes_str == reassembled
113+
114+
locs = pp.find("loop-rotate")
115+
print(reassembled)
116+
print(locs)
117+
l, r = pp.split(locs[1])
118+
print(pp.reassemble(full))
119+
print("----------")
120+
print(pp.reassemble(l))
121+
print("----------")
122+
print(pp.reassemble(r))
123+
124+
125+
if __name__ == "__main__":
126+
main()

tadashi/translators.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from cython.cimports.tadashi.codegen import codegen
1919
from cython.cimports.tadashi.scop import Scop
2020

21+
from .passesparser import PassParser
22+
2123
ABC_ERROR_MSG = "Translator is an abstract base class, use a derived class."
2224
DOUBLE_SET_SOURCE = "Translator.set_source() should only be called once."
2325

@@ -219,9 +221,16 @@ class Polly(Translator):
219221
compiler: str
220222
json_paths: list[Path]
221223
cwd: Path
224+
before_polly_passes: str
225+
after_polly_passes: str
222226

223227
def __init__(self, compiler: str = "clang"):
224228
self.compiler = str(compiler)
229+
pp = PassParser()
230+
locs = pp.find("loop-rotate")
231+
before, after = pp.split(locs[1])
232+
self.before_polly_passes = pp.reassemble(before)
233+
self.after_polly_passes = pp.reassemble(after)
225234

226235
def _run(self, cmd: list[str], description: str, cwd: str = None):
227236
"""cmd is command list, description is verb-ing, cwd defailts to self.cwd"""
@@ -259,43 +268,35 @@ def populate_ccscops(self, options: list[str]):
259268
raise MemoryError()
260269
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
261270
self.cwd = Path(tempfile.mkdtemp(prefix=f"tadashi-{timestamp}-"))
262-
self._get_preopt_bitcode(options)
271+
self._get_pre_polly_bc(options)
263272
stderr = self._export_jscops(options)
264273
self.json_paths = self._fill_json_paths(stderr)
265274
for file in self.json_paths:
266275
with open(self.cwd / file) as fp:
267276
jscop = json.load(fp)
268277
self._proc_jscop(jscop)
269278

270-
def _get_preopt_bitcode(self, options: list[str]) -> Path:
271-
output = self.cwd / self.source.with_suffix(".O1.bc").name
272-
if output.exists():
273-
return output
279+
def _get_pre_polly_bc(self, options: list[str]) -> Path:
280+
compile_O0_bc = self.cwd / self.source.with_suffix(".pre_polly.bc").name
281+
pre_polly_bc = self.cwd / self.source.with_suffix(".pre_polly.bc").name
282+
if pre_polly_bc.exists():
283+
return pre_polly_bc
274284

275285
compiler_opts = {
276-
"clang": [
277-
"-O0",
278-
"-Xclang",
279-
"-disable-O0-optnone",
280-
],
281-
"flang": ["-O1"],
286+
"clang": ["-O0", "-Xclang", "-disable-O0-optnone"],
287+
"flang": ["-O0"],
282288
}
283-
cmd = [
284-
self.compiler,
285-
*options,
286-
"-c",
287-
"-emit-llvm",
288-
str(self.source),
289-
*compiler_opts[self.compiler[:5]],
290-
"-o",
291-
str(output),
292-
]
293-
self._run(cmd, "parsing")
294-
return output
289+
compile_cmd = [self.compiler, *options, "-c", "-emit-llvm", str(self.source)]
290+
compile_cmd += [*compiler_opts[self.compiler[:5]], "-o", str(compile_O0_bc)]
291+
self._run(compile_cmd, "compiling with O0")
292+
opt_cmd = ["opt", f"-passes={self.before_polly_passes}"]
293+
opt_cmd += [str(compile_O0_bc), f"-o={str(pre_polly_bc)}"]
294+
self._run(opt_cmd, "running pre polly opt passes")
295+
return pre_polly_bc
295296

296297
def _export_jscops(self, options: list[str]) -> str:
297298
cmd = self._polly() + [
298-
str(self._get_preopt_bitcode(options)),
299+
str(self._get_pre_polly_bc(options)),
299300
"-polly-export-jscop",
300301
"-o=/dev/null",
301302
]
@@ -351,22 +352,28 @@ def _proc_jscop(self, jscop):
351352
self.ccscops.emplace_back(domain, sched, read, write)
352353

353354
def legal(self) -> bool:
354-
input_path = str(self._get_preopt_bitcode([]))
355+
input_path = str(self._get_pre_polly_bc([]))
355356
cmd = self._polly() + [input_path, "-polly-import-jscop", "-o=/dev/null"]
356357
proc = self._run(cmd, "checking legality")
357358
return proc.returncode == 0
358359

359360
def _import_jscops(self, options: list[str]) -> Path:
360-
input_path = str(self._get_preopt_bitcode(options))
361-
output = self.cwd / self.source.with_suffix(".bc").name
362-
cmd = self._polly() + [
363-
input_path,
364-
"-polly-import-jscop",
361+
input_path = str(self._get_pre_polly_bc(options))
362+
post_polly_bc = self.cwd / self.source.with_suffix(".post_opt.bc").name
363+
364+
polly_cmd = self._polly() + [input_path, "-polly-import-jscop"]
365+
polly_cmd += [
365366
"-polly-codegen",
366-
f"-o={str(output)}",
367+
f"-o={str(post_polly_bc)}",
367368
"-disable-polly-legality",
368369
]
369-
self._run(cmd, "imnporting jscops")
370+
self._run(polly_cmd, "importing jscops")
371+
372+
output = self.cwd / self.source.with_suffix(".bc").name
373+
374+
opt_cmd = ["opt", f"-passes={self.after_polly_passes}"]
375+
opt_cmd += [str(post_polly_bc), f"-o={str(output)}"]
376+
opt_proc = self._run(opt_cmd, "running opt after jscops import")
370377
return output
371378

372379
@cython.ccall
@@ -375,39 +382,33 @@ def generate_code(
375382
) -> Path:
376383
for scop_idx, jscop_path in enumerate(self.json_paths):
377384
jscop_path = self.cwd / jscop_path
378-
ccscop = self.ccscops[scop_idx]
379-
sched = isl.isl_schedule_node_get_schedule(ccscop.current_node)
380-
umap = isl.isl_schedule_get_map(sched)
381-
isl.isl_schedule_free(sched)
382-
sched_str = isl.isl_union_map_to_str(umap).decode()
383385
backup_path = jscop_path.with_suffix(jscop_path.suffix + ".bak")
384386
shutil.copy2(jscop_path, backup_path)
385-
with jscop_path.open("r", encoding="utf-8") as f:
386-
jscop = json.load(f)
387-
for stmt in jscop["statements"]:
388-
name = stmt["name"]
389-
uset = isl.isl_union_set_read_from_str(
390-
self.ctx, stmt["domain"].encode()
391-
)
392-
stmt_map = isl.isl_union_map_intersect_domain_union_set(
393-
isl.isl_union_map_copy(umap), uset
394-
)
395-
stmt["schedule"] = isl.isl_union_map_to_str(stmt_map).decode()
396-
isl.isl_union_map_free(stmt_map)
397-
isl.isl_union_map_free(umap)
398-
with jscop_path.open("w", encoding="utf-8") as f:
399-
json.dump(jscop, f, indent=2)
400-
output_path = Path(output_path).with_suffix(".s")
401-
self._generate_assembly(output_path, options)
402-
return output_path
387+
self._update_jscop(jscop_path, scop_idx)
388+
return self._compile_to_obj(output_path, options)
403389

404-
def _generate_assembly(self, output: Path, options: list[str]):
390+
@cython.cfunc
391+
def _update_jscop(self, jscop_path: Path, scop_idx: int):
392+
ccscop = self.ccscops[scop_idx]
393+
sched = isl.isl_schedule_node_get_schedule(ccscop.current_node)
394+
isl.isl_schedule_free(sched)
395+
umap = isl.isl_schedule_get_map(sched)
396+
with jscop_path.open("r", encoding="utf-8") as f:
397+
jscop = json.load(f)
398+
for stmt in jscop["statements"]:
399+
uset = isl.isl_union_set_read_from_str(self.ctx, stmt["domain"].encode())
400+
tmp = isl.isl_union_map_copy(umap)
401+
stmt_map = isl.isl_union_map_intersect_domain_union_set(tmp, uset)
402+
stmt["schedule"] = isl.isl_union_map_to_str(stmt_map).decode()
403+
isl.isl_union_map_free(stmt_map)
404+
with jscop_path.open("w", encoding="utf-8") as f:
405+
json.dump(jscop, f, indent=2)
406+
isl.isl_union_map_free(umap)
407+
408+
def _compile_to_obj(self, output: Path, options: list[str]):
409+
output_path = Path(output).with_suffix(".o")
405410
input_path = str(self._import_jscops(options))
406-
cmd = [
407-
"llc",
408-
"-relocation-model=pic",
409-
# *options,
410-
input_path,
411-
f"-o={str(output)}",
412-
]
411+
cmd = ["llc", "-relocation-model=pic", "--filetype=obj"]
412+
cmd += [input_path, f"-o={str(output_path)}"]
413413
self._run(cmd, "generating assembly")
414+
return output_path

0 commit comments

Comments
 (0)