Commit be16f21
[Graph Partition] add symints to get_graph_inputs (pytorch#154679)
During `codegen_inputs`, we check whether there are undefined symbols:
https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L1668-L1674
Previously, for graph partition inputs, we do not explicitly add symints.
https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L3265-L3272
We relied on sizes/strides of TensorBox for codegen symint inputs. For example, a tensor with shape `[s0, 2]` will implicitly codegen `s0` as an input here. This works fine in most cases since backed symint has to come from some tensor shapes.
https://github.com/pytorch/pytorch/blob/65b1aedd09e98fcafcdd893ca4924f4fa598fd18/torch/_inductor/codegen/wrapper.py#L1624-L1632
In rare cases, this does not work. One example is saved tensors for backward where a tensor may have shape `[2*s0, 2]`. Since `2*s0` is an expression but not a symbol, `codegen_input_symbol_assignment` would not handle `s0` and later there would be an error when `_verify_input_symbol_assignment`.
The fix is add symints to `get_graph_inputs`. An alternative way is to update `codegen_input_symbol_assignment` but I want to minimize the change to graph partition only.
Pull Request resolved: pytorch#154679
Approved by: https://github.com/eellison1 parent d3c8f36 commit be16f21
File tree
4 files changed
+51
-8
lines changed- test
- distributed/tensor
- inductor
- torch/_inductor
- codegen
4 files changed
+51
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
801 | 801 | | |
802 | 802 | | |
803 | 803 | | |
804 | | - | |
805 | | - | |
806 | | - | |
807 | | - | |
808 | | - | |
809 | | - | |
| 804 | + | |
810 | 805 | | |
811 | 806 | | |
812 | 807 | | |
| |||
876 | 871 | | |
877 | 872 | | |
878 | 873 | | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
879 | 891 | | |
880 | 892 | | |
881 | 893 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14968 | 14968 | | |
14969 | 14969 | | |
14970 | 14970 | | |
| 14971 | + | |
| 14972 | + | |
| 14973 | + | |
| 14974 | + | |
| 14975 | + | |
| 14976 | + | |
| 14977 | + | |
| 14978 | + | |
| 14979 | + | |
| 14980 | + | |
| 14981 | + | |
| 14982 | + | |
| 14983 | + | |
| 14984 | + | |
| 14985 | + | |
| 14986 | + | |
| 14987 | + | |
| 14988 | + | |
| 14989 | + | |
| 14990 | + | |
| 14991 | + | |
14971 | 14992 | | |
14972 | 14993 | | |
14973 | 14994 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3282 | 3282 | | |
3283 | 3283 | | |
3284 | 3284 | | |
3285 | | - | |
| 3285 | + | |
| 3286 | + | |
| 3287 | + | |
3286 | 3288 | | |
3287 | 3289 | | |
3288 | 3290 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4251 | 4251 | | |
4252 | 4252 | | |
4253 | 4253 | | |
4254 | | - | |
| 4254 | + | |
| 4255 | + | |
| 4256 | + | |
| 4257 | + | |
| 4258 | + | |
| 4259 | + | |
| 4260 | + | |
| 4261 | + | |
| 4262 | + | |
4255 | 4263 | | |
4256 | 4264 | | |
4257 | 4265 | | |
| |||
0 commit comments