Skip to content

Commit 36e4840

Browse files
committed
hide warnings
1 parent d580f33 commit 36e4840

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ def test_tile(self):
13771377
torch.tensor([2, 2], dtype=torch.int64),
13781378
)
13791379

1380+
@ignore_warnings(UserWarning)
13801381
def test_custom_kernels(self):
13811382
class LayerNormalizationOrt(OpRunKernel):
13821383
"LayerNormalization"

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
import numpy as np
66
from scipy.spatial.distance import cdist
77
import torch
8-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_torch, requires_torch
8+
from onnx_diagnostic.ext_test_case import (
9+
ExtTestCase,
10+
hide_stdout,
11+
has_torch,
12+
requires_torch,
13+
ignore_warnings,
14+
)
915
from onnx_diagnostic.torch_export_patches import torch_export_patches, torch_export_rewrite
1016
from onnx_diagnostic.torch_export_patches.patch_module import (
1117
transform_method,
@@ -370,6 +376,7 @@ def forward(self, x, y):
370376
self.assertEqualAny(expected_0, ep.module()(x, -y))
371377
self.assertEqualAny(expected_1, ep.module()(-x, -y))
372378

379+
@ignore_warnings(UserWarning)
373380
def test_rewrite_test_in_forward_none(self):
374381

375382
class Model(torch.nn.Module):

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,8 @@ def patched_modeling_marian_eager_attention_forward(
10321032

10331033

10341034
class common_RotaryEmbedding(torch.nn.Module):
1035-
@torch.no_grad()
1035+
# This may cause some issues.
1036+
# @torch.no_grad()
10361037
@patched_dynamic_rope_update
10371038
def forward(self, x, position_ids):
10381039
inv_freq_expanded = (

0 commit comments

Comments
 (0)