Skip to content

Commit f649ee7

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Use source hashing to generate consistent symbolic ids (pytorch#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how.... Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized. We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions. Pull Request resolved: pytorch#149665 Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
1 parent c49315e commit f649ee7

23 files changed

+521
-443
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,46 @@ def fn(x, y):
196196
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
197197
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
198198

199+
@inductor_config.patch("fx_graph_remote_cache", False)
200+
@inductor_config.patch("fx_graph_cache", True)
201+
@functorch_config.patch({"enable_autograd_cache": True})
202+
def test_symbol_specialization(self):
203+
"""
204+
Verify the symbol specializations don't cause cache miss.
205+
"""
206+
207+
def fn(x, y, z):
208+
return (torch.randn(5) + x + y, z * torch.randn(1))
209+
210+
a = torch.rand(5)
211+
torch._dynamo.maybe_mark_dynamic(a, 0)
212+
b = torch.rand(5)
213+
c = torch.randn(6)
214+
torch._dynamo.maybe_mark_dynamic(c, 0)
215+
216+
compiled_fn = torch.compile(fn, backend="inductor")
217+
218+
# A first call should miss in the cache.
219+
compiled_fn(a, b, c)
220+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
221+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
222+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
223+
224+
# A second call should hit even if a new dimension is marked as dynamic
225+
# that is later specialized as part of tracing.
226+
a = torch.rand(5)
227+
torch._dynamo.maybe_mark_dynamic(a, 0)
228+
b = torch.rand(5)
229+
torch._dynamo.maybe_mark_dynamic(b, 0)
230+
c = torch.randn(6)
231+
torch._dynamo.maybe_mark_dynamic(c, 0)
232+
self._clear_dynamo_and_codecache()
233+
234+
compiled_fn(a, b, c)
235+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
236+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
237+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
238+
199239
@functorch_config.patch({"enable_autograd_cache": True})
200240
def test_aot_runtime_trace_joint(self):
201241
@torch.compile(backend="inductor")

test/dynamo/test_backward_higher_order_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def fn(x, y):
245245
actual,
246246
"""\
247247
class GraphModule(torch.nn.Module):
248-
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
248+
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s45)"):
249249
l_inputs_ = L_inputs_
250250
l_sizes_0_ = L_sizes_0_
251251
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
@@ -264,7 +264,7 @@ def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_
264264
265265
copy_: "f32[2]" = new_grad_strided.copy_(aot0_tangents_1); copy_ = None
266266
267-
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
267+
add: "Sym(s45 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
268268
269269
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
270270

test/dynamo/test_comptime.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ def f(x):
5757
self.assertExpectedInline(
5858
FILE.getvalue().strip(),
5959
"""\
60-
FakeTensor(..., size=(s0,))
60+
FakeTensor(..., size=(s77,))
6161
2
62-
[FakeTensor(..., size=(s0,)), 2]
63-
(FakeTensor(..., size=(s0,)), 2)
64-
{'foo': FakeTensor(..., size=(s0,))}
62+
[FakeTensor(..., size=(s77,)), 2]
63+
(FakeTensor(..., size=(s77,)), 2)
64+
{'foo': FakeTensor(..., size=(s77,))}
6565
range(1, 3, 1)
6666
Employee(name='foo', id=2)
6767
UserDefinedListVariable(mylist)
6868
defaultdict(NestedUserFunctionVariable(), {})
6969
set()
7070
{'a','b'}
71-
s0""",
71+
s77""",
7272
)
7373

7474
def test_print_graph(self):

test/dynamo/test_exc.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -256,34 +256,34 @@ def fn(x, shape):
256256
==> L['x'].size()[0]: 3
257257
==> L['x'].storage_offset(): 0
258258
==> L['x'].stride()[0]: 1
259-
==> s0: 3
260-
==> s1: 0
261-
==> s2: 1
262259
==> s3: 1
260+
==> s52: 1
261+
==> s77: 3
262+
==> s86: 0
263263
264264
Assertions:
265265
==> (== 0 L['x'].storage_offset())
266266
==> (== 1 L['x'].stride()[0])
267-
==> (== L['shape'][0] s1)
268-
==> (== L['shape'][1] s2)
267+
==> (== L['shape'][0] s86)
268+
==> (== L['shape'][1] s52)
269269
==> (== L['shape'][2] s3)
270-
==> (== L['x'].size()[0] s0)
271-
==> (> s0 1)
270+
==> (== L['x'].size()[0] s77)
271+
==> (> s77 1)
272272
273273
Target Expressions:
274-
==> (!= (+ s1 s2 s3) s0)
275-
==> (<= 0 s1)
276-
==> (<= 0 s2)
274+
==> (!= (+ s3 s52 s86) s77)
277275
==> (<= 0 s3)
278-
==> (<= 2 s0)
276+
==> (<= 0 s52)
277+
==> (<= 0 s86)
278+
==> (<= 2 s77)
279279
==> (== 0 L['x'].storage_offset())
280280
==> (== 1 L['x'].stride()[0])
281-
==> (== L['shape'][0] s1)
282-
==> (== L['shape'][1] s2)
281+
==> (== L['shape'][0] s86)
282+
==> (== L['shape'][1] s52)
283283
==> (== L['shape'][2] s3)
284-
==> (== L['x'].size()[0] s0)
285-
==> (> s0 0)
286-
==> (>= 0 s1)
284+
==> (== L['x'].size()[0] s77)
285+
==> (> s77 0)
286+
==> (>= 0 s86)
287287
288288
Failed Source Expressions:
289289
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@@ -309,7 +309,7 @@ def fn(x, shape):
309309
BisectValidationException,
310310
lambda: fn(torch.randn(20), (5, 10, 5)),
311311
"""\
312-
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
312+
translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
313313
314314
Failure occurred while running node:
315315
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
@@ -321,33 +321,33 @@ def fn(x, shape):
321321
==> L['x'].size()[0]: 3
322322
==> L['x'].storage_offset(): 0
323323
==> L['x'].stride()[0]: 1
324-
==> s0: 3
325-
==> s1: 1
326-
==> s2: 1
327324
==> s3: 0
325+
==> s52: 1
326+
==> s77: 3
327+
==> s86: 1
328328
329329
Assertions:
330330
==> (== 0 L['x'].storage_offset())
331331
==> (== 1 L['x'].stride()[0])
332-
==> (== L['shape'][0] s1)
333-
==> (== L['shape'][1] s2)
332+
==> (== L['shape'][0] s86)
333+
==> (== L['shape'][1] s52)
334334
==> (== L['shape'][2] s3)
335-
==> (== L['x'].size()[0] s0)
336-
==> (> s0 1)
335+
==> (== L['x'].size()[0] s77)
336+
==> (> s77 1)
337337
338338
Target Expressions:
339-
==> (!= (+ s1 s2 s3) s0)
340-
==> (<= 0 s1)
341-
==> (<= 0 s2)
339+
==> (!= (+ s3 s52 s86) s77)
342340
==> (<= 0 s3)
343-
==> (<= 2 s0)
341+
==> (<= 0 s52)
342+
==> (<= 0 s86)
343+
==> (<= 2 s77)
344344
==> (== 0 L['x'].storage_offset())
345345
==> (== 1 L['x'].stride()[0])
346-
==> (== L['shape'][0] s1)
347-
==> (== L['shape'][1] s2)
346+
==> (== L['shape'][0] s86)
347+
==> (== L['shape'][1] s52)
348348
==> (== L['shape'][2] s3)
349-
==> (== L['x'].size()[0] s0)
350-
==> (> s0 0)
349+
==> (== L['x'].size()[0] s77)
350+
==> (> s77 0)
351351
352352
Failed Source Expressions:
353353
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

test/dynamo/test_export.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2703,7 +2703,7 @@ def forward(self, x, y):
27032703
for node in ebar.graph_module.graph.nodes
27042704
if node.op == "placeholder"
27052705
],
2706-
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
2706+
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
27072707
)
27082708

27092709
@torch._dynamo.config.patch(
@@ -3480,23 +3480,23 @@ def test_symbool_guards(
34803480
true_graph = """\
34813481
class GraphModule(torch.nn.Module):
34823482
def forward(self, pred, x):
3483-
arg1: "f32[s1, s2]";
3483+
arg1: "f32[s77, s27]";
34843484
34853485
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
34863486
l_x_ = arg1
34873487
3488-
sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
3488+
sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
34893489
return pytree.tree_unflatten([sin], self._out_spec)
34903490
"""
34913491
false_graph = """\
34923492
class GraphModule(torch.nn.Module):
34933493
def forward(self, pred, x):
3494-
arg1: "f32[s1, s2]";
3494+
arg1: "f32[s77, s27]";
34953495
34963496
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
34973497
l_x_ = arg1
34983498
3499-
cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
3499+
cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
35003500
return pytree.tree_unflatten([cos], self._out_spec)
35013501
"""
35023502
true_guard_code = [

test/dynamo/test_functions.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,7 +2655,7 @@ def forward(self, L_x_: "f32[3]"):
26552655
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
26562656
"""\
26572657
class GraphModule(torch.nn.Module):
2658-
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0]"):
2658+
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77]"):
26592659
l_x_ = L_x_
26602660
26612661
sum_1: "f32[]" = l_x_.sum(); l_x_ = None
@@ -2885,13 +2885,13 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
28852885
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
28862886
"""\
28872887
class GraphModule(torch.nn.Module):
2888-
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2888+
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
28892889
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
28902890
2891-
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2892-
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
2891+
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2892+
mul_1: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
28932893
2894-
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None
2894+
mul_2: "f32[s9, s9]" = torch.mul(mul, mul_1); mul = mul_1 = None
28952895
return (mul_2,)
28962896
""",
28972897
)
@@ -2932,14 +2932,14 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
29322932
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
29332933
"""\
29342934
class GraphModule(torch.nn.Module):
2935-
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2935+
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
29362936
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
29372937
2938-
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2938+
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
29392939
2940-
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
2940+
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
29412941
2942-
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
2942+
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
29432943
return (mul_1,)
29442944
""",
29452945
)
@@ -2982,14 +2982,14 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
29822982
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
29832983
"""\
29842984
class GraphModule(torch.nn.Module):
2985-
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2985+
def forward(self, s9: "Sym(s9)", L_lambda0_keywords_y_: "f32[s9, s9]"):
29862986
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
29872987
2988-
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2988+
mul: "f32[s9, s9]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
29892989
2990-
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
2990+
add: "f32[s9, s9]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
29912991
2992-
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
2992+
mul_1: "f32[s9, s9]" = torch.mul(mul, add); mul = add = None
29932993
return (mul_1,)
29942994
""",
29952995
)
@@ -3029,14 +3029,14 @@ def forward(self, L_x_: "f32[2, 2]"):
30293029
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
30303030
"""\
30313031
class GraphModule(torch.nn.Module):
3032-
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
3032+
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]"):
30333033
l_x_ = L_x_
30343034
3035-
mul: "f32[s0, s0]" = l_x_ * 4
3036-
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None
3037-
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None
3035+
mul: "f32[s77, s77]" = l_x_ * 4
3036+
mul_1: "f32[s77, s77]" = mul * l_x_; mul = None
3037+
mul_2: "f32[s77, s77]" = 20 * l_x_; l_x_ = None
30383038
3039-
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
3039+
mul_3: "f32[s77, s77]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
30403040
return (mul_3,)
30413041
""",
30423042
)

0 commit comments

Comments
 (0)