8
8
9
9
def check_wrapping (
10
10
source : str ,
11
- output : Optional [str ] = None ,
12
- num_wrapped : int = 0 ,
11
+ output : str ,
13
12
namespace : Optional [Dict [str , Any ]] = None ,
14
13
ignore : Optional [List [str ]] = None ,
15
14
):
@@ -20,15 +19,20 @@ def check_wrapping(
20
19
wrapper = AutoWrapper (namespace , ignore )
21
20
wrapped = wrapper .auto_wrap (tree )
22
21
23
- if output is not None :
24
- wrapped_lines = ast .unparse (wrapped ).splitlines ()
25
- output_lines = textwrap .dedent (output ).splitlines ()[1 :]
26
- assert wrapped_lines == output_lines
22
+ wrapped_lines = ast .unparse (wrapped ).splitlines ()
23
+ output_lines = textwrap .dedent (output ).splitlines ()[1 :]
27
24
28
- assert len (wrapper ._wrapper_fn_defs ) == num_wrapped
25
+ assert len (wrapped_lines ) == len (output_lines )
26
+ for wrapped_line , output_line in zip (wrapped_lines , output_lines ):
27
+ if "# skip" in output :
28
+ continue
29
+
30
+ assert wrapped_line == output_line
29
31
30
32
31
33
def test_static_if ():
34
+ """Checks that resolvable if statements are replaced"""
35
+
32
36
source = """
33
37
def forward():
34
38
if 1 + 1 == 2:
@@ -39,10 +43,12 @@ def forward():
39
43
if True:
40
44
pass
41
45
"""
42
- check_wrapping (source , output , 0 )
46
+ check_wrapping (source , output )
43
47
44
48
45
49
def test_static_if_global_vars ():
50
+ """Checks that resolvable if statements are replaced"""
51
+
46
52
source = """
47
53
def forward():
48
54
if config.is_false:
@@ -54,20 +60,35 @@ def forward():
54
60
pass
55
61
"""
56
62
config = SimpleNamespace (is_false = False )
57
- check_wrapping (source , output , 0 , namespace = {"config" : config })
63
+ check_wrapping (source , output , namespace = {"config" : config })
58
64
59
65
60
66
def test_dynamic_if ():
67
+ """Checks that non-resolvable if statements are ignored"""
68
+
61
69
source = """
62
70
def forward():
63
71
test = ...
64
72
if test:
65
73
pass
66
74
"""
67
- check_wrapping (source , None , 1 )
75
+ output = """
76
+ @torch.fx.wrap
77
+ def wrapped_0(test):
78
+ if test:
79
+ pass
80
+ return ()
81
+
82
+ def forward():
83
+ test = ...
84
+ () = wrapped_0(test)
85
+ """
86
+ check_wrapping (source , output )
68
87
69
88
70
89
def test_ignore_functions ():
90
+ """Checks that ignored functions are wrapped"""
91
+
71
92
def func_one ():
72
93
pass
73
94
@@ -79,11 +100,23 @@ def forward():
79
100
func_one()
80
101
func_two()
81
102
"""
103
+ output = """
104
+ @torch.fx.wrap
105
+ def wrapped_0():
106
+ return func_one()
107
+ return ()
108
+
109
+ def forward():
110
+ wrapped_0()
111
+ func_two()
112
+ """
82
113
namespace = {"func_one" : func_one , "func_two" : func_two }
83
- check_wrapping (source , None , 1 , namespace = namespace , ignore = ["func_one" ])
114
+ check_wrapping (source , output , namespace = namespace , ignore = ["func_one" ])
84
115
85
116
86
117
def test_ignore_methods ():
118
+ """Checks that ignored class methods are wrapped"""
119
+
87
120
class Model :
88
121
def meth_one (self ):
89
122
pass
@@ -96,11 +129,23 @@ def forward(self):
96
129
self.meth_one()
97
130
self.meth_two()
98
131
"""
132
+ output = """
133
+ @torch.fx.wrap
134
+ def wrapped_0():
135
+ return self.meth_one()
136
+ return ()
137
+
138
+ def forward(self):
139
+ wrapped_0()
140
+ self.meth_two()
141
+ """
99
142
namespace = {"self" : Model ()}
100
- check_wrapping (source , None , 1 , namespace = namespace , ignore = ["meth_one" ])
143
+ check_wrapping (source , output , namespace = namespace , ignore = ["meth_one" ])
101
144
102
145
103
146
def test_branch_with_self_assignment ():
147
+ """Checks that names referenced in self assignment are included in fn args"""
148
+
104
149
source = """
105
150
def forward(x, y):
106
151
if y > 0:
@@ -109,18 +154,38 @@ def forward(x, y):
109
154
x = x - 1
110
155
return x
111
156
"""
157
+ output = """
158
+ @torch.fx.wrap
159
+ def wrapped_0(x, y):
160
+ if y > 0:
161
+ x = x + 1
162
+ else:
163
+ x = x - 1
164
+ return (x,)
112
165
113
- tree = ast .parse (textwrap .dedent (source ))
114
- wrapper = AutoWrapper (namespace = {}, ignore = [])
115
- wrapper .auto_wrap (tree )
166
+ def forward(x, y):
167
+ (x,) = wrapped_0(x, y) # skip: some envs use "(x,)" -> "x,"
168
+ return x
169
+ """
170
+ check_wrapping (source , output )
116
171
117
- assert len (wrapper ._wrapper_fn_defs ) == 1
118
172
119
- # Check if both x, y are included in args
120
- wrapped_fn = wrapper ._wrapper_fn_defs [0 ]
121
- arg_names = {arg .arg for arg in wrapped_fn .args .args }
173
+ def test_function_variadic ():
174
+ """Checks for handling variadic names created via function def"""
175
+
176
+ source = """
177
+ def forward(a, *b, c=5, **d):
178
+ if a == b and c == d:
179
+ pass
180
+ """
181
+ output = """
182
+ @torch.fx.wrap
183
+ def wrapped_0(a, b, c, d):
184
+ if a == b and c == d:
185
+ pass
186
+ return ()
122
187
123
- assert arg_names == {
124
- "x" ,
125
- "y" ,
126
- }, f"Expected arguments {{'x', 'y'}}, but got { arg_names } "
188
+ def forward(a, *b, c=5, **d):
189
+ () = wrapped_0(a, b, c, d)
190
+ """
191
+ check_wrapping ( source , output )
0 commit comments