Skip to content

Commit 0decda6

Browse files
authored
Use Torch ExportedModule to import initial MLIR module (#416)
Use a `torch.export.ExportedProgram` to generate the initial MLIR module. This requires us to create an `ExportedProgram` from the initial `GraphModule`. Benefits: - We can use the torch-mlir's official entrypoint - This handles in-place ops for us - We can run decompositions and **keep** location data - This location data will stick around throughout the compile process Issues: - `aten.clamp` is decomposed by torch-mlir to `maximum(minimum(input, max), min)`. `ttnn.maximum` requires that the operand which needs to be broadcasted is on the RHS. Currently, in tt-mlir the `PartiallyBroadcastable` op trait only enforces that the broadcasted operand is on the LHS - tt-torch issue: #431 - tt-mlir issue: tenstorrent/tt-mlir#2458 - Graph parameters are inlined as constants in the graph. To have the `FxImporter` treat them as graph inputs we need to edit the `ExportedModule`s `ExportedGraphSignature` and force all "parameter" types to "user inputs" - This is a hack as the `ExportedGraphSignature` is meant to be a private member of `ExportedProgram` - Ideally we can configure the `FxImporter` to _not_ inline the parameters if we pass a flag of some sort. Perhaps a future contribution to torch-mlir. Other Info: - We need to upgrade to PyTorch 2.6.0 as it contains crucial changes which allow us to use custom decompositions (necessary to support interpolation) - AdaptiveAvgPool2d is lowered AvgPool2d and eventually to `stablehlo.reduce_window **even in the case where the op is equivalent to a global average**. Since we do not have support for lowering a sum_pool in `StablehloToTTIRPatterns.cpp` (sum because the division is afterward), I've temporarily added a custom decomposition of `aten.avg_pool2d` which will convert to a mean over the spatial dimensions when the `avg_pool2d` is equivalent to it. - `aten.split` is no longer lowered to a series of `narrow` ops. Instead it is now lowered to a series of `as_strided` ops. - `narrow` is lowered to `slice`, which can be lowered to `stablehlo.slice`. `as_strided` cannot be lowered from Torch Backend IR to Stablehlo. I've temporarily added back the old decomposition from PyTorch 2.5.0 which uses narrow as a custom decomposition. - I've made a PR which adds a lowering of `AtenAsStridedOp` to `stablehlo::SliceOp` in our fork of torch-mlir: tenstorrent/llvm-torch-mlir#4 - The tracer which generates the `GraphModule` which is passed to `backend` does not account for control flow, I believe in PyTorch 2.5.0 a graph break would be triggered during `.generate` methods in `transformers` LLMs. It does not anymore and so `.generate` will run until the max length is reached. - **this means that the entire generation becomes one program** - Once the first EOS token is generated, the rest of the length is filled with padding. We cannot compare the golden output to the result from the `GraphModule` as the output shapes are different. - Since the output of `.generate` graphs are integers PCC/atol verification is not quite useful but does return `True` when the outputs are _identical_ - The tokenizer can decode the outputs and strip padding. - I've added a flag to `ModelTester` that informs the `ModelTester` it is testing a `.generate` call. It will decode the output tokens and we compare the resulting strings. - PyTorch has an experimental `torch.cond` which they seem to intend to use to trace data-dependent control-flow. There's a note in the `transformers` source that says they intend to use it when it is no longer experimental - When the graph is compiled, the user inputs are placed **at the end** of the arguments passed to the program rather than the front. That is graph constants first, then inputs. - I needed to implement an `FxImporter` hook for importing literals to the graph. By default it will make all non-scalars `DenseElementsResourceAttr`s, however, this causes the process to hang upon cleanup whether the test fails or not. So the hook just uses `DenseElementsAttr` for all literals. - Someone has mentioned this problem in an IREE issue as well: iree-org/iree#20102 - They've traced it down to this PR in llvm that adds a GIL acquire when destroying the `DenseElementsResourceAttr`: llvm/llvm-project#124832
1 parent c79248a commit 0decda6

30 files changed

+540
-316
lines changed

docs/src/controlling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ You can use the following environment variables to override default behaviour:
99
| TT_TORCH_VERIFY_INTERMEDIATES | Sets whether to verify runtime intermediates during execution. | False |
1010
| TT_TORCH_CONSTEVAL | Enables evaluation of constant expressions (consteval) in the Torch FX graph prior to compilation. | False |
1111
| TT_TORCH_CONSTEVAL_PARAMETERS | Extends consteval to include parameters (e.g., model weights) as well as embedded constants. | False |
12-
| TT_TORCH_EMBEDDEDD_CONSTANTS | Remove embedded constants from the Torch FX graph and convert them to constant inputs | False |
12+
| TT_TORCH_INLINE_PARAMETERS | Inlines parameters in the MLIR module (and thus flatbuffer executable) rather than requiring them as inputs. NOTE: The maximum size of a flatbuffer is 2GB so this will cause compilation to fail for sufficiently large models | False |
1313
| TT_TORCH_IR_LOG_LEVEL | Enables printing MLIR from Torch to TTNN. It supports two modes; `INFO` and `DEBUG`. `INFO` prints MLIR for all conversions steps (Torch, StableHLO, TTIR and TTNN MLIR graphs). `DEBUG` prints intermediate MLIR for all passes (IR dump before and after each pass) additionally. Be warned, `DEBUG` IR printing forces single core compile, so it is much slower. | Disable |
1414

1515
### Controlling Compiler Behaviour Programatically

env/activate

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ else
3434
cd $TT_TORCH_HOME/third_party
3535
git clone https://github.com/pytorch/vision.git
3636
cd vision
37-
git checkout v0.20.0
37+
git checkout v0.21.0
3838
pip uninstall -y torchvision
3939
TORCHVISION_USE_VIDEO_CODEC=0 TORCHVISION_USE_FFMPEG=0 CC=clang CXX=clang++ _GLIBCXX_USE_CXX11_ABI=1 USE_CUDA=OFF python setup.py bdist_wheel
4040

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
1+
torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.6.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
22
black
33
mdutils
44
ninja

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def run(self):
6565
},
6666
zip_safe=False,
6767
install_requires=[
68-
"torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl",
68+
"torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.6.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl",
6969
"numpy",
7070
],
7171
)

tests/models/Qwen/test_qwen2_casual_lm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def test_qwen2_casual_lm(record_property, model_name, mode, op_by_op):
5757
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
5858

5959
tester = ThisTester(
60-
model_name, mode, compiler_config=cc, record_property_handle=record_property
60+
model_name,
61+
mode,
62+
compiler_config=cc,
63+
record_property_handle=record_property,
64+
is_token_output=True,
6165
)
6266
results = tester.test_model()
6367

tests/models/RMBG/test_RMBG.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _load_inputs(self):
3939
"mode",
4040
["train", "eval"],
4141
)
42-
@pytest.mark.xfail(reason="Fails due pt2 compile issue, graph is traced")
42+
@pytest.mark.skip(reason="Python bus error at the end of torch op-by-op flow")
4343
@pytest.mark.parametrize(
4444
"op_by_op",
4545
[OpByOpBackend.STABLEHLO, OpByOpBackend.TORCH, None],

tests/models/beit/test_beit_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_beit_image_classification(record_property, model_name, mode, op_by_op):
6060
if op_by_op == OpByOpBackend.STABLEHLO:
6161
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
6262

63-
required_atol = 0.032 if model_name == "microsoft/beit-base-patch16-224" else 0.05
63+
required_atol = 0.032 if model_name == "microsoft/beit-base-patch16-224" else 0.065
6464
tester = ThisTester(
6565
model_name,
6666
mode,

tests/models/codegen/test_codegen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def test_codegen(record_property, mode, op_by_op):
4646
cc.op_by_op_backend = OpByOpBackend.STABLEHLO
4747

4848
tester = ThisTester(
49-
model_name, mode, compiler_config=cc, record_property_handle=record_property
49+
model_name,
50+
mode,
51+
compiler_config=cc,
52+
record_property_handle=record_property,
53+
is_transformers_generation=True,
5054
)
5155
results = tester.test_model()
5256

tests/models/deit/test_deit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_deit(record_property, model_name, mode, op_by_op):
6767
tester = ThisTester(
6868
model_name,
6969
mode,
70-
relative_atol=0.01,
70+
relative_atol=0.015,
7171
compiler_config=cc,
7272
record_property_handle=record_property,
7373
)

tests/models/falcon/test_falcon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_falcon(record_property, mode, op_by_op):
5050
tester = ThisTester(
5151
model_name,
5252
mode,
53-
relative_atol=0.013,
53+
relative_atol=0.015,
5454
compiler_config=cc,
5555
record_property_handle=record_property,
5656
)

0 commit comments

Comments
 (0)