Skip to content

Commit ad8cb2d

Browse files
committed
better support for inlining new generics
1 parent 985b389 commit ad8cb2d

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

pylingual/editable_bytecode/EditableBytecode.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def inline_annotate_functions(self):
152152
("SET_FUNCTION_ATTRIBUTE", 8), # closure
153153
("STORE_NAME", "__annotate_func__"),
154154
)
155+
155156
# fmt: on
156157

157158
def try_read_annotate_func(codeobj) -> EditableBytecode | None:
@@ -264,14 +265,23 @@ def is_annotate_func_and_get_inlinable_insts(codeobj) -> tuple[bool, list[Inst]]
264265
# handle function definition annotations
265266
if inst.opname == "LOAD_CONST":
266267
is_annotate_func, inlinable_insts = is_annotate_func_and_get_inlinable_insts(inst.argval)
267-
if not is_annotate_func:
268+
if is_annotate_func:
269+
# replace
270+
# LOAD_CONST __annotate__
271+
# MAKE_FUNCTION
272+
inline_dict[(idx, tuple(self.instructions[idx : idx + 2]))] = inlinable_insts
273+
jump_target_mapping[inst] = inlinable_insts[0]
268274
continue
269275

270-
# replace
271-
# LOAD_CONST __annotate__
272-
# MAKE_FUNCTION
273-
inline_dict[(idx, tuple(self.instructions[idx : idx + 2]))] = inlinable_insts
274-
jump_target_mapping[inst] = inlinable_insts[0]
276+
if iscode(inst.argval) and inst.argval.co_name.startswith("<generic parameters of "):
277+
# fully inline generics
278+
# replace
279+
# LOAD_CONST <generic parameters of <function name>>
280+
# MAKE_FUNCTION
281+
generic_bc = EditableBytecode(inst.argval, self.opcode, self.version)
282+
inline_dict[(idx, tuple(self.instructions[idx : idx + 2]))] = generic_bc.instructions[:-1]
283+
jump_target_mapping[inst] = generic_bc.instructions[0]
284+
continue
275285

276286
# handle inline variable annotations
277287
elif inst.opname in ("LOAD_NAME", "LOAD_DEREF") and inst.argval == "__conditional_annotations__":
@@ -811,7 +821,7 @@ def insert_insts(self, insert_dict: dict[int, list[Inst]]) -> int:
811821
if inst.argval not in self.co_names:
812822
self.co_names.append(inst.argval)
813823
inst.arg = self.co_names.index(inst.argval)
814-
elif inst.optype == "free":
824+
elif inst.optype == "free" or inst.optype == "local":
815825
if inst.argval not in self.co_varnames:
816826
self.co_varnames.append(inst.argval)
817827
inst.arg = self.co_varnames.index(inst.argval)

pylingual/masking/model_disasm.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,17 @@ def create_global_masker(bytecode: EditableBytecode) -> Masker:
6565

6666
# create names
6767
for name in bc_co.co_names:
68-
if name in global_tab:
69-
continue
70-
global_tab.update({bc.resolve_namespace(name): f"<mask_{global_idx}>"})
71-
global_idx += 1
68+
if isinstance(name, tuple):
69+
for n in name:
70+
if n in global_tab:
71+
continue
72+
global_tab.update({bc.resolve_namespace(n): f"<mask_{global_idx}>"})
73+
global_idx += 1
74+
else:
75+
if name in global_tab:
76+
continue
77+
global_tab.update({bc.resolve_namespace(name): f"<mask_{global_idx}>"})
78+
global_idx += 1
7279

7380
for free in bc_co.co_freevars:
7481
if free in global_tab:
@@ -84,10 +91,17 @@ def create_global_masker(bytecode: EditableBytecode) -> Masker:
8491
global_idx += 1
8592

8693
for local in bc_co.co_varnames:
87-
if local in global_tab:
88-
continue
89-
global_tab.update({bc.resolve_namespace(local): f"<mask_{global_idx}>"})
90-
global_idx += 1
94+
if isinstance(local, tuple):
95+
for local_item in local:
96+
if local_item in global_tab:
97+
continue
98+
global_tab.update({bc.resolve_namespace(local_item): f"<mask_{global_idx}>"})
99+
global_idx += 1
100+
else:
101+
if local in global_tab:
102+
continue
103+
global_tab.update({bc.resolve_namespace(local): f"<mask_{global_idx}>"})
104+
global_idx += 1
91105

92106
global_tab.update({bc_co.co_name: f"<mask_{global_idx}>"})
93107
global_idx += 1

0 commit comments

Comments
 (0)